Channel importance in sequence classification

21 Views Asked by At

I have an ONNX model that takes input [1, 35, 4], which is [batch_size, num_channels, seq_len], and outputs [1, 3], which is [batch_size, num_classes]. What I need is to assign a number to each channel telling me how important it is for making the predictions the model is making. I thought I would use shap's PermutationExplainer for this purpose, but I'm not sure how to let it know that I don't care about seq_len.

I tried doing this:

import onnxruntime as ort
import shap
import numpy as np

# Load the ONNX model
model_path = 'explain-example.onnx'
sess = ort.InferenceSession(model_path)

# Define the model function to handle batching
def model(x):
    x = x.reshape(-1, 35, 4).astype(np.float32)
    # Model expects input [1, 35, 4] and returns [1, 3]
    outputs = [sess.run(None, {'input': x[i:i+1]})[0] for i in range(x.shape[0])]
    return np.concatenate(outputs, axis=0)

# Create sample input
X = np.random.rand(1000, 35, 4).astype(np.float32)

output_names = [f"output_{i}" for i in range(3)]
feature_names = [f"Channel {i}" for i in range(35)]

def masker_fn(mask, x):
    # problem, mask shape is (140,) and x shape is (35, 4)
    masked_x = x.copy()
    for i in range(x.shape[1]):
        if mask[i] == 0:
            masked_x[:, i, :] = 0
    return masked_x

# Declare the explainer using the custom masker for channels
explainer = shap.PermutationExplainer(model, masker_fn, feature_names=feature_names, output_names=output_names)

# Compute shap values
shap_values = explainer(X)

but the problem is that the masks being sent to the masker have shape (140,) instead of (35,). I could of course somehow merge the masks across seq_len dimension, but I am thinking that 1. unnecessary work is done by the explainer and 2. maybe this will break its inner algorithm somehow.

How can I properly tell it that no, there aren't 140 features, but 35?

0

There are 0 best solutions below