I pruned my pre-trained model and the I can clearly see the impact on total model parameters but actually in inference time, as long as I increase the pruning percentage, the inference time gets worse and worse. Can anyone review my code if I am doing something wrong?
def pytorch_pruning(model):
prune_percentage_synthesis = 0.9
prune_percentage_hyper_synthesis = 0.1
# Prune synthesis_transform filters
for idx, conv_layer in enumerate([model.synthesis_transform.conv0, model.synthesis_transform.conv1, model.synthesis_transform.conv2, model.synthesis_transform.conv3,
model.synthesis_transform.res0.conv_a1, model.synthesis_transform.res0.conv_b1, model.synthesis_transform.res0.conv_b2, model.synthesis_transform.res0.conv_c1,
model.synthesis_transform.res0.conv_c2, model.synthesis_transform.res0.conv_c3, model.synthesis_transform.res0.conv_c1, model.synthesis_transform.res0.conv_d1,
model.synthesis_transform.res0.conv_d2, model.synthesis_transform.res0.conv_d3, model.synthesis_transform.res0.conv_d4, model.synthesis_transform.res1.conv_a1,
model.synthesis_transform.res1.conv_b1, model.synthesis_transform.res1.conv_b2, model.synthesis_transform.res1.conv_c1, model.synthesis_transform.res1.conv_c2,
model.synthesis_transform.res1.conv_c3, model.synthesis_transform.res1.conv_c1, model.synthesis_transform.res1.conv_d1, model.synthesis_transform.res1.conv_d2,
model.synthesis_transform.res1.conv_d3, model.synthesis_transform.res1.conv_d4, model.synthesis_transform.res2.conv_a1, model.synthesis_transform.res2.conv_b1,
model.synthesis_transform.res2.conv_b2, model.synthesis_transform.res2.conv_c1, model.synthesis_transform.res2.conv_c2, model.synthesis_transform.res2.conv_c3,
model.synthesis_transform.res2.conv_c1, model.synthesis_transform.res2.conv_d1, model.synthesis_transform.res2.conv_d2, model.synthesis_transform.res2.conv_d3,
model.synthesis_transform.res2.conv_d4]):
if hasattr(conv_layer, 'kernel'):
pruning_layers_synthesis = prune.ln_structured(conv_layer, 'kernel', amount=prune_percentage_synthesis, dim=1, n=1)
prune.remove(conv_layer, 'kernel')
elif hasattr(conv_layer, 'kernel_orig'):
pruning_layers_synthesis = prune.ln_structured(conv_layer, 'kernel_orig', amount=prune_percentage_synthesis, dim=1, n=1)
prune.remove(conv_layer, 'kernel_orig')
elif hasattr(conv_layer, 'kernel_orig'):
pruning_layers_synthesis = prune.ln_structured(conv_layer, 'bias', amount=prune_percentage_synthesis, dim=1, n=1)
prune.remove(conv_layer, 'bias')
# Prune hyper_synthesis_transform filters
for idx, conv_layer in enumerate([model.hyper_synthesis_transform.conv0, model.hyper_synthesis_transform.conv1, model.hyper_synthesis_transform.conv2]):
if hasattr(conv_layer, 'kernel'):
pruning_layers_synthesis = prune.ln_structured(conv_layer, 'kernel', amount=prune_percentage_hyper_synthesis, dim=1, n=1)
prune.remove(conv_layer, 'kernel')
elif hasattr(conv_layer, 'kernel_orig'):
pruning_layers_synthesis = prune.ln_structured(conv_layer, 'kernel_orig', amount=prune_percentage_hyper_synthesis, dim=1, n=1)
prune.remove(conv_layer, 'kernel_orig')
elif hasattr(conv_layer, 'bias'):
pruning_layers_synthesis = prune.ln_structured(conv_layer, 'bias', amount=prune_percentage_hyper_synthesis, dim=1, n=1)
prune.remove(conv_layer, 'bias')
return (model)
# for i in range (len(l)):
# Load the state_dicts of the two models
checkpoint_path1 = '/home/mg/Downloads/code/VMSparse_Pruning/models/Pruned_L1_Pe8_Earlystopping_FreezingEncoder/'+l[0]+'/regularized_checkpoint_best_loss.pth.tar'
checkpoint1 = torch.load(checkpoint_path1, map_location=device)
# Create instances of your model class
model1 = dl_coding_model.CodingModel(32, 1)
model1.load_state_dict(checkpoint1["state_dict"])
# Prune filters
model1 = pytorch_pruning(model1)
# Save the pruned model
pruned_checkpoint_path = '/home/mg/Downloads/code/VMSparse_Pruning/models/Pruned_L1_Pe8_Earlystopping_FreezingEncoder/' + l[0] + '/pruned90%_00%.pth.tar'
torch.save({
'state_dict': model1.state_dict()
}, pruned_checkpoint_path)
# Example usage:
# Assuming 'model' is your pruned model after running pytorch_pruning(model)
pruning_percentage = check_pruning(model1)
print(f"Percentage of pruned weights: {pruning_percentage:.2f}%")