Step 1: I first traced a Roberta model and saved it.
batch_size = 4
batched_indexed_tokens = [[101, 64]*64]*batch_size
batched_attention_masks = [[1, 1]*64]*batch_size
tokens_tensor = torch.tensor(batched_indexed_tokens)
attention_masks_tensor = torch.tensor(batched_attention_masks)
mlm_model_ts = RobertaForMaskedLM.from_pretrained('roberta-large', torchscript=True)
traced_mlm_model = torch.jit.trace(mlm_model_ts, [tokens_tensor, attention_masks_tensor])
torch.jit.save(traced_mlm_model, 'roberta-large.pt')
Step 2: Then I load the model to compile it into a TensorRT model:
traced_mlm_model=torch.jit.load("roberta-large.pt",map_location=torch.device("cuda"))
trt_model = torch_tensorrt.compile(traced_mlm_model.cuda(),
inputs= [torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32), # input_ids
torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32)], # attention_mask
enabled_precisions= {torch.half},
workspace_size=2000000000,
truncate_long_and_double=True
)
torch.jit.save(trt_model, 'roberta-large_trt.trt')
Step 3: I load the TensorRT model
trt_model = torch.jit.load("roberta-large_trt.trt").cuda()
What I try to do is inference this loaded model in step 3 using HF pipeline:
classifier = pipeline("fill-mask", trt_model,tokenizer=tokenizer)
classifier("Paris is the <mask> of France.")
But I get an error:
File "/usr/local/lib/python3.10/dist-packages/transformers/pipelines/__init__.py", line 880, in pipeline
model_config = model.config
File "/usr/local/lib/python3.10/dist-packages/torch/jit/_script.py", line 811, in __getattr__
return super().__getattr__(attr)
File "/usr/local/lib/python3.10/dist-packages/torch/jit/_script.py", line 526, in __getattr__
return super().__getattr__(attr)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1695, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'RecursiveScriptModule' object has no attribute 'config'
I tried using pipeline derivatively with name of the model "roberta-large", it works fine.