Reshaping a SHAP shap._explanation.Explanation element

34 Views Asked by At

I am applying shap to a multivariate time series classification problem. Therefore, my instances are intervals including n timepoints. From my workaround, I am using this command

explanation = explainer(X_test[sample_idx, :, :].reshape(sample_size, v * w))

where sample_size is the number of intervals, v is the number of features and w is the number of timepoints per interval in order to get the local explanations for each feature and timepoint in one interval in form of a shap._explanation.Explanation object which looks like this:

.values =
array([[[-0.06085803,  0.06085803],
        [ 0.01253338, -0.01253338],
        [ 0.        ,  0.        ],
        ...,
        [ 0.04145166, -0.04145166],
        [-0.03110505,  0.03110505],
        [ 0.01649869, -0.0164987 ]],

       [[-0.00939477,  0.00939477],
        [ 0.00384894, -0.00384894],
        [ 0.        ,  0.        ],
        ...,
        [ 0.02929746, -0.02929746],
        [ 0.02393174, -0.02393174],
        [-0.01241742,  0.01241742]],

       [[ 0.00730818, -0.00730819],
        [ 0.01542984, -0.01542984],
        [-0.03578429,  0.03578429],
        ...,
        [ 0.03123155, -0.03123155],
        [-0.0098867 ,  0.0098867 ],
        [ 0.        ,  0.        ]],

       ...,

       [[ 0.        ,  0.        ],
        [ 0.05053705, -0.05053705],
        [ 0.0198657 , -0.0198657 ],
        ...,
        [ 0.00958598, -0.00958598],
        [-0.05249643,  0.05249643],
        [-0.02039118,  0.02039118]],

       [[-0.03896002,  0.03896002],
        [ 0.01164171, -0.01164171],
        [ 0.01400289, -0.01400289],
        ...,
        [-0.02907301,  0.02907301],
        [ 0.        ,  0.        ],
        [ 0.03117007, -0.03117007]],

       [[ 0.        ,  0.        ],
        [ 0.        ,  0.        ],
        [ 0.        ,  0.        ],
        ...,
        [-0.01277748,  0.01277748],
        [ 0.02368447, -0.02368447],
        [ 0.01660404, -0.01660404]]])

.base_values =
array([[0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085],
       [0.50142914, 0.49857085]])

.data =
array([[ 0.03,  0.05,  0.08, ..., -0.6 , -0.6 , -0.6 ],
       [ 0.  , -0.01, -0.06, ..., -0.6 , -0.6 , -0.6 ],
       [-0.03, -0.01, -0.06, ..., -0.6 , -0.6 , -0.6 ],
       ...,
       [-0.07, -0.02, -0.01, ..., -0.4 , -0.4 , -0.4 ],
       [-0.08, -0.08, -0.06, ..., -0.4 , -0.4 , -0.4 ],
       [ 0.04,  0.04, -0.01, ..., -0.4 , -0.4 , -0.4 ]])

However, since I would like to use the SHAP internal plots to vizualize SHAP values, I need to convert the shap._explanation.Explanation object into a flattend structure like this:

.values =
array([[ 5.95425583e-01, -4.58633694e-01,  3.87175008e-01, ...,
        -8.70954162e-02, -1.99409482e-01,  3.42254378e-02],
       [ 1.04104975e+00, -3.48955375e-01,  5.63579733e-01, ...,
        -9.94380733e-02, -1.62471390e+00,  4.11868767e-02],
       [ 3.57663240e-01, -2.74999500e-04, -4.93422823e-01, ...,
        -8.80035447e-02, -1.44724054e-01,  2.05118780e-02],
       ...,
       [ 1.18296099e+00,  3.35005086e-04, -3.73861502e-01, ...,
        -8.02874353e-02, -1.59453756e-01,  5.42146152e-02],
       [-1.91623858e+00, -2.50196535e-02, -4.12581469e-01, ...,
        -9.75276305e-02, -1.58804272e+00, -2.38128416e-03],
       [ 8.52208949e-01,  1.19914864e-01, -2.49582733e-01, ...,
        -3.12689449e-02,  9.36303758e-02,  3.73075812e-02]])

.base_values =
array([-2.48532888, -2.48532888, -2.48532888, ..., -2.48532888,
       -2.48532888, -2.48532888])

.data =
array([[39.,  7., 13., ...,  0., 40., 39.],
       [50.,  6., 13., ...,  0., 13., 39.],
       [38.,  4.,  9., ...,  0., 40., 39.],
       ...,
       [58.,  4.,  9., ...,  0., 40., 39.],
       [22.,  4.,  9., ...,  0., 20., 39.],
       [52.,  5.,  9., ...,  0., 40., 39.]])

where the interval based intersection is eliminated by seeing timepoints as instances.

My first idea was to change reshape(sample_size, v * w) in

explanation = explainer(X_test[sample_idx, :, :].reshape(sample_size, v * w))

into reshape(sample_size * w, v) to compensate for the interval dimension but the model trained on intervals cannot handle this different shape. So, I need to derive the SHAP values interval based and need to transform the resulting array afterwards so that only timepoints are left as instances. Does anyone has an idea of how to do that? I was already doing this by hand resulting in three single arrays for .data, .values and .base_values. However, when I want to apply a SHAP plot like this

shap.plots.beeswarm(shap_values)

I need to have a shap._explanation.Explanation object.

0

There are 0 best solutions below