I have a few questions about transfer learning from Bart's model.
Namely - my plan (in short):
- create a tf.keras.model that summarizes the text
- learning from Bart's model by passing the last hidden state.
Then I try to train a new model for my CSV data, then refine the model, save the model and retrieve it as AutoModelForSeq2SeqLM.from_pretrained()
so that I can access the generate method.
Questions:
- Can I pass the entire bart model instead of
last_hidden_state
(I've seen people do this, why?) - How can I make my model.fit teach on X and Y (features and labels)
Can someone point me in the right direction? Tell me how the code below can also include labels, because this code reminds me more of classifications than of working on labels
If there's a problem with the logic, sorry, I'm new at this
import tensorflow as tf
import tensorflow_text as tf_text
import numpy as np
import pandas as pd
from transformers import TFBartModel,PreTrainedTokenizerFast
import os
df_train = pd.read_csv('./training_data/train.csv',index_col=False)
X_train = df_train.drop("target", axis=1)
Y_train = df_train["target"]
model_dir = "bart_base_transformers"
trained_model = TFBartModel.from_pretrained(model_dir,from_pt=True)
tokenizer = PreTrainedTokenizerFast(tokenizer_file=os.path.join(model_dir, "tokenizer.json"),skip_special_tokens=True)
tokenizer.add_special_tokens({'pad_token':'[PAD]'})
xxx_train = tokenizer(X_train["source"].tolist(),max_length=1200,padding=True,truncation=True,return_tensors='tf')
for layer in trained_model.layers:
layer.trainable = False
max_len = 36544
input_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_ids")
input_mask = Input(shape=(max_len,), dtype=tf.int32, name="attention_mask")
bert_inputs = {'input_ids': input_ids, 'input_mask': input_mask}
model_output = trained_model(input_ids, input_mask, output_hidden_states=True)
embeddings = model_output.last_hidden_state
y = Dense(tokenizer.vocab_size,)(embeddings)
new_model = Model(inputs=bert_inputs, outputs=y)
#what next to fit model with X_train and Y_train ??