huggingface distillbert classification using multiprocessing

1.7k Views Asked by At

I am trying to use torch multiprocessing to parallelize the predictions from two separate huggingface distillbert classification models. It seems to be deadlocked at the prediction step. I am using python 3.6.5, torch 1.5.0 and huggingface transformers version 2.11.0. The output from running the code is

Tree enc done
Begin tree prediction<------(Comment: Both begin tree
End tree predictions<-------  and end tree predictions)
0.03125429153442383
Dn prediction
Dn enc done
Begin dn predictions<------(Comment: Both begin dn
End dn predictions<-------  and end dn predictions)
0.029727697372436523
----------Done sequential predictions-------------

--------Start Parallel predictions--------------
Tree prediction
Tree enc done
Begin tree prediction. <------(Comment: Process is deadlocked after this)
Dn prediction
Dn enc done
Begin dn predictions. <-------(Comment: Process is deadlocked after this)

During parallel predictions it seems to be deadlocking and not printing out "End tree predictions" and "End dn predictions". Not sure why this is happening. The code is

import torch
import torch.multiprocessing as mp
import time
import transformers
from transformers import  DistilBertForSequenceClassification

# Load the BERT tokenizer.
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

tree_model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels = 2,
    output_attentions = False, 
    output_hidden_states = False
)
tree_model.eval()

dn_model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels = 2,
    output_attentions = False, 
    output_hidden_states = False, 
)
dn_model.eval()


tree_model.share_memory()
dn_model.share_memory()


def predict(sentences =[], tokenizer=tokenizer,models=(tree_model,dn_model,None)):
  MAX_SENTENCE_LENGTH = 16
  start = time.time()
  input_ids = []
  attention_masks = []
  predictions = []

  tree_model = models[0]
  dn_model = models[1]

  if models[0]:
      print("Tree prediction")
  if models[1]:
    print("Dn prediction")
  for sent in sentences:
    encoded_dict = tokenizer.encode_plus(
                        sent,                      
                        add_special_tokens = True, 
                        max_length = MAX_SENTENCE_LENGTH,
                        pad_to_max_length = True,
                        return_attention_mask = True,   
                        return_tensors = 'pt',     
                   )

    # Add the encoded sentence to the list.    
    input_ids.append(encoded_dict['input_ids'])

    # And its attention mask (simply differentiates padding from non-padding).
    attention_masks.append(encoded_dict['attention_mask'])

  if tree_model:
      print("Tree enc done")
  if dn_model:
    print("Dn enc done")

  # Convert the lists into tensors.
  new_input_ids = torch.cat(input_ids, dim=0)
  new_attention_masks = torch.cat(attention_masks, dim=0)

  with torch.no_grad():
      # Forward pass, calculate logit predictions
    if tree_model:
      print("Begin tree prediction")
      outputs = tree_model(new_input_ids, 
                      attention_mask=new_attention_masks) 
      print("End tree predictions")
    else:
      print("Begin dn predictions")
      outputs = dn_model(new_input_ids, 
                      attention_mask=new_attention_masks)
      print("End dn predictions")
  logits = outputs[0]
  logits = logits.detach().cpu()
  print(time.time()-start)
  predictions = logits
  return predictions



def get_tree_prediction(sentence, tokenizer=tokenizer,models=(tree_model,dn_model, None)):
    return predict(sentences =[sentence], tokenizer=tokenizer,models=models)

def get_dn_prediction(sentence, tokenizer=tokenizer,models=(tree_model,dn_model, None)):
  return predict(sentences =[sentence], tokenizer=tokenizer,models=models)


if __name__ == '__main__':
    sentence = "hello world"
    processes = []
    get_tree_prediction(sentence, tokenizer, (tree_model,None,None))
    get_dn_prediction(sentence, tokenizer, (None,dn_model,None))
    print("----------Done sequential predictions-------------")

    print('\n--------Start Parallel predictions--------------')
    tr_p = mp.Process(target=get_tree_prediction, args=(sentence, tokenizer,
                                                         (tree_model,None,None)))

    tr_p.start()
    processes.append(tr_p)

    dn_p = mp.Process(target=get_dn_prediction, args=(sentence, tokenizer,
                                                       (None,dn_model,None)))
    dn_p.start()
    processes.append(dn_p)

    for p in processes:
        p.join()
0

There are 0 best solutions below