Trying to use XLA to further enhance the performance and speed up the training of my model in TF2.10. However, my input data shape varies, i.e. batch.shape = TensorShape([X, 4]) with X varying between batches.
In case of static execution without XLA, i.e. decorating the python update step function with @tf.function(jit_compile=False), and obtaining a concrete function for an input signature with shape = [None, 4] avoids retracing for every new X shape. However, if jit_compile=True, although no retracing occurs for calls with new shapes, however the first update_step call on a new X requires a very large amount of time.
The question is if there exists any method to avoid large compilation times on tf.graph with XLA and newly encountered shapes.
Code:
# the update function
@tf.function(jit_compile=IS_XLA)
def update_step(model, optim, batch):
...
return loss
# Training function
def train_model(model, optim, all_batches):
concrete_update_step = update_step.get_concrete_function(model=model, optim=optim, //
batch=tf.TensorSpec(shape=(None, 4), dtype=tf.float32))
for batch in all_batches:
loss = concrete_update_step(batch)
return None
if __name__ == '__main__':
# run without XLA
IS_XLA = False
for epoch in range(N_epochs):
train_model(model,optim,all_batches)
### Tracing occurs only for the call on the first batch
# run with XLA
IS_XLA = True
for epoch in range(N_epochs):
train_model(model,optim,all_batches)
### Although tracing occurs only for the call on the first batch, however for every new batch.shape[0], calling concrete_update_step requires huge amount of time