How to get input size for a operator in pytorch script model?

1.6k Views Asked by At

I use this code to transfer the model to script model:

scripted_model = torch.jit.trace(detector.model, images).eval()

Then I print the scripted_model. A part of the output is as follows:

 (base): DLA(
    original_name=DLA
    (base_layer): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level0): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level1): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level2): Tree(
      original_name=Tree
      (tree1): BasicBlock(
        original_name=BasicBlock
        (conv1): Conv2d(original_name=Conv2d)
        (bn1): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
        (conv2): Conv2d(original_name=Conv2d)
        (bn2): BatchNorm2d(original_name=BatchNorm2d)
      )
      (tree2): BasicBlock(
        original_name=BasicBlock
        (conv1): Conv2d(original_name=Conv2d)
        (bn1): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
        (conv2): Conv2d(original_name=Conv2d)
        (bn2): BatchNorm2d(original_name=BatchNorm2d
      )
      (root): Root(
        original_name=Root
        (conv): Conv2d(original_name=Conv2d)
        (bn): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
      )
      (downsample): MaxPool2d(original_name=MaxPool2d)
      (project): Sequential(
        original_name=Sequential
        (0): Conv2d(original_name=Conv2d)
        (1): BatchNorm2d(original_name=BatchNorm2d)
      )
    ) 
...

I just want to get the input size for the operator, such as how many inputs for the operator (0): Conv2d(original_name=Conv2d). I print the graph of this script model, the output is as follows:

  %4770 : __torch__.torch.nn.modules.module.___torch_mangle_11.Module = prim::GetAttr[name="wh"](%self.1)
  %4762 : __torch__.torch.nn.modules.module.___torch_mangle_15.Module = prim::GetAttr[name="tracking"](%self.1)
  %4754 : __torch__.torch.nn.modules.module.___torch_mangle_23.Module = prim::GetAttr[name="rot"](%self.1)
  %4746 : __torch__.torch.nn.modules.module.___torch_mangle_7.Module = prim::GetAttr[name="reg"](%self.1)
  %4738 : __torch__.torch.nn.modules.module.___torch_mangle_3.Module = prim::GetAttr[name="hm"](%self.1)
  %4730 : __torch__.torch.nn.modules.module.___torch_mangle_27.Module = prim::GetAttr[name="dim"](%self.1)
  %4722 : __torch__.torch.nn.modules.module.___torch_mangle_19.Module = prim::GetAttr[name="dep"](%self.1)
  %4714 : __torch__.torch.nn.modules.module.___torch_mangle_31.Module = prim::GetAttr[name="amodel_offset"](%self.1)
  %4706 : __torch__.torch.nn.modules.module.___torch_mangle_289.Module = prim::GetAttr[name="ida_up"](%self.1)
  %4645 : __torch__.torch.nn.modules.module.___torch_mangle_262.Module = prim::GetAttr[name="dla_up"](%self.1)
  %4461 : __torch__.torch.nn.modules.module.___torch_mangle_180.Module = prim::GetAttr[name="base"](%self.1)
  %5100 : (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) = prim::CallMethod[name="forward"](%4461, %input.1)
  %5082 : Tensor, %5083 : Tensor, %5084 : Tensor, %5085 : Tensor, %5086 : Tensor, %5087 : Tensor, %5088 : Tensor, %5089 : Tensor = prim::TupleUnpack(%5100)
  %5101 : (Tensor, Tensor, Tensor) = prim::CallMethod[name="forward"](%4645, %5082, %5083, %5084, %5085, %5086, %5087, %5088, %5089)
  %5097 : Tensor, %5098 : Tensor, %5099 : Tensor = prim::TupleUnpack(%5101)
  %3158 : None = prim::Constant()

I even can find the operator name. How can I get input size for a specific operator in the script model?

1

There are 1 best solutions below

1
Hamzah On BEST ANSWER

One solution is to try summary from torchinfo and the output shape of the first layer is the input shape for the next one and so on:

!pip install torchinfo

from torchinfo import summary
summary(model, input_size=(batch_size, 3, 224, 224)) # input size to your NN

#output 

===============================================================================================
    Layer (type:depth-idx)                        Output Shape              Param #
    ===============================================================================================
    ResNet50                                      --                        --
    ├─ResNet: 1-1                                 [64, 10]                  --
    │    └─Conv2d: 2-1                            [64, 64, 112, 112]        9,408
    │    └─BatchNorm2d: 2-2                       [64, 64, 112, 112]        128
    │    └─ReLU: 2-3                              [64, 64, 112, 112]        --
    │    └─MaxPool2d: 2-4                         [64, 64, 56, 56]          --
    │    └─Sequential: 2-5                        [64, 64, 56, 56]          --
    │    │    └─BasicBlock: 3-1                   [64, 64, 56, 56]          73,984
    │    │    └─BasicBlock: 3-2                   [64, 64, 56, 56]          73,984
    │    └─Sequential: 2-6                        [64, 128, 28, 28]         --
    │    │    └─BasicBlock: 3-3                   [64, 128, 28, 28]         230,144
    │    │    └─BasicBlock: 3-4                   [64, 128, 28, 28]         295,424
    │    └─Sequential: 2-7                        [64, 256, 14, 14]         --
    │    │    └─BasicBlock: 3-5                   [64, 256, 14, 14]         919,040
    │    │    └─BasicBlock: 3-6                   [64, 256, 14, 14]         1,180,672
    │    └─Sequential: 2-8                        [64, 512, 7, 7]           --
    │    │    └─BasicBlock: 3-7                   [64, 512, 7, 7]           3,673,088
    │    │    └─BasicBlock: 3-8                   [64, 512, 7, 7]           4,720,640
    │    └─AdaptiveAvgPool2d: 2-9                 [64, 512, 1, 1]           --
    │    └─Linear: 2-10                           [64, 10]                  5,130
    ===============================================================================================
    Total params: 11,181,642
    Trainable params: 11,181,642
    Non-trainable params: 0
    Total mult-adds (G): 116.07
    ===============================================================================================
    Input size (MB): 38.54
    Forward/backward pass size (MB): 2543.33
    Params size (MB): 44.73
    Estimated Total Size (MB): 2626.59
    ===============================================================================================