Feature Importance for FT Transformer

178 Views Asked by At

I hope someone is able to shed light on this. I was reading through the codes on this link with regards to getting feature importance from attention scores.

https://github.com/aruberts/TabTransformerTF/blob/main/tabtransformertf/models/fttransformer.py

and I came across these codes.

    # Prepare for Transformer
    transformer_inputs = tf.concat(transformer_inputs, axis=1)
    importances = []
    
    # Pass through Transformer blocks
    for transformer in self.transformers:
        if self.explainable:
            transformer_inputs, att_weights = transformer(transformer_inputs)
            importances.append(tf.reduce_sum(att_weights[:, :, 0, :], axis=1))
        else:
            transformer_inputs = transformer(transformer_inputs)

After cross checking with chatgpt, this is what I found from its replies.

"The third dimension in att_weights[:, :, 0, :] is set to 0 to select the attention weights of the first attention head."

Basically what I want to ask is whether the syntax make sense since its only appending the first attention head, which may not give the full information on the attention scores or weights. Rather we should be taking all the scores from the attention heads and calculate the mean from there. Is this reasoning correct?

Because I would deem this to be rather straightforward but the author of the github didn't take this approach and hence I wonder what could be the reasons.

0

There are 0 best solutions below