Handling Tuple Values in Word Attributions for Transformers Interpret

42 Views Asked by At

I am working on interpreting word attributions using the transformers_interpret library for a fine-tuned model in a natural language processing task. I want to print all the attribution scores and their visualizations for instances where the cumulative attribution score for a text surpasses a specified threshold.

Here's a simplified version of my code:

# Set the desired attribution score threshold
attribution_threshold = 2.50

# Set the desired maximum sequence length
max_sequence_length = 512

for row_index in range(len(df)):
    text = df.loc[row_index, 'TEXT'] 
    truncated_text = text[:max_sequence_length]
    word_attributions = cls_explainer(truncated_text)

    # Check if the overall attribution score is above the threshold
    overall_score = sum(abs(attribution) for attribution in word_attributions)
    if overall_score > attribution_threshold:
        cls_explainer.visualize()

However, I've encountered an issue where word_attributions contains tuples instead of individual scores for each word. My goal is to modify this code so that I can properly handle tuple values and achieve the desired check for individual attribution scores.

How can I adapt the code to print all the attribution scores and their visualizations for instances where any individual attribution score within each tuple surpasses the specified threshold?

1

There are 1 best solutions below

0
DavidS On

I found one solution. The problem is that the documentation is a little messy.

Here is the code that solves it:

# Set the desired maximum sequence length
max_sequence_length = 512
attribution_threshold = 1.8  # Adjust the threshold as needed

# List to store word attributions for each row
all_word_attributions = []

# Loop through all rows in the dataset
for row_index in range(len(df)):
    text = df.loc[row_index, 'TEXT']  # Assuming 'TEXT' is the correct column name
    truncated_text = text[:max_sequence_length]
    word_attributions = cls_explainer(truncated_text)

    # Sum of attributions for the entire text
    total_attribution = cls_explainer.attributions.attributions_sum.sum()

    # Print the sum of attributions
    print(f'Total Attribution for Row {row_index}: {total_attribution}')

    # Save the filtered attributions for this row
    all_word_attributions.append(word_attributions)

    # Visualize if the total attribution is above the threshold
    if abs(total_attribution) > attribution_threshold:
        cls_explainer.visualize()

As you can see, the

total_attribution = cls_explainer.attributions.attributions_sum.sum()

is the part that does this job.

Hope this is useful.