I am sorry to post this without a toy reproducible example. I am slowly adding @tf.function decorators to my code and I came across a new error that I don't even know how to begin attacking.
Does this smell like anything you have seen before. Any pointers in the right direction would be very useful.
Error message
This popped up when calling tape.gradient. 16 is the expected size of the input to softmax (it is the value for n_to_keep if you are following the code below)
ERROR Incompatible shapes: [16] vs. [2087866608]
[[{{node gradients/Softmax_grad/mul}}]]
tf2xla conversion failed while converting __inference___backward_get_broadcast_inputs_146142_146192[_XlaMustCompile=true,config_proto=3175580994766145631,executor_type=11160318154034397263]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions. [Op:__inference___backward_get_broadcast_inputs_146142_146192]
Problem code (I think)
I unfortunately can't share my entire script but this is the only function that is using softmax. Please pardon the code quality.
import tensorflow as tf
def extract( foo, sorted_I, n_to_keep, num_partitions ):
# split by segment into a list called bar
bar = tf.dynamic_partition(
foo,
sorted_I,
num_partitions
)
# unsplit after extracting the best few elements
best = tf.concat(
[A[:n_to_keep] for A in bar],
axis=-1
)
return best, bar
@tf.function(jit_compile=True)
def get_broadcast_inputs( ci_edge_heats, ci_edges_I, n_to_keep, num_partitions ):
ci_edge_heats = tf.reshape(ci_edge_heats, [-1])
sorted_indices = tf.argsort(ci_edge_heats, axis=0, direction="DESCENDING")
sorted_I = tf.gather( ci_edges_I, sorted_indices )
node_ids = tf.range(tf.shape(ci_edge_heats)[0])
sorted_node_ids = tf.gather( node_ids, sorted_indices )
best_indices, _ = extract(
sorted_node_ids,
sorted_I,
n_to_keep,
num_partitions
)
best_I, _ = extract(
sorted_I,
sorted_I,
n_to_keep,
num_partitions
)
sorted_heats = tf.gather( ci_edge_heats, sorted_indices )
partitioned_heats = tf.dynamic_partition(
sorted_heats,
sorted_I,
num_partitions
)
weights = tf.concat(
[ tf.nn.softmax(A[:n_to_keep]) for A in partitioned_heats],
axis=-1
)
return best_indices, weights, best_I
What did you try and what were you expecting?
I am not fluent enough to try dumping the compiled functions, like the error message suggests. I'd love advice on how to sift through the dumps to find clues to this bug.