I was wondering how you can resume training from a checkpoint with different hyperparameter config when training with transformers library. Given the example below, no matter what you change in the training_args, these will be overridden by whatever training args are saved in the checkpoint. The transformers library does not have the ability to change training arguments when resuming from a checkpoint. Some things like eval, batch_size and save_steps are overridable if you amend the checkpoint's JSON config, but other hyperparameters are not.
Given a non-PEFT model, you could just save the entire model from the checkpoitn, load it up and call trainer.train() on it to achieve this behaviour, but given a PEFT setup I'm not sure how you can do this?
from peft import prepare_model_for_kbit_training
ft_model.gradient_checkpointing_enable()
ft_model = prepare_model_for_kbit_training(model)
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"w1",
"w2",
"w3",
"lm_head",
],
bias="none",
lora_dropout=0.05, # Conventional
task_type="CAUSAL_LM",
)
print_trainable_parameters(ft_model)
ft_model = accelerator.prepare_model(ft_model)
import transformers
from datetime import datetime
tokenizer.pad_token = tokenizer.eos_token
learning_rate = 5e-5
warmup_steps = 100
gradient_accumulation_steps = 2
trainer = transformers.Trainer(
model=model,
callbacks=[upload_checkpoint_callback],
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_val_dataset,
args=transformers.TrainingArguments(
output_dir=output_dir,
warmup_steps=warmup_steps,
per_device_train_batch_size=8,
gradient_checkpointing=True,
gradient_accumulation_steps=gradient_accumulation_steps,
max_steps=5000,
learning_rate=learning_rate,
logging_steps=10,
fp16=True,
optim="paged_adamw_8bit",
logging_dir="/content/logs",
save_strategy="steps",
save_steps=10,
evaluation_strategy="steps",
eval_steps=10,
load_best_model_at_end=True,
report_to="wandb",
run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train(resume_from_checkpoint="/content/latest_checkpoint/")