Early stopping with Pycaret? Overfitting with Catboost and XGBoost

2.3k Views Asked by At

I'm comparing the performance of Catboost, XGBoost and LinearRegression in Pycaret. Catboost and XGBoost are untuned.

So far I see that Catboost and XGBoost are overfitting.

enter image description here

For linear regression train/test-score is train R2: 0.72, test R2: 0.65

Is there a way to set a 'Early Stopping' for XGBoost and Catboost to avoid this overfit? Or is there other parameters to tune in Pycaret to avoid overfitting?

2

There are 2 best solutions below

0
On BEST ANSWER

There exists more possibilities, how to avoid an overfit.

  • Feature Selection (cann be set up in the setup) - there are two types and variable threshold OR RFE (recursive feature elimination) or SHAP
  • tune the both - Catboost, XGBoost (or the other tree algorithms)
  • increase the n_estimators=100 or 500, or 1000
  • run the algorithms several times
  • change sampling 80/20, 70/30 etc.
  • remove correlated inputs
0
On

First, how are you comparing models without tuning hyperparameters? Seeing your code would be helpful.

There is an early stopping parameter in pycaret, but I'm not sure what it's doing. It's also only available for the tune_model function. If you allow pycaret to auto-search hyperparameters for xgboost and catboost, they should no longer overfit. This is because they will be tuning the regularization hyperparameter (L1 and/or L2 regularizations on the leaf weights) and will be comparing scores across the validation sets.

With catboost (or xgboost or lightgbm), you can set the early_stopping_rounds parameter to enable early stopping:

import catboost

cb = catboost.CatBoostClassifier(n_estimators=1000)
cb.fit(x_train, y_train, eval_set=(x_test, y_test), early_stopping_rounds=10, plot=True)

You need to provide the eval_set, otherwise, it will have nothing to evaluate for early stopping. I don't think it's possible at the moment to add early_stopping_rounds as a parameter to any of the relevant pycaret functions you are probably using.