I am trying to fine tune GPT2 model for my task of finding name pronunciation from the sentence using input as phonetic transcription of the sentence and output as the name of the person. The input is the expected string and the output is supposed to be extracted name from the transcription. I have pasted the code below.
#!/usr/bin/env python
# coding: utf-8
# Lets import the relevant things for the model.
# In[48]:
import torch
from torch.utils.data import DataLoader,Dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel,GPT2Config, AdamW
from torch.nn.utils.rnn import pad_sequence
# In[49]:
def collate_batch(batch):
# Sort the batch by input sequence length
batch = sorted(batch, key=lambda x: len(x['input_ids']))
# Pad sequences to have the same length within the batch
input_ids = pad_sequence([item['input_ids'] for item in batch], batch_first=True)
target_ids = pad_sequence([item['target_ids'] for item in batch], batch_first=True)
return {'input_ids': input_ids, 'target_ids': target_ids}
# In[50]:
def fine_tune_model(train_dataset, model, tokenizer, epochs=3, learning_rate=1e-5):
# Set the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Create DataLoader for training dataset
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_batch)
# Set up optimizer and loss function
optimizer = AdamW(model.parameters(), lr=learning_rate)
# Fine-tune the model
for epoch in range(epochs):
model.train()
total_loss = 0
for batch in train_loader:
input_ids = batch['input_ids'].squeeze().to(device)
target_ids = batch['target_ids'].squeeze().to(device)
# Shift target_ids by one position to the right
shifted_target_ids = torch.roll(target_ids, shifts=1, dims=1)
# Check if tokenizer.pad_token_id is None and replace it with a valid token ID
if tokenizer.pad_token_id is None:
pad_token_id = 0 # Replace with a valid non-padding token ID
else:
pad_token_id = tokenizer.pad_token_id
shifted_target_ids[:, 0] = pad_token_id # Set the first position to the padding token
# Forward pass
outputs = model(input_ids, labels=shifted_target_ids)
loss = outputs.loss
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
average_loss = total_loss / len(train_loader)
print(f"Epoch {epoch + 1}/{epochs}, Loss: {average_loss}")
print("Fine-tuning complete!")
# In[51]:
def extract_phonetic_transcription(model, tokenizer, input_text):
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors='pt')
# Generate output from the model
with torch.no_grad():
output_ids = model.generate(input_ids)
# Decode the output and extract the relevant part
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Assuming the relevant part starts after the first space
relevant_part = output_text.split(' ', 1)[1]
return relevant_part
# In[52]:
class PhoneticTranscriptionDataset(Dataset):
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
# Tokenize input and target sequences
input_text = sample['input']
target_text = sample['target']
input_ids = self.tokenizer.encode(input_text, return_tensors='pt').squeeze()
target_ids = self.tokenizer.encode(target_text, return_tensors='pt').squeeze()
return {'input_ids': input_ids, 'target_ids': target_ids}
# In[53]:
# Example usage
if __name__ == "__main__":
# Your dataset with pairs of input and target sequences
training_data = [
{'input': "m a j n e j m ɪ z s ʌ m i ɹ z o w ʃ i", 'target': "s ʌ m i ɹ z o w ʃ i"},
{'input': "m a j n e j m ɪ z b ʌ s o w ɹ ɑ d͡ʒ ɡ ʊ l i", 'target': "b ʌ s o w ɹ ɑ d͡ʒ ɡ ʊ l i"}
# Add more examples as needed
]
# Load pre-trained GPT-2 model and tokenizer
model_name = 'gpt2'
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# Create an instance of the dataset
train_dataset = PhoneticTranscriptionDataset(training_data, tokenizer)
# Fine-tune the model on your dataset (provide your own implementation)
# train_dataset = ... # Your prepared dataset
fine_tune_model(train_dataset, model, tokenizer)
# Example input
input_phonetic_transcription = "m a j n e j m ɪ z s ʌ m i ɹ z o w ʃ i"
# Extract relevant part using the model
relevant_part = extract_phonetic_transcription(model, tokenizer, input_phonetic_transcription)
print("Input Phonetic Transcription:", input_phonetic_transcription)
print("Extracted Relevant Part:", relevant_part)
# In[ ]:
# In[ ]:
Following is the error I see:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[53], line 23
19 train_dataset = PhoneticTranscriptionDataset(training_data, tokenizer)
21 # Fine-tune the model on your dataset (provide your own implementation)
22 # train_dataset = ... # Your prepared dataset
---> 23 fine_tune_model(train_dataset, model, tokenizer)
25 # Example input
26 input_phonetic_transcription = "m a j n e j m ɪ z s ʌ m i ɹ z o w ʃ i"
Cell In[50], line 33, in fine_tune_model(train_dataset, model, tokenizer, epochs, learning_rate)
30 shifted_target_ids[:, 0] = pad_token_id # Set the first position to the padding token
32 # Forward pass
---> 33 outputs = model(input_ids, labels=shifted_target_ids)
34 loss = outputs.loss
36 # Backward pass and optimization
File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\transformers\models\gpt2\modeling_gpt2.py:1108, in GPT2LMHeadModel.forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1106 # Flatten the tokens
1107 loss_fct = CrossEntropyLoss()
-> 1108 loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1110 if not return_dict:
1111 output = (lm_logits,) + transformer_outputs[1:]
File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\loss.py:1174, in CrossEntropyLoss.forward(self, input, target)
1173 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1174 return F.cross_entropy(input, target, weight=self.weight,
1175 ignore_index=self.ignore_index, reduction=self.reduction,
1176 label_smoothing=self.label_smoothing)
File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\functional.py:3029, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3027 if size_average is not None or reduce is not None:
3028 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3029 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (72) to match target batch_size (50).