Is it possible to use XLA in Tensorflow with variable input shape?

193 Views Asked by At

Trying to use XLA to further enhance the performance and speed up the training of my model in TF2.10. However, my input data shape varies, i.e. batch.shape = TensorShape([X, 4]) with X varying between batches.

In case of static execution without XLA, i.e. decorating the python update step function with @tf.function(jit_compile=False), and obtaining a concrete function for an input signature with shape = [None, 4] avoids retracing for every new X shape. However, if jit_compile=True, although no retracing occurs for calls with new shapes, however the first update_step call on a new X requires a very large amount of time.

The question is if there exists any method to avoid large compilation times on tf.graph with XLA and newly encountered shapes.

Code:

# the update function 
@tf.function(jit_compile=IS_XLA)
def update_step(model, optim, batch):
   ... 
   return loss

# Training function 
def train_model(model, optim, all_batches):
   concrete_update_step = update_step.get_concrete_function(model=model, optim=optim, // 
            batch=tf.TensorSpec(shape=(None, 4), dtype=tf.float32))
   
   for batch in all_batches:
      loss =  concrete_update_step(batch)
   
   return None 



if __name__ == '__main__':

# run without XLA
IS_XLA = False
for epoch in range(N_epochs):
  train_model(model,optim,all_batches)

### Tracing occurs only for the call on the first batch 



# run with XLA 
IS_XLA = True 
for epoch in range(N_epochs):
  train_model(model,optim,all_batches)

### Although tracing occurs only for the call on the first batch, however for every new batch.shape[0], calling concrete_update_step requires huge amount of time 
0

There are 0 best solutions below