Issue with fine tuning GPT2 for IPA transcription model

22 Views Asked by At

I am trying to fine tune GPT2 model for my task of finding name pronunciation from the sentence using input as phonetic transcription of the sentence and output as the name of the person. The input is the expected string and the output is supposed to be extracted name from the transcription. I have pasted the code below.

#!/usr/bin/env python
# coding: utf-8

# Lets import the relevant things for the model. 

# In[48]:


import torch
from torch.utils.data import DataLoader,Dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel,GPT2Config, AdamW
from torch.nn.utils.rnn import pad_sequence


# In[49]:


def collate_batch(batch):
    # Sort the batch by input sequence length
    batch = sorted(batch, key=lambda x: len(x['input_ids']))

    # Pad sequences to have the same length within the batch
    input_ids = pad_sequence([item['input_ids'] for item in batch], batch_first=True)
    target_ids = pad_sequence([item['target_ids'] for item in batch], batch_first=True)

    return {'input_ids': input_ids, 'target_ids': target_ids}


# In[50]:


def fine_tune_model(train_dataset, model, tokenizer, epochs=3, learning_rate=1e-5):
    # Set the device (CPU or GPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Create DataLoader for training dataset
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True,  collate_fn=collate_batch)

    # Set up optimizer and loss function
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    # Fine-tune the model
    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            input_ids = batch['input_ids'].squeeze().to(device)
            target_ids = batch['target_ids'].squeeze().to(device)
            
              # Shift target_ids by one position to the right
            shifted_target_ids = torch.roll(target_ids, shifts=1, dims=1)
              # Check if tokenizer.pad_token_id is None and replace it with a valid token ID
            if tokenizer.pad_token_id is None:
                pad_token_id = 0  # Replace with a valid non-padding token ID
            else:
                pad_token_id = tokenizer.pad_token_id

            shifted_target_ids[:, 0] = pad_token_id  # Set the first position to the padding token

            # Forward pass
            outputs = model(input_ids, labels=shifted_target_ids)
            loss = outputs.loss

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        average_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {average_loss}")

    print("Fine-tuning complete!")


# In[51]:


def extract_phonetic_transcription(model, tokenizer, input_text):
    # Tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt')

    # Generate output from the model
    with torch.no_grad():
        output_ids = model.generate(input_ids)

    # Decode the output and extract the relevant part
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Assuming the relevant part starts after the first space
    relevant_part = output_text.split(' ', 1)[1]

    return relevant_part


# In[52]:


class PhoneticTranscriptionDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        # Tokenize input and target sequences
        input_text = sample['input']
        target_text = sample['target']

        input_ids = self.tokenizer.encode(input_text, return_tensors='pt').squeeze()
        target_ids = self.tokenizer.encode(target_text, return_tensors='pt').squeeze()

        return {'input_ids': input_ids, 'target_ids': target_ids}


# In[53]:


# Example usage
if __name__ == "__main__":
    
    # Your dataset with pairs of input and target sequences
    training_data = [
        {'input': "m a j n e j m ɪ z s ʌ m i ɹ z o w ʃ i", 'target': "s ʌ m i ɹ z o w ʃ i"},
        {'input': "m a j n e j m ɪ z b ʌ s o w ɹ ɑ d͡ʒ ɡ ʊ l i", 'target': "b ʌ s o w ɹ ɑ d͡ʒ ɡ ʊ l i"}
        # Add more examples as needed
    ]
    
 

    # Load pre-trained GPT-2 model and tokenizer
    model_name = 'gpt2'
    model = GPT2LMHeadModel.from_pretrained(model_name)
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    
    # Create an instance of the dataset
    train_dataset = PhoneticTranscriptionDataset(training_data, tokenizer)

    # Fine-tune the model on your dataset (provide your own implementation)
    # train_dataset = ...  # Your prepared dataset
    fine_tune_model(train_dataset, model, tokenizer)

    # Example input
    input_phonetic_transcription = "m a j n e j m ɪ z s ʌ m i ɹ z o w ʃ i"

    # Extract relevant part using the model
    relevant_part = extract_phonetic_transcription(model, tokenizer, input_phonetic_transcription)

    print("Input Phonetic Transcription:", input_phonetic_transcription)
    print("Extracted Relevant Part:", relevant_part)


# In[ ]:





# In[ ]:




Following is the error I see:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[53], line 23
     19 train_dataset = PhoneticTranscriptionDataset(training_data, tokenizer)
     21 # Fine-tune the model on your dataset (provide your own implementation)
     22 # train_dataset = ...  # Your prepared dataset
---> 23 fine_tune_model(train_dataset, model, tokenizer)
     25 # Example input
     26 input_phonetic_transcription = "m a j n e j m ɪ z s ʌ m i ɹ z o w ʃ i"

Cell In[50], line 33, in fine_tune_model(train_dataset, model, tokenizer, epochs, learning_rate)
     30 shifted_target_ids[:, 0] = pad_token_id  # Set the first position to the padding token
     32 # Forward pass
---> 33 outputs = model(input_ids, labels=shifted_target_ids)
     34 loss = outputs.loss
     36 # Backward pass and optimization

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\transformers\models\gpt2\modeling_gpt2.py:1108, in GPT2LMHeadModel.forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1106     # Flatten the tokens
   1107     loss_fct = CrossEntropyLoss()
-> 1108     loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
   1110 if not return_dict:
   1111     output = (lm_logits,) + transformer_outputs[1:]

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\loss.py:1174, in CrossEntropyLoss.forward(self, input, target)
   1173 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1174     return F.cross_entropy(input, target, weight=self.weight,
   1175                            ignore_index=self.ignore_index, reduction=self.reduction,
   1176                            label_smoothing=self.label_smoothing)

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\functional.py:3029, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3027 if size_average is not None or reduce is not None:
   3028     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3029 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

ValueError: Expected input batch_size (72) to match target batch_size (50).

0

There are 0 best solutions below