Suppose I have the following model, built from this synthetic data.
import numpy as np
import shap
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
X, y = make_classification(n_samples=1000,
n_features=50,
n_informative=9,
n_redundant=0,
n_repeated=0,
n_classes=10,
n_clusters_per_class=1,
class_sep=9,
flip_y=0.2,
random_state=17)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
model = RandomForestClassifier()
model.fit(X_train, y_train)
And I calculate features shap values:
explainer = shap.Explainer(model)
shap_values = explainer.shap_values(X_test)
type(shap_values)
list
To calculate each class' SHAP values separately, I do:
abs_sv = np.abs(shap_values)
avg_feature_importance_per_class = np.mean(abs_sv, axis=1)
avg_feature_importance_per_class.shape
(10, 50)
Question
Now, how do I calculate the mean of absolute SHAP values across all classes, which I can consider as the model's feature importance (derived from SHAP values).
I do like this:
feature_importance_overall = np.mean(abs_sv, axis=0)
But then I got myself confused. I am really doing this right? Especially if I look at the shape:
feature_importance_overall.shape
(250, 50)
I was expecting something a the shape of (number_of_features_,)
.
Similar to what I get from:
model.feature_importances_.shape
(50,)
avg_feature_importance_per_class.shape
also shows this but for number_of_classes
(i.e. (10, 50)
) since this is computed for individual classes separately.
To understand how you need to perform the calculation you mentioned, let's take a look at the shape of
shap_values
This numpy array contains three dimensions, which represent:
10
: number of classes250
: number of records in your test data50
: number of featuresThus, to get the desired result of calculating the "mean absolute shap values across all classes", it is required to average this array across the 10 classes (the first dimension at index 0) as well as all 250 records (the second dimension at index 1), which you can do with the following operation: