TVM compilation produces mismatched matrices

41 Views Asked by At

I have defined the following BERT-based model using PyTorch:

 class BERTGRUSentiment(nn.Module):
   13     def __init__(self,
   14                  bert,
   15                  hidden_dim,
   16                  output_dim,
   17                  n_layers,
   18                  bidirectional,
   19                  dropout=0):
S> 20         
   21         super().__init__()
S> 22         
   23         self.bert = bert
S> 24         
   25         embedding_dim = bert.config.to_dict()['hidden_size']
S> 26         
   27         self.rnn = nn.GRU(embedding_dim,
   28                           hidden_dim,
S> 29                           num_layers = n_layers,
S> 30                           bidirectional = True,
S> 31                           batch_first = True,
S> 32                           dropout = 0 if (n_layers < 2)  else dropout)
S> 33         
S> 34         self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
S> 35         
   36         self.dropout = nn.Dropout(dropout)
S> 37         
   38     def forward(self, text):
S> 39         
   40         # text = [batch size, sent len]
S> 41                 
   42         with torch.no_grad():
   43             embedded = self.bert(text)[0]
S> 44                 
S> 45         #embedded = [batch size, sent len, emb dim]
S> 46         
   47         _, hidden = self.rnn(embedded)
S> 48         
S> 49         #hidden = [n layers * n directions, batch size, emb dim]
S> 50         hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
S> 51         
   52         hidden = self.dropout(hidden)
S> 53                     
S> 54         #hidden = [batch size, hid dim]
S> 55         
   56         output = self.out(hidden)
S> 57         
S> 58         #output = [batch size, out dim]
S> 59         
   60         return output

I am able to train and run this model without any issues.

However, when I try to use apache TVM to compile the model as follows:



  1 import model
  2 import torch
  3 from transformers import BertModel
  4 import tvm
  5 from tvm import relay
  6 
  7 import sys 
  8 sys.setrecursionlimit(1000000)
  9 
 10 bert = BertModel.from_pretrained('bert-base-uncased')
 11 embedding_dim = bert.config.to_dict()['max_position_embeddings']
 12 
 13 device = "cuda"
 14 
 15 HIDDEN_DIM = 256   
 16 OUTPUT_DIM = 1     
 17 N_LAYERS = 2       
 18 BIDIRECTIONAL = True
 19 DROPOUT = 0        
 20 
 21 bert = model.BERTGRUSentiment(bert, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT)
 22 bert.eval()
 23 bert.to(device)
 24 
 25 print(f"cuda: {next(bert.parameters()).is_cuda}")
 26 print(f"training: {bert.training}")
 27 
 28 input_name = "text"         
 29 input_shape = [1, embedding_dim]
 30 
 31 shape_list = [(input_name, input_shape)]
 32 
 33 example = model.preprocess(model.tokenizer, "it was ok").unsqueeze(0)
 34 
 35 scripted_model = torch.jit.trace(bert, example).eval()
 36 # scripted_model = torch.jit.script(bert)
 37 
 38 print(scripted_model.graph)
 39 
 40 print("starting")
 41 
 42 mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
 43 
 44 print(mod)        
 45 
 46 target = tvm.target.cuda()
 47 
 48 with tvm.transform.PassContext(opt_level=3):
 49     lib = relay.build(mod, target=target, params=params)
 50 
 51 lib.export_library("compiled.so")
 52 

I get the following error:

Traceback (most recent call last):
  File "compile.py", line 42, in <module>
    mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
  File "/home/moe/tvm/python/tvm/relay/frontend/pytorch.py", line 5008, in from_pytorch
    outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
  File "/home/moe/tvm/python/tvm/relay/frontend/pytorch.py", line 4272, in convert_operators
    _get_input_types(op_node, outputs, default_dtype=self.default_dtype),
  File "/home/moe/tvm/python/tvm/relay/frontend/pytorch.py", line 1783, in linear
    [inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2]
  File "/home/moe/tvm/python/tvm/relay/frontend/pytorch.py", line 1976, in matmul
    raise AssertionError(msg)
AssertionError: Tensors being multiplied do not have compatible shapes.

I have added logging to the tvm files and it seems that the dimensions in question are (2,2, 256) and (1, 512). It also seems that this error occurs near the GRU layer as well. Are there any additional steps I can take to debug or eliminate this problem?

0

There are 0 best solutions below