I'm converting a Pytorch model to ONNX format. I have the model weights in a .pt file and the architecture in a .yaml file. Using that, I wrote the model architecture as follows:
import torch
import torch.nn as nn
from models.common import ReOrg, Conv, Concat, SPPCSPC
from models.yolo import IKeypoint
class YOLOv7(nn.Module):
def __init__(self, nc=1, nkpt=17, anchors=None, depth_multiple=1.0, width_multiple=1.0):
super(YOLOv7, self).__init__()
self.nc = nc
self.nkpt = nkpt
self.anchors = anchors if anchors is not None else [[], [], [], []] # Default to empty lists if None
self.depth_multiple = depth_multiple
self.width_multiple = width_multiple
self.backbone = nn.ModuleList([
ReOrg(),
Conv(3, int(64 * width_multiple), 3, 1), # 1-P1/2
Conv(int(64 * width_multiple), int(128 * width_multiple), 3, 2), # 2-P2/4
Conv(int(128 * width_multiple), int(64 * width_multiple), 1, 1),
Conv(int(64 * width_multiple), int(64 * width_multiple), 1, 1),
Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
Concat([-1, -3, -5, -6]),
Conv(int(256 * width_multiple), int(128 * width_multiple), 1, 1), # 10
Conv(int(128 * width_multiple), int(256 * width_multiple), 3, 2), # 11-P3/8
Conv(int(256 * width_multiple), int(128 * width_multiple), 1, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 1, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
Concat([-1, -3, -5, -6]),
Conv(int(512 * width_multiple), int(256 * width_multiple), 1, 1), # 19
Conv(int(256 * width_multiple), int(512 * width_multiple), 3, 2), # 20-P4/16
Conv(int(512 * width_multiple), int(256 * width_multiple), 1, 1),
Conv(int(256 * width_multiple), int(256 * width_multiple), 1, 1),
Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
Concat([-1, -3, -5, -6]),
Conv(int(1024 * width_multiple), int(512 * width_multiple), 1, 1), # 28
Conv(int(512 * width_multiple), int(768 * width_multiple), 3, 2), # 29-P5/32
Conv(int(768 * width_multiple), int(384 * width_multiple), 1, 1),
Conv(int(384 * width_multiple), int(384 * width_multiple), 1, 1),
Conv(int(384 * width_multiple), int(384 * width_multiple), 3, 1),
Conv(int(384 * width_multiple), int(384 * width_multiple), 3, 1),
Conv(int(384 * width_multiple), int(384 * width_multiple), 3, 1),
Conv(int(384 * width_multiple), int(384 * width_multiple), 3, 1),
Concat([-1, -3, -5, -6]),
Conv(int(1536 * width_multiple), int(768 * width_multiple), 1, 1), # 37
Conv(int(768 * width_multiple), int(1024 * width_multiple), 3, 2), # 38-P6/64
Conv(int(1024 * width_multiple), int(512 * width_multiple), 1, 1),
Conv(int(512 * width_multiple), int(512 * width_multiple), 1, 1),
Conv(int(512 * width_multiple), int(512 * width_multiple), 3, 1),
Conv(int(512 * width_multiple), int(512 * width_multiple), 3, 1),
Conv(int(512 * width_multiple), int(512 * width_multiple), 3, 1),
Conv(int(512 * width_multiple), int(512 * width_multiple), 3, 1),
Concat([-1, -3, -5, -6]),
Conv(int(2048 * width_multiple), int(1024 * width_multiple), 1, 1) # 46
])
self.head = nn.ModuleList([
SPPCSPC(int(1024 * width_multiple), int(512 * width_multiple), k=(5, 9, 13)), # 47
Conv(int(512 * width_multiple), int(384 * width_multiple), 1, 1),
nn.Upsample(scale_factor=2, mode='nearest'), # 49
Conv(int(768 * width_multiple), int(384 * width_multiple), 1, 1), # Connect from 37
Concat([-1, -2]),
Conv(int(384 * width_multiple), int(384 * width_multiple), 1, 1),
Conv(int(384 * width_multiple), int(384 * width_multiple), 1, 1),
Conv(int(384 * width_multiple), int(192 * width_multiple), 3, 1),
Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
Concat([-1, -2, -3, -4, -5, -6]),
Conv(int(1152 * width_multiple), int(384 * width_multiple), 1, 1), # 59
Conv(int(384 * width_multiple), int(256 * width_multiple), 1, 1),
nn.Upsample(scale_factor=2, mode='nearest'), # 61
Conv(int(512 * width_multiple), int(256 * width_multiple), 1, 1), # Connect from 28
Concat([-1, -2]),
Conv(int(256 * width_multiple), int(256 * width_multiple), 1, 1),
Conv(int(256 * width_multiple), int(256 * width_multiple), 1, 1),
Conv(int(256 * width_multiple), int(128 * width_multiple), 3, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
Concat([-1, -2, -3, -4, -5, -6]),
Conv(int(768 * width_multiple), int(256 * width_multiple), 1, 1), # 71
Conv(int(256 * width_multiple), int(128 * width_multiple), 1, 1),
nn.Upsample(scale_factor=2, mode='nearest'), # 73
Conv(int(256 * width_multiple), int(128 * width_multiple), 1, 1), # Connect from 19
Concat([-1, -2]),
Conv(int(128 * width_multiple), int(128 * width_multiple), 1, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 1, 1),
Conv(int(128 * width_multiple), int(64 * width_multiple), 3, 1),
Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
Concat([-1, -2, -3, -4, -5, -6]),
Conv(int(384 * width_multiple), int(128 * width_multiple), 1, 1), # 83
Conv(int(128 * width_multiple), int(256 * width_multiple), 3, 2), # Connect from 83
Concat([-1, 71]), # Connect with 71
Conv(int(256 * width_multiple), int(256 * width_multiple), 1, 1),
Conv(int(256 * width_multiple), int(256 * width_multiple), 1, 1),
Conv(int(256 * width_multiple), int(128 * width_multiple), 3, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
Concat([-1, -2, -3, -4, -5, -6]),
Conv(int(768 * width_multiple), int(256 * width_multiple), 1, 1), # 93
Conv(int(256 * width_multiple), int(384 * width_multiple), 3, 2), # Connect from 93
Concat([-1, 59]), # Connect with 59
Conv(int(384 * width_multiple), int(384 * width_multiple), 1, 1),
Conv(int(384 * width_multiple), int(384 * width_multiple), 1, 1),
Conv(int(384 * width_multiple), int(192 * width_multiple), 3, 1),
Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
Concat([-1, -2, -3, -4, -5, -6]),
Conv(int(1152 * width_multiple), int(384 * width_multiple), 1, 1), # 103
Conv(int(384 * width_multiple), int(512 * width_multiple), 3, 2), # Connect from 103
Concat([-1, 47]), # Connect with 47
Conv(int(512 * width_multiple), int(512 * width_multiple), 1, 1),
Conv(int(512 * width_multiple), int(512 * width_multiple), 1, 1),
Conv(int(512 * width_multiple), int(256 * width_multiple), 3, 1),
Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
Concat([-1, -2, -3, -4, -5, -6]),
Conv(512 * int(width_multiple), 512, 1, 1), # 113
Conv(256 * int(width_multiple), 256, 3, 1), # Layer connected from 83
Conv(512 * int(width_multiple), 512, 3, 1), # Layer connected from 93
Conv(768 * int(width_multiple), 768, 3, 1), # Layer connected from 103
Conv(1024 * int(width_multiple), 1024, 3, 1), # Layer connected from 113
])
self.detect = IKeypoint(nc=nc, anchors=anchors, nkpt=nkpt)
def forward(self, x):
x_backbone = x
for layer in self.backbone:
x_backbone = layer(x_backbone)
x_head = x_backbone
for layer in self.head:
x_head = layer(x_head)
# Detection (IKeypoint)
output = self.detect(x_head)
return output
anchors = [
[19, 27, 44, 40, 38, 94], # P3/8
[96, 68, 86, 152, 180, 137], # P4/16
[140, 301, 303, 264, 238, 542], # P5/32
[436, 615, 739, 380, 925, 792] # P6/64
]
The following script is for converting Pytorch to ONNX:
import torch
import torch.nn as nn
from collections import OrderedDict
import torch.onnx
def load_model_weights(model_path, model):
checkpoint = torch.load(model_path, map_location='cpu')
state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint['model'].state_dict() if 'model' in checkpoint else checkpoint
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
model = YOLOv7(nc=1, nkpt=17, depth_multiple=1.0, width_multiple=1.0, anchors=anchors)
load_model_weights('yolov7-w6-pose.pt', model)
model.eval()
dummy_input = torch.randn(1, 3, 640, 640)
torch.onnx.export(model,
dummy_input,
"yolov7-w6-pose.onnx",
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['detections', 'keypoints'],
dynamic_axes={'input': {0: 'batch_size'},
'detections': {0: 'batch_size'},
'keypoints': {0: 'batch_size'}})
Running this script returns following error:
RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[1, 12, 320, 320] to have 3 channels, but got 12 channels instead
I tried changing the input channels but to no avail.