I need to get from my Pytorch AutoEncoder the importance it gives to each input variable. I am working with a tabular data set, no images.
My AutoEncoder is as follows:
class AE(torch.nn.Module):
def __init__(self, input_size, hidden_layer, latent_layer):
super().__init__()
self.encoder = torch.nn.Sequential(
torch.nn.Linear(input_size, hidden_layer),
torch.nn.ReLU(),
torch.nn.Linear(hidden_layer, latent_layer)
)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(latent_layer, hidden_layer),
torch.nn.ReLU(),
torch.nn.Linear(hidden_layer, input_size)
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
To save unnecessary information, I simply call the following function to get my model:
average_loss, model, train_losses, test_losses = fullAE(batch_size=128, input_size=genes_tensor.shape[1],
learning_rate=0.0001, weight_decay=0,
epochs=50, verbose=False, dataset=genes_tensor, betas_value=(0.9, 0.999), train_dataset=genes_tensor_train, test_dataset=genes_tensor_test)
Where "model" is a trained instance of the previous AutoEncoder:
model = AE(input_size=input_size, hidden_layer=int(input_size * 0.75), latent_layer=int(input_size * 0.5)).to(device)
Well now I need to get the importance given by that model to each input variable in my original "genes_tensor" dataset, but I don't know how. I have researched how to do it and found a way to do it with shap software:
e = shap.DeepExplainer(model, genes_tensor)
shap_values = e.shap_values(
genes_tensor
)
shap.summary_plot(shap_values,genes_tensor,feature_names=features)
The problem with this implementation is the following: 1) I don't know if what I am actually doing is correct. 2) It takes forever to finish, since the dataset contains 950 samples, I have tried to do it with only 1 sample and it takes long enough. The result using a single sample is as follows:
I have seen that there are other options to obtain the importance of the input variables like Captum, but Captum only allows to know the importance in Neural Networks with a single output neuron, in my case there are many.
The options for AEs or VAEs that I have seen on github do not work for me since they use concrete cases, and especially images always, for example:
https://github.com/peterparity/PDE-VAE-pytorch
https://github.com/FengNiMa/VAE-TracIn-pytorch
Is my shap implementation correct?
Edit:
I have run the shap code with only 4 samples and get the following result:
I don't understand why it's not the typical shap summary_plot plot that appears everywhere.
I have been looking at the shap documentation, and it is because my model is multi-output by having more than one neuron at the output.
Not commenting much on SHAP below, but I have some thoughts on potential alternatives. Example code at the end.
Since SHAP is taking so long, I think it's worth considering other techniques if you think they can provide useful information which you can iterate on more quickly.
One approach is to run permutation importance tests (example code at end). Start by training a 'good' reference model, and getting the model's reconstruction and reconstruction error using the original data. Then, for each
feature_ifeature_i. This will mostly negate its effect during subsequent training.This information will allow you to plot feature vs. change in recon, or feature vs. change in recon error. The first plot tells you how each feature impacts the model's output, and can be viewed as an approximation of SHAP (though I view it as a distinct and useful method in its own right). The second plot tells you how each feature impacts reconstruction accuracy. This method is relatively fast as you only need to train the model once.
A limitation of this method is that if features are highly correlated, permutation tests can underestimate or miss a feature's importance (SHAP doesn't). There are ways of mitigating this, such as assessing correlations in advance and removing or grouping related ones.
An alternative way of assessing feature importance for an autoencoder is to record the latent representation of each sample. You can run a mutual information analysis to see the strength of association between a feature and the latent space representation. Some features might explain more of the compressed representation than others, suggesting a relative importance.
Other techniques could look at the size of the weight learnt for each feature (perhaps in combination with a sparsity penalty), or activation sizes.
For any given method, consider running it on just a portion of the dataset in order to save time, or training for only a few epochs. The results will be more approximate, but may be good enough for assessing relative feature importances.
To minimise overfitting, you might want to run the fitting on part of the data, and then get your recons and recon errors using an unseen validation sample.
The code below trains an autoencoder on petal features and runs a permutation test on the features. In this example some the features were highly correlated, and since I didn't handle that I'm not going to rely on the results below. The figures are just illustrative of what the code does.
Imports and prepare data
Define a simple autoencoder and a training loop:
Train the model. Calling it good at ~13% reconstruction error.
On that trained model, run permutation tests for each feature, and plot the results. Plotted are the model's drop in performance, and a bar plot of normalised results (which can be interpreted as feature importances). These results are shown at the start of this example.
Permutation tests:
Plotting: