Kfold validation_data in Keras model.fit using Sci-Kit Learn GridsearchCV

82 Views Asked by At

I'm working with Keras, using Sci-Kit Learn gridsearchcv and Kold and SciKeras wrappers. I would to pass the validation folders of Kfold to the fit method of the model, by means of the parameter validation_data. I tried some alternatives but I can't do it. Here's the code.

NN = KerasClassifier(
  model=get_NN,
  X_len = len(X_train.columns),
  loss="mse",
  optimizer="SGD",
  epochs=300,
  batch_size=4,
  shuffle=True,
  verbose=False,
  # fit__validation_data = # Here I should pass the validation data
  callbacks=[
    tf.keras.callbacks.EarlyStopping(
      monitor="val_loss", min_delta=0.0001, patience=15, restore_best_weights=True
    )
  ]
)

custom_scores_monk = {
    "accuracy": "accuracy",
    "mse": make_scorer(mean_squared_error,greater_is_better=False)
}

NN_MONK1_GRID_DICT = {
  "model__lr" : [0.5], 
  "model__alpha" : [0.8],
  "model__hidden_activation" : ["tanh"],
  "model__neurons" : [4], 
  "model__initializer" : ["glorot"], 
  "model__nesterov" : [True], 
  "model__penalty": [None], 
  "model__lambda_reg": [None],
  "model__seed" : [15]
}

grid = GridSearchCV(NN,
                    param_grid=NN_MONK1_GRID_DICT,
                    scoring=custom_scores_monk,
                    refit="mse",
                    cv=CV,
                    return_train_score=True,
                    n_jobs=-1
        )

Between the others alternatives, I tried writing a custom callback for updating the data set on_train_begin, but It seems to be a dirty practice, I'm not surprised It doesn't work.

class ValidationCallback(Callback):
  def __init__(self, X, y, validation_split):
    super().__init__()
    self.X = X
    self.y = y
    self.validation_split = validation_split
    self.count = 0

  def on_train_begin(self, logs=None):
    print("Training " + str(self.count))
    indexes = self.validation_split[self.count]
    X_val, y_val = [self.X.iloc[i] for i in indexes], [self.y.iloc[i] for i in indexes]
    self.count = self.count+1
    self.model.fit__validation_data = (X_val, y_val)

Instead, I'm very surprised there is no a solution for a so common task as the KFold cross validation, especially using framework as skl. In particular, this problem leads to the impossibility to use 'val_loss' as monitor value for early stopping, apart from the impossibility to plot and compare training and validation learning curves.

Do You have solutions?

1

There are 1 best solutions below

0
Mac On BEST ANSWER

I spent about a week on that and I finally found a way.

Short answer: don't do it. Just handwrite an ad-hoc method for grid search and use it.

Long answer: you can define a subclass of the SciKeras' wrapper, in order to redefine the fit method passing the current fold to it. To do that, you must:

  1. know the fold that will be used, and thus you must set a random_state in your CV object
    # define a split strategy using a random_state
    CV = StratifiedKFold(n_splits=5, random_state=42, shuffle=True)
    
    # get the validation folds
    val_split = [ test for (train, test) in CV.split(X_train, y_train) ]
    
    val_data = [ 
      (
        [X_train.iloc[i].tolist() for i in indexes], 
        [y_train.iloc[i].tolist() for i in indexes]
      ) for indexes in val_split 
    ]
  1. define a "static" counter for the folds
    # static fold counter
    def count():
      count.count += 1
      return count.count
    
    def reset_counter():
      count.count =-1
    
    def get_count():
      return count.count
  1. in the same way, you must define a registry for memorizing the various history objects
    # static history register
    def histories():
      histories.histories = []
    
    def register(h):
      histories.histories.append(h)
    
    def get_histories():
      return histories.histories
    
    def clear_histories():
      histories()
  1. define a method for computing the mean of the histories. It allows to do early stopping on the validation loss.
    # utilities to get the mean of K histories
    
    def add_padding(ls, n):
      ls.extend([ls[-1]] * n)
      return ls
    
    def mean_epochs(l):
      return int(mean([ len(item['loss']) for item in l ]))
    
    def mean_history(_histories):
      m = mean_epochs(_histories)+1
      for history in _histories:
        l = len(history['loss'])
        for field in _histories[0]:
          if l>= m:
            history[field] = history[field][:m]
          else:
            history[field] = add_padding(history[field], (m-l))
      return \
        { field : 
            [ 
              (sum(x)/len(_histories)) for x in zip(
                *[ history[field] for history in _histories ]
              )
            ] for field in _histories[0]
        }
  1. extend the SciKeras wrapper class, redefining the fit method
    # KerasClassifier Wrapper for kfold
    class KCWrapper(KerasClassifier):
    
      # you can pass the same parameters you passed to the KerasClassifier, after val_data and k
      def __init__(self, val_data, k, *args, **kwargs):
        super(KCWrapper, self).__init__(*args, **kwargs)
        self.val_data = val_data
        self.k = k
      
      def fit(self, X, y, **kwargs):
        h = super().fit(X, y, validation_data=self.val_data[count()], **kwargs)
        register(h.history_)
        # do_NN_plot(h.history_)  # plot single fold curve
        if self.kfold_finished(): # plot mean of k folds curves
          do_NN_plot(mean_history(get_histories()))
        
      def kfold_finished(self):
        return self.k == get_count()+1
  1. instantiate the (wrapper of the wrapper of the) classifier
    # Define grids for gridsearchcv
    kerasClassifierParams = {
      "model" : get_NN,
      "X_len" : len(X_train.columns),
      "loss" : "mse",
      "optimizer" : "SGD",
      "epochs" : 300,
      "batch_size" : 4,
      "shuffle" : True,
      "verbose" : False
    }

    NN = KCWrapper(
      val_data,
      5, # 5-Fold
      callbacks=[
        tf.keras.callbacks.EarlyStopping(
          monitor="val_loss", min_delta=0.0001, patience=20, restore_best_weights=True
        )
      ],
      **kerasClassifierParams
    )

The provided code also uses a routine for plotting data:

    def do_NN_plot(history):
    
      # Plot Accuracy
      plt.plot(history['binary_accuracy'])
      plt.plot(history['val_binary_accuracy'], linestyle="--", color="orange")
      plt.title(f'model accuracy')
      plt.ylabel('accuracy')
      plt.xlabel('epoch')
      plt.legend(['training', 'test'], loc='lower right')
      plt.show()
    
      # Plot loss
      plt.plot(history['loss'])
      plt.plot(history['val_loss'],  linestyle="--", color="orange")
      plt.title(f'model MSE')
      plt.ylabel('MSE')
      plt.xlabel('epoch')
      plt.legend(['training', 'test'], loc='upper right')
      plt.show()

If you're working on a regression task, you can do the same thing with a wrapper of (a wrapper of) a regressor:

    # define a split strategy using a random_state
    CV = StratifiedKFold(n_splits=5, random_state=42, shuffle=True)
    
    # get the validation folds
    val_split = [ test for (train, test) in CV.split(X_train, y_train) ]
    
    val_data = [ 
      (
        [X_train.iloc[i].tolist() for i in indexes], 
        [y_train.iloc[i].tolist() for i in indexes]
      ) for indexes in val_split 
    ]

    # static fold counter
    def count():
      count.count += 1
      return count.count
    
    def reset_counter():
      count.count =-1
    
    def get_count():
      return count.count

    # static history register
    def histories():
      histories.histories = []
    
    def register(h):
      histories.histories.append(h)
    
    def get_histories():
      return histories.histories
    
    def clear_histories():
      histories()

    # utilities to get the mean of K histories
    
    def add_padding(ls, n):
      ls.extend([ls[-1]] * n)
      return ls
    
    def mean_epochs(l):
      return int(mean([ len(item['loss']) for item in l ]))
    
    def mean_history(_histories):
      m = mean_epochs(_histories)+1
      for history in _histories:
        l = len(history['loss'])
        for field in _histories[0]:
          if l>= m:
            history[field] = history[field][:m]
          else:
            history[field] = add_padding(history[field], (m-l))
      return \
        { field : 
            [ 
              (sum(x)/len(_histories)) for x in zip(
                *[ history[field] for history in _histories ]
              )
            ] for field in _histories[0]
        }

    def do_NN_plot(history):
    
      # Plot Accuracy
      plt.plot(history['binary_accuracy'])
      plt.plot(history['val_binary_accuracy'], linestyle="--", color="orange")
      plt.title(f'model accuracy')
      plt.ylabel('accuracy')
      plt.xlabel('epoch')
      plt.legend(['training', 'test'], loc='lower right')
      plt.show()
    
      # Plot loss
      plt.plot(history['loss'])
      plt.plot(history['val_loss'],  linestyle="--", color="orange")
      plt.title(f'model MSE')
      plt.ylabel('MSE')
      plt.xlabel('epoch')
      plt.legend(['training', 'test'], loc='upper right')
      plt.show()


    # KerasRegressor Wrapper for kfold
    class KRWrapper(KerasRegressor):
    
      def __init__(self, val_data, k, *args, **kwargs):
        super(KRWrapper, self).__init__(*args, **kwargs)
        self.val_data = val_data
        self.k = k
        
      def fit(self, X, y, **kwargs):
        h = super().fit(X, y, validation_data=self.val_data[count()], **kwargs)
        register(h.history_)
        # do_NN_plot(h.history_)  # plot single fold curve
        if self.kfold_finished(): # plot mean of k folds curves
          do_NN_plot(mean_history(get_histories()))
        
      def kfold_finished(self):
        return self.k == get_count()+1

    # Define grids for gridsearchcv
    kerasRegressorParams = {
      "model" : get_NN,
      "X_len" : len(X_train.columns),
      "loss" : mee_NN,
      "optimizer" : "SGD", # fixed into get_NN
      "batch_size" : 32,
      "epochs" : 2000,
      "shuffle" : True,
      "verbose" : 0
    }
    
    NN = KRWrapper(
      val_data,
      5,
      callbacks=[
        tf.keras.callbacks.EarlyStopping(
          monitor="val_loss", min_delta=0.000001, patience=50, restore_best_weights=True
        )
      ],
      **kerasRegressorParams
    )

This satisfied my curiosity and stubbornness, but it's a dirty solution (even if it's still a solution :P). How I said at the beginning: just handwrite an ad-hoc method for grid search and use it. The solution presented above doesn't allow to use the intrinsic parallelization of the Skl's GridsearchCV, so it's a lot of completely useless work.

Note: the approach that uses the callback didn't work because the parameters of the fit method are passed before the callback is invoked. Thus when the callback is invoked, the setted fit__validation_data is no evaluated.