Get trial number in Keras callback function using Autokeras

299 Views Asked by At

I'm running an autokeras training, and I want to compute the global progress of the training, where the progress is defined by the total epochs run, including previous trials.

Unfortunatelly, I don't know how to get the trial number, or how to keep a global epoch count.

Here it is a snippet of the code:

class ReportingCallback(keras.callbacks.Callback):
   def __init__(self, trials_total)
      self.trials_total = trials_total

   def on_epoch_end(self, epoch, logs=None):
      epochs_per_trial = self.params["epochs"]
      epochs_total = epochs_per_trial * self.trials_total
      i_trial = ????
      epochs_current = (i_trial * epochs_per_trial) + epoch
      print("Progress: " + str() "/" + str(epochs_total) )



def automl(train_x, train_y): 
     max_trials = 5

     clf = ak.StructuredDataClassifier(max_trials=max_trials)  
        
     clf.fit(
           train_x,
           train_y,
           epochs=100,
           callbacks = [ReportingCallback(max_trials)]
     )

0

There are 0 best solutions below