Pruned model not only couldn't decrease the inference time but also increased

42 Views Asked by At

I pruned my pre-trained model and the I can clearly see the impact on total model parameters but actually in inference time, as long as I increase the pruning percentage, the inference time gets worse and worse. Can anyone review my code if I am doing something wrong?

def pytorch_pruning(model):
    prune_percentage_synthesis = 0.9
    prune_percentage_hyper_synthesis = 0.1

    # Prune synthesis_transform filters
    for idx, conv_layer in enumerate([model.synthesis_transform.conv0, model.synthesis_transform.conv1, model.synthesis_transform.conv2, model.synthesis_transform.conv3,
                      model.synthesis_transform.res0.conv_a1, model.synthesis_transform.res0.conv_b1, model.synthesis_transform.res0.conv_b2, model.synthesis_transform.res0.conv_c1,
                      model.synthesis_transform.res0.conv_c2, model.synthesis_transform.res0.conv_c3, model.synthesis_transform.res0.conv_c1, model.synthesis_transform.res0.conv_d1,
                      model.synthesis_transform.res0.conv_d2, model.synthesis_transform.res0.conv_d3, model.synthesis_transform.res0.conv_d4, model.synthesis_transform.res1.conv_a1,
                      model.synthesis_transform.res1.conv_b1, model.synthesis_transform.res1.conv_b2, model.synthesis_transform.res1.conv_c1, model.synthesis_transform.res1.conv_c2,
                      model.synthesis_transform.res1.conv_c3, model.synthesis_transform.res1.conv_c1, model.synthesis_transform.res1.conv_d1, model.synthesis_transform.res1.conv_d2,
                      model.synthesis_transform.res1.conv_d3, model.synthesis_transform.res1.conv_d4, model.synthesis_transform.res2.conv_a1, model.synthesis_transform.res2.conv_b1,
                      model.synthesis_transform.res2.conv_b2, model.synthesis_transform.res2.conv_c1, model.synthesis_transform.res2.conv_c2, model.synthesis_transform.res2.conv_c3,
                      model.synthesis_transform.res2.conv_c1, model.synthesis_transform.res2.conv_d1, model.synthesis_transform.res2.conv_d2, model.synthesis_transform.res2.conv_d3,
                      model.synthesis_transform.res2.conv_d4]):
        
        if hasattr(conv_layer, 'kernel'):
            pruning_layers_synthesis = prune.ln_structured(conv_layer, 'kernel', amount=prune_percentage_synthesis, dim=1, n=1)
            prune.remove(conv_layer, 'kernel')
        elif hasattr(conv_layer, 'kernel_orig'):
            pruning_layers_synthesis = prune.ln_structured(conv_layer, 'kernel_orig', amount=prune_percentage_synthesis, dim=1, n=1)
            prune.remove(conv_layer, 'kernel_orig')
        elif hasattr(conv_layer, 'kernel_orig'):
            pruning_layers_synthesis = prune.ln_structured(conv_layer, 'bias', amount=prune_percentage_synthesis, dim=1, n=1)
            prune.remove(conv_layer, 'bias') 
    # Prune hyper_synthesis_transform filters
    for idx, conv_layer in enumerate([model.hyper_synthesis_transform.conv0, model.hyper_synthesis_transform.conv1, model.hyper_synthesis_transform.conv2]): 
        if hasattr(conv_layer, 'kernel'):
            pruning_layers_synthesis = prune.ln_structured(conv_layer, 'kernel', amount=prune_percentage_hyper_synthesis, dim=1, n=1)
            prune.remove(conv_layer, 'kernel')
        elif hasattr(conv_layer, 'kernel_orig'):
            pruning_layers_synthesis = prune.ln_structured(conv_layer, 'kernel_orig', amount=prune_percentage_hyper_synthesis, dim=1, n=1)
            prune.remove(conv_layer, 'kernel_orig')       
        elif hasattr(conv_layer, 'bias'):
            pruning_layers_synthesis = prune.ln_structured(conv_layer, 'bias', amount=prune_percentage_hyper_synthesis, dim=1, n=1)
            prune.remove(conv_layer, 'bias')
    return (model)    
# for i in range (len(l)):

# Load the state_dicts of the two models
checkpoint_path1 = '/home/mg/Downloads/code/VMSparse_Pruning/models/Pruned_L1_Pe8_Earlystopping_FreezingEncoder/'+l[0]+'/regularized_checkpoint_best_loss.pth.tar'
checkpoint1 = torch.load(checkpoint_path1, map_location=device)

# Create instances of your model class
model1 = dl_coding_model.CodingModel(32, 1)
model1.load_state_dict(checkpoint1["state_dict"])
# Prune filters
model1 = pytorch_pruning(model1)

# Save the pruned model
pruned_checkpoint_path = '/home/mg/Downloads/code/VMSparse_Pruning/models/Pruned_L1_Pe8_Earlystopping_FreezingEncoder/' + l[0] + '/pruned90%_00%.pth.tar'
torch.save({
    'state_dict': model1.state_dict()
}, pruned_checkpoint_path)

# Example usage:
# Assuming 'model' is your pruned model after running pytorch_pruning(model)
pruning_percentage = check_pruning(model1)
print(f"Percentage of pruned weights: {pruning_percentage:.2f}%")
0

There are 0 best solutions below