How to reuse operation in tensorflow?

197 Views Asked by At

Keras layers can be reused i.e. if I have l = keras.layers.Dense(5) I can apply it multiple times to different tensors like t1 = l(t1); t2 = l(t2).

Is there anything similar in tensorflow without using keras?

Why do I need it. I have non-eager mode and want to create static .pb graph-file. Suppose I have a function f(t) that is huge and long, and it does tensor t transformations. Inside a graph it creates a huge sub-graph of different operations with flow of tensors over paths. Now I want to reuse it, meaning that I don't want to call it for every input t because it will form new sub-graph each time, just duplicates with different inputs. I want somehow to reuse same subgraph and directing different tensors as inputs to this subgraph. Also it is good to reuse it not to call huge function to form same structure for every possible input tensor, because it is slow.

Another important reason for re-using same operation is because same weights and heavy parameters can be used for many calls of operation on many inputs. It is sometimes important and needed that weights are same for all inputs to have correctly trained neural network.

The real reason for reusing is not only to save sapce occupied by graph, but also due to the fact that number of possible inputs to f(t) may vary depending on input. Suppose we have keras.layers.Input(...) placeholder as input. It always has batch 0-th dimension equal to None (unknown) at graph construction time, the real value for 0-th dimension is only known when real data is fed through sess.run(...). Now when data is fed I want to make as many transformations (calls to f(t)) as the size of batch dimension, in other words I want to call f(t) for every sub-tensor in the batch. E.g. for batch of images I want to call f(t) for every single image in the batch. Hence there will be different number of calls of f(t) for different batch sizes. How do I achieve this? Could it be achieved through tf.while_loop, if yes than how do I use while loop in my case?

0

There are 0 best solutions below