I am trying to use torch multiprocessing to parallelize the predictions from two separate huggingface distillbert classification models. It seems to be deadlocked at the prediction step. I am using python 3.6.5, torch 1.5.0 and huggingface transformers version 2.11.0. The output from running the code is
Tree enc done
Begin tree prediction<------(Comment: Both begin tree
End tree predictions<------- and end tree predictions)
0.03125429153442383
Dn prediction
Dn enc done
Begin dn predictions<------(Comment: Both begin dn
End dn predictions<------- and end dn predictions)
0.029727697372436523
----------Done sequential predictions-------------
--------Start Parallel predictions--------------
Tree prediction
Tree enc done
Begin tree prediction. <------(Comment: Process is deadlocked after this)
Dn prediction
Dn enc done
Begin dn predictions. <-------(Comment: Process is deadlocked after this)
During parallel predictions it seems to be deadlocking and not printing out "End tree predictions" and "End dn predictions". Not sure why this is happening. The code is
import torch
import torch.multiprocessing as mp
import time
import transformers
from transformers import DistilBertForSequenceClassification
# Load the BERT tokenizer.
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
tree_model = DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased",
num_labels = 2,
output_attentions = False,
output_hidden_states = False
)
tree_model.eval()
dn_model = DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased",
num_labels = 2,
output_attentions = False,
output_hidden_states = False,
)
dn_model.eval()
tree_model.share_memory()
dn_model.share_memory()
def predict(sentences =[], tokenizer=tokenizer,models=(tree_model,dn_model,None)):
MAX_SENTENCE_LENGTH = 16
start = time.time()
input_ids = []
attention_masks = []
predictions = []
tree_model = models[0]
dn_model = models[1]
if models[0]:
print("Tree prediction")
if models[1]:
print("Dn prediction")
for sent in sentences:
encoded_dict = tokenizer.encode_plus(
sent,
add_special_tokens = True,
max_length = MAX_SENTENCE_LENGTH,
pad_to_max_length = True,
return_attention_mask = True,
return_tensors = 'pt',
)
# Add the encoded sentence to the list.
input_ids.append(encoded_dict['input_ids'])
# And its attention mask (simply differentiates padding from non-padding).
attention_masks.append(encoded_dict['attention_mask'])
if tree_model:
print("Tree enc done")
if dn_model:
print("Dn enc done")
# Convert the lists into tensors.
new_input_ids = torch.cat(input_ids, dim=0)
new_attention_masks = torch.cat(attention_masks, dim=0)
with torch.no_grad():
# Forward pass, calculate logit predictions
if tree_model:
print("Begin tree prediction")
outputs = tree_model(new_input_ids,
attention_mask=new_attention_masks)
print("End tree predictions")
else:
print("Begin dn predictions")
outputs = dn_model(new_input_ids,
attention_mask=new_attention_masks)
print("End dn predictions")
logits = outputs[0]
logits = logits.detach().cpu()
print(time.time()-start)
predictions = logits
return predictions
def get_tree_prediction(sentence, tokenizer=tokenizer,models=(tree_model,dn_model, None)):
return predict(sentences =[sentence], tokenizer=tokenizer,models=models)
def get_dn_prediction(sentence, tokenizer=tokenizer,models=(tree_model,dn_model, None)):
return predict(sentences =[sentence], tokenizer=tokenizer,models=models)
if __name__ == '__main__':
sentence = "hello world"
processes = []
get_tree_prediction(sentence, tokenizer, (tree_model,None,None))
get_dn_prediction(sentence, tokenizer, (None,dn_model,None))
print("----------Done sequential predictions-------------")
print('\n--------Start Parallel predictions--------------')
tr_p = mp.Process(target=get_tree_prediction, args=(sentence, tokenizer,
(tree_model,None,None)))
tr_p.start()
processes.append(tr_p)
dn_p = mp.Process(target=get_dn_prediction, args=(sentence, tokenizer,
(None,dn_model,None)))
dn_p.start()
processes.append(dn_p)
for p in processes:
p.join()