callbacks in sklearn.multiclass.OneVsRestClassifier

32 Views Asked by At

I want use callbacks and eval_set etc. but i have a problem:

from sklearn.multiclass import OneVsRestClassifier
import lightgbm
verbose = 100
params = {
    "objective": "binary",
    "n_estimators": 500,
    "verbose": 0
}
fit_params = {
    "eval_set": eval_dataset,
    "callbacks": [CustomCallback(verbose)]
}

clf = OneVsRestClassifier(lightgbm.LGBMClassifier(**params))
clf.fit(X_train, y_train,  **fit_params)

how i can hand over fit_params to my estimator? I get

----------------------------------------------------------------------
---> 13 clf.fit(X_train, y_train,  **fit_params)

TypeError: OneVsRestClassifier.fit() got an unexpected keyword argument 'eval_set'
1

There are 1 best solutions below

0
James Lamb On BEST ANSWER

Per scikit-learn's docs for OneVsRestClassifier (link), as of v1.4.0 additional **fit_params are only passed through to estimators' fit() methods if you've enabled what scikit-learn calls "metadata routing".

There are 2 required steps which are missing in your example:

  • opting in by running sklearn.set_config(enable_metadata_routing=True)
  • explicitly telling scikit-learn to pass through eval_set and callbacks, via .set_fit_request().

(docs link)

Consider this minimal, reproducible example using Python 3.11, lightgbm==4.3.0, and scikit-learn==1.4.1.

import lightgbm as lgb
import sklearn
from sklearn.datasets import make_blobs
from sklearn.multiclass import OneVsRestClassifier

# enable metadata_routing
sklearn.set_config(enable_metadata_routing=True)

# create datasets
X, y = sklearn.datasets.make_blobs(
    n_samples=10_000,
    n_features=10,
    centers=2
)
eval_dataset = lgb.Dataset(X, label=y)
eval_results = {}

# construct estimator
params = {
    "objective": "binary",
    "n_estimators": 10,
}
fit_params = {
    "eval_set": (X, y),
    "callbacks": [lgb.record_evaluation(eval_results)]
}

clf = OneVsRestClassifier(
    lgb.LGBMClassifier(**params)
    .set_fit_request(callbacks=True, eval_set=True)
)

# train
clf.fit(X, y,  **fit_params)

# check eval results, to prove that the callback was used
print(eval_results)

# {'valid_0': OrderedDict([('binary_logloss', [0.598138869381609, 0.5203293282602738, 0.45544446427154844, 0.40059849184355334, 0.3537472248673818, 0.31338812592304066, 0.2783839141567028, 0.24785302530927006, 0.22109850424011224, 0.19756016345789282])])}