I've created a simple NLP model in PyTorch, trained it and it works as expected in Python. Then I've exported it to the TorchScript with jit.trace. And loading it back into Python works fine and model works as expected. But when I try to execute it in rust with tch-rs (Rust bindings for the C++ api of PyTorch), the following error occurs and I have no idea how to debug it:
Error: Torch("The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File \"code/__torch__/___torch_mangle_469.py\", line 17, in forward
dropout = self.dropout
bert = self.bert
_0 = (dropout).forward((bert).forward(input_id, mask, ), )
~~~~~~~~~~~~~ <--- HERE
_1 = (relu).forward((linear).forward(_0, ), )
return _1
File \"code/__torch__/transformers/models/bert/modeling_bert/___torch_mangle_465.py\", line 19, in forward
batch_size = ops.prim.NumToTensor(torch.size(input_id, 0))
_0 = int(batch_size)
seq_length = ops.prim.NumToTensor(torch.size(input_id, 1))
~~~~~~~~~~ <--- HERE
_1 = int(seq_length)
_2 = int(seq_length)
Traceback of TorchScript, original code (most recent call last):
/user/.conda/envs/tch/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py(954): forward
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/nn/modules/module.py(1176): _slow_forward
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/nn/modules/module.py(1192): _call_impl
/var/folders/zs/vmmy3w4n0ns1c0kj91skmfnm0000gn/T/ipykernel_10987/868892765.py(17): forward
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/nn/modules/module.py(1176): _slow_forward
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/nn/modules/module.py(1192): _call_impl
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/jit/_trace.py(957): trace_module
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/jit/_trace.py(753): trace
/var/folders/zs/vmmy3w4n0ns1c0kj91skmfnm0000gn/T/ipykernel_10987/749605851.py(1): <module>
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3430): run_code
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3341): run_ast_nodes
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3168): run_cell_async
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2970): _run_cell
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2941): run_cell
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/zmqshell.py(531): run_cell
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/ipkernel.py(380): do_execute
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/kernelbase.py(700): execute_request
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/kernelbase.py(383): dispatch_shell
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/kernelbase.py(496): process_one
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/kernelbase.py(510): dispatch_queue
/user/.conda/envs/tch/lib/python3.10/asyncio/events.py(80): _run
/user/.conda/envs/tch/lib/python3.10/asyncio/base_events.py(1868): _run_once
/user/.conda/envs/tch/lib/python3.10/asyncio/base_events.py(597): run_forever
/user/.conda/envs/tch/lib/python3.10/site-packages/tornado/platform/asyncio.py(212): start
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/kernelapp.py(701): start
/user/.conda/envs/tch/lib/python3.10/site-packages/traitlets/config/application.py(990): launch_instance
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel_launcher.py(12): <module>
/user/.conda/envs/tch/lib/python3.10/runpy.py(75): _run_code
/user/.conda/envs/tch/lib/python3.10/runpy.py(191): _run_module_as_main
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
")
And here is the simple model that I try to execute:
from torch import nn
from transformers import BertModel
class BertClassifier(nn.Module):
def __init__(self, dropout=0.5):
super(BertClassifier, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-cased')
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(768, 5)
self.relu = nn.ReLU()
def forward(self, input_id, mask):
_, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
dropout_output = self.dropout(pooled_output)
linear_output = self.linear(dropout_output)
final_layer = self.relu(linear_output)
return final_layer
I'm new to ML and I can't find any docs on how to debug TorchScript runtime errors so I appreciate any help in solving this problem