How to add simple custom pytorch-crf layer on top of TokenClassification model using pytorch and Trainer

511 Views Asked by At

I followed this link, but its implemented in Keras.

Cannot add CRF layer on top of BERT in keras for NER

Model description

Is it possible to add simple custom pytorch-crf layer on top of TokenClassification model. It will make the model more robust.

from torchcrf import CRF

model_checkpoint = "dslim/bert-base-NER"
tokenizer = BertTokenizer.from_pretrained(model_checkpoint,add_prefix_space=True)
config = BertConfig.from_pretrained(model_checkpoint, output_hidden_states=True)
bert_model = BertForTokenClassification.from_pretrained(model_checkpoint,id2label=id2label,label2id=label2id,ignore_mismatched_sizes=True)


class BERT_CRF(nn.Module):
    
    def __init__(self, bert_model, num_labels):
        super(BERT_CRF, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(0.25)
        
        self.classifier = nn.Linear(4*768, num_labels)

        self.crf = CRF(num_labels, batch_first = True)
    
    def forward(self, input_ids, attention_mask,  labels=None, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        
        **sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**
        sequence_output = self.dropout(sequence_output)
        
        emission = self.classifier(sequence_output) # [32,256,17]
        labels=labels.reshape(attention_mask.size()[0],attention_mask.size()[1])
        
        if labels is not None:    
            loss = -self.crf(log_soft(emission, 2), labels, mask=attention_mask.type(torch.uint8), reduction='mean')
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return [loss, prediction]
                
        else:         
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return prediction

args = TrainingArguments(
    "spanbert_crf_ner-pos2",
    # evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    per_device_train_batch_size=8,
    # per_device_eval_batch_size=32
    fp16=True
    # bf16=True #Ampere GPU
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_data,
    # eval_dataset=train_data,
    # data_collator=data_collator,
    # compute_metrics=compute_metrics,
    tokenizer=tokenizer)

I get error on line **sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**

As outputs = self.bert(input_ids, attention_mask=attention_mask) gives the logits for tokenclassification. How can we get hidden states so that I can concate last 4 hidden states. so that I can dooutputs[1][-1]`?

Or is their easier way to implement BERT-CRF model?

1

There are 1 best solutions below

2
Andrei Pop On BEST ANSWER

i know it's 10 months later, but maybe it helps other guys

Here is what I used for Trainer and it works in hyperparameter_search too:

class BERT_CRF_Config(PretrainedConfig):
    model_type = "BERT_CRF"

    def __init__(self, **kwarg):
        super().__init__(**kwarg)
        self.model_name = "BERT_CRF"
        self.use_last_n_hidden_states = 1
        self.dropout = 0.5

class BERT_CRF(PreTrainedModel):
    config_class = BERT_CRF_Config

    def __init__(self, config):
        super().__init__(config)

        bert_config = BertConfig.from_pretrained(config.bert_name)

        bert_config.output_attentions = True
        bert_config.output_hidden_states = True

        self.bert = AutoModel.from_pretrained(config.bert_name, config=bert_config)

        self.dropout = nn.Dropout(p=config.dropout)

        self.linear = nn.Linear(
            self.bert.config.hidden_size*config.use_last_n_hidden_states, config.num_labels)
        
        self.crf = CRF(config.num_labels, batch_first=True)

    def forward(self,  input_ids = None, attention_mask = None, labels = None,
                labels_mask=None,  token_type_ids=None, return_dict = None, **kwargs):

        if not torch.is_tensor(input_ids):
          input_ids = torch.tensor(input_ids).to(self.device)

        if not torch.is_tensor(token_type_ids):
          token_type_ids = torch.tensor(token_type_ids).to(self.device)

        if not torch.is_tensor(attention_mask):
          attention_mask = torch.tensor(attention_mask).to(self.device)

        if not torch.is_tensor(labels):
          labels = torch.tensor(labels).to(self.device)

        if not torch.is_tensor(labels_mask):
          labels_mask = torch.tensor(labels_mask).to(self.device)

        bert_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, 
                                attention_mask=attention_mask)
        # last_hidden_layer = bert_output['last_hidden_state']
        # logits = self.linear(last_hidden_layer)

        last_hidden_layers = torch.cat(bert_output['hidden_states'][-self.config.use_last_n_hidden_states:], dim=2)
        last_hidden_layers = self.dropout(last_hidden_layers)
        logits = self.linear(last_hidden_layers)

        def to_tensor(x):
          x = list(map(lambda y: torch.as_tensor(y), x))
          x = torch.nested.as_nested_tensor(x)
          x = torch.nested.to_padded_tensor(x,padding=0)

          x = torch.clamp(x, min=0)

          return x

        if labels is not None:
          log_likelihood, outputs = (
                                     self.crf(logits, labels, mask=labels_mask.bool()), 
                                     self.crf.decode(logits, mask=labels_mask.bool())
                                    )
          outputs = to_tensor(outputs)
          loss = -log_likelihood
          if not return_dict:
            return loss, outputs
          else:
            return TokenClassifierOutput(
                loss=loss,
                logits=outputs,
                hidden_states=bert_output.hidden_states,
                attentions=bert_output.attentions,
            )
        
        outputs = self.crf.decode(logits, batch_first=True)
        outputs = to_tensor(outputs)

        return outputs

    @property
    def device(self):
        return next(self.parameters()).device

and for your hyperparameter search you can use something like this:

def optuna_hp_space(trial):
    return {
        "learning_rate": trial.suggest_categorical("learning_rate", [1e-5, 2e-5, 2e-5, 4e-5, 5e-5, 6e-5]),
        "warmup_ratio": trial.suggest_categorical("warmup_ratio", [0, 0.1, 0.2, 0.3]),
        "weight_decay": trial.suggest_categorical("weight_decay", [1e-6, 1e-5, 1e-4]),
        "max_grad_norm": trial.suggest_categorical("max_grad_norm", [8, 9,10,11]),
    }

def model_init_crf(trial):
    config = BERT_CRF_Config.from_pretrained(BERT_MODEL, num_labels=NR_LABELS, )
    config.bert_name = BERT_MODEL
    config.dropout = trial.suggest_categorical("dropout", [0, 0.10,  0.30,  0.50])
    config.use_last_n_hidden_states = trial.suggest_categorical("last_n_hidden_states",
                                                          range(1, config.num_hidden_layers+1))

    model = BERT_CRF(config).to('cuda')
    return model

best_trial = trainer.hyperparameter_search(
    direction="maximize",
    backend="optuna",
    hp_space=optuna_hp_space,
    n_trials=50,
    compute_objective=my_objective,
)