How to implement distinctiveness pruning to my neural network built by PyTorch?

137 Views Asked by At

I built a neural network with PyTorch to predict the final marks based on students' assignment marks, lab marks, etc.

Although PyTorch has some built-in pruning modules, I am intended to implement distinctiveness pruning to my network, which prunes neurons based on the cosine similarity of pairs of neurons.

I'm really new to deep learning. In fact, I'm not even very familiar with Python itself, but I'm actively learning it as hard as I can. I did a lot of research on the Internet, I now understand the theories and mechanism behind distinctiveness pruning but fail to code it out in Python.

Could you guys please provide me with some help? Anything from pseudo code to a more detailed explanation would be much appreciated.

I think my main problem is that I don't know how to get the activation vector (the output of a neuron as a vector) of each neuron, how to add the weights of one neuron to another if they are similar or get rid of both neurons if they do the opposite thing, how to integrate this pruning process to my existing neural network.

My existing neural network is provided as below:

class Mark_Predict(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, num_classes):
        super().__init__()
        self.hidden1 = nn.Linear(input_size, hidden_size1)
        self.hidden2 = nn.Linear(hidden_size1, hidden_size2)
        self.relu = nn.ReLU()
        self.out = nn.Linear(hidden_size2, num_classes)

    def forward(self, x):
        out = self.hidden1(x)
        out = self.relu(out)
        out = self.hidden2(out)
        out = self.relu(out)
        out = self.out(out)
        return out

I have 14 features, and the output is either 0, 1, 2, or 3 (A, B, C, or D as the final mark). I used CrossEntropyLoss as my loss function and Adam as my optimizer.

I implemented 5-Fold cross-validation, and the current average test accuracy is around 70%. I'm looking forward to seeing the outcome after pruning.

0

There are 0 best solutions below