Optuna pruning for validation loss

1.9k Views Asked by At

I introduced the following lines in my deep learning project in order to early stop when the validation loss has not improved for 10 epochs:

if best_valid_loss is None or valid_loss < best_valid_loss:
    best_valid_loss = valid_loss
    counter = 0 
else:
    counter += 1 
    if counter == 10: 
        break

Now I want to use Optuna to tune some hyperparameters, but I don't really understand how pruning works in Optuna. Is it possible for Optuna pruners to act the same way as in the code above? I assume I have to use the following:

optuna.pruners.PatientPruner(???, patience=10)

But I don't know which pruner I could use inside PatientPruner. Btw in Optuna I'm minimizing the validation loss.

1

There are 1 best solutions below

1
On

Short answer: Yes.

Hi, I'm one of the authors of PatientPruner in Optuna. If we perform vanilla early-stopping, wrapped_pruner=None works as we expected. For example,

import optuna

def objective(t):
    for step in range(30):
        if step == 5:
            t.report(0., step=step)
        else:
            t.report(step * 0.1, step=step)
        if t.should_prune():
            print("pruned at {}".format(step))
            raise optuna.exceptions.TrialPruned()
            
    return 1.

study = optuna.create_study(pruner=optuna.pruners.PatientPruner(None, patience=9), direction="minimize")
study.optimize(objective, n_trials=1)

The output will be pruned at 15.