so i was trying to train a chatbot using transformers for my ai assistant , i thought simpletransformer package in python would help me speed up alot of my tasks . I soon gathered a good dataset over kaggle (https://www.kaggle.com/datasets/arnavsharmaas/chatbot-dataset-topical-chat) to train my chatbot , i loaded up the data did some preprocessing and transformed it into one column input_text another target_text as mentioned in the docs. Then i trained my model with encoder type as roberta and decoder type as bert as thats what selected by default and it cannot be changed i saw it in the docs . I trained it on the first 1k samples and see if the code is working at first try i gave it one line from the dataset and it just spammed the word my the result was #mymymymymy
i restarted my runtime and trained again this time it always generated an empty string , i was expecting proper results . Here are the code snippets:-
Loading and preprocessing data :-
import pandas as pd
df=pd.read_csv("../input/chatbot-dataset-topical-chat/topical_chat.csv")
#converting to required format
new_df={"input_text":[],'target_text':[]}
for i in range(0,df.shape[0]):
if i%2==0:
new_df['input_text'].append(df['message'][i])
else:
new_df['target_text'].append(df['message'][i])
new_df=pd.DataFrame(new_df)
new_df.head()
works uptil here and the code for training the transformer is
!pip install simpletransformers
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs
model_args = Seq2SeqArgs()
model_args.num_train_epochs = 3
model_args.overwrite_output_dir = True
model = Seq2SeqModel(
"roberta",
"roberta-base",
"bert-base-cased",
args=model_args,
)
model.train_model(new_df.head(1000))
finally i asked it to predict a sample from the dataframe it once spammed a word like i said after restart it produces empty string Can anyone help me please?
If it is a Seq2Seq Mapping problem, I find that BART does a better job than RoBERTa
Also, preferably have more samples in your input training data as well. Here is a script with BART to get you started: