Problem with debugging TorchScript RuntimeError: Dimension out of range

758 Views Asked by At

I've created a simple NLP model in PyTorch, trained it and it works as expected in Python. Then I've exported it to the TorchScript with jit.trace. And loading it back into Python works fine and model works as expected. But when I try to execute it in rust with tch-rs (Rust bindings for the C++ api of PyTorch), the following error occurs and I have no idea how to debug it:

Error: Torch("The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File \"code/__torch__/___torch_mangle_469.py\", line 17, in forward
    dropout = self.dropout
    bert = self.bert
    _0 = (dropout).forward((bert).forward(input_id, mask, ), )
                            ~~~~~~~~~~~~~ <--- HERE
    _1 = (relu).forward((linear).forward(_0, ), )
    return _1
  File \"code/__torch__/transformers/models/bert/modeling_bert/___torch_mangle_465.py\", line 19, in forward
    batch_size = ops.prim.NumToTensor(torch.size(input_id, 0))
    _0 = int(batch_size)
    seq_length = ops.prim.NumToTensor(torch.size(input_id, 1))
                                      ~~~~~~~~~~ <--- HERE
    _1 = int(seq_length)
    _2 = int(seq_length)

Traceback of TorchScript, original code (most recent call last):
/user/.conda/envs/tch/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py(954): forward
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/nn/modules/module.py(1176): _slow_forward
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/nn/modules/module.py(1192): _call_impl
/var/folders/zs/vmmy3w4n0ns1c0kj91skmfnm0000gn/T/ipykernel_10987/868892765.py(17): forward
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/nn/modules/module.py(1176): _slow_forward
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/nn/modules/module.py(1192): _call_impl
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/jit/_trace.py(957): trace_module
/user/.conda/envs/tch/lib/python3.10/site-packages/torch/jit/_trace.py(753): trace
/var/folders/zs/vmmy3w4n0ns1c0kj91skmfnm0000gn/T/ipykernel_10987/749605851.py(1): <module>
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3430): run_code
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3341): run_ast_nodes
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3168): run_cell_async
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2970): _run_cell
/user/.conda/envs/tch/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2941): run_cell
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/zmqshell.py(531): run_cell
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/ipkernel.py(380): do_execute
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/kernelbase.py(700): execute_request
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/kernelbase.py(383): dispatch_shell
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/kernelbase.py(496): process_one
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/kernelbase.py(510): dispatch_queue
/user/.conda/envs/tch/lib/python3.10/asyncio/events.py(80): _run
/user/.conda/envs/tch/lib/python3.10/asyncio/base_events.py(1868): _run_once
/user/.conda/envs/tch/lib/python3.10/asyncio/base_events.py(597): run_forever
/user/.conda/envs/tch/lib/python3.10/site-packages/tornado/platform/asyncio.py(212): start
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel/kernelapp.py(701): start
/user/.conda/envs/tch/lib/python3.10/site-packages/traitlets/config/application.py(990): launch_instance
/user/.conda/envs/tch/lib/python3.10/site-packages/ipykernel_launcher.py(12): <module>
/user/.conda/envs/tch/lib/python3.10/runpy.py(75): _run_code
/user/.conda/envs/tch/lib/python3.10/runpy.py(191): _run_module_as_main
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
")

And here is the simple model that I try to execute:

from torch import nn
from transformers import BertModel

class BertClassifier(nn.Module):

    def __init__(self, dropout=0.5):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 5)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)

        return final_layer

I'm new to ML and I can't find any docs on how to debug TorchScript runtime errors so I appreciate any help in solving this problem

0

There are 0 best solutions below