Loading a model inside while loop of Tensorflow

574 Views Asked by At

I am trying to use @tf.function(jit_compile=True) to create a TF graph with a while loop; below is its pseudocode. I'm not able to provide a functioning code since it contains a lot of dependencies.

Code 1
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  tf.while()
    out3 = inputs
    tf.while_loop(number_samples)
      model = tf.keras.models.load_model()
      out2 = model(out3)
      out3 = function2(out2)
    inputs = function3(out3)
  return out3
  
Code 2
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  model = tf.keras.models.load_model()
  tf.while()
    out3 = inputs
    out2 = model(out3)
    out3 = function2(out2)
    inputs = function3(out3)
  return out3
  

The above code1 results in a memory explosion because I am calling the model inside the while loop. When I load the model outside both of the while loops, I get the error RuntimeError: Cannot get session inside Tensorflow graph function. What is the best way to prevent memory explosion?

Edit 1: The inputs are tensors. The problem here is that I need to pass a large batch at once. For this, I made a while loop and thought that the while loop would work in parallel (keras model() can only process 32 samples at once). I am not sure why keras model does not have batch size as an input. In the above code is it preferable to directly load the weights and all the values and do the manual computation to get the outputs? In the case of code 2, will each mode have a graph because it is called inside while loop?

Edit 2: function 3 has gradient computation of out3 with respect to inputs.

1

There are 1 best solutions below

4
Djinn On

2 possible solutions:

According to the tensorflow documentation, your problem may be stored tensors used in back propogation.

Try:

Code 1
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  tf.while(swap_memory=True)
    out3 = inputs
    tf.while_loop(number_samples)
      model = tf.keras.models.load_model()
      out2 = model(out3)
      out3 = function2(out2)
    inputs = function3(out3)
  return out3
  
Code 2
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  model = tf.keras.models.load_model()
  tf.while(swap_memory=True)
    out3 = inputs
    out2 = model(out3)
    out3 = function2(out2)
    inputs = function3(out3)
  return out3

Or at the end of your loops, try calling, K.clear_session() to reset the states.

from tensorflow.keras import backend as K

Code 1
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  tf.while()
    out3 = inputs
    tf.while_loop(number_samples)
      model = tf.keras.models.load_model()
      out2 = model(out3)
      out3 = function2(out2)
    inputs = function3(out3)
  K.clear_session()
  return out3
  
Code 2
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  model = tf.keras.models.load_model()
  tf.while()
    out3 = inputs
    out2 = model(out3)
    out3 = function2(out2)
    inputs = function3(out3)
  K.clear_session()
  return out3