There are a lot of changes in scikit-learn 1.2.0 where it supports pandas output for all of the transformers but how can I use it in a custom transformer?
In [1]: Here is my custom transformer which is a standard scaler:
from sklearn.base import BaseEstimator, TransformerMixin
import numpy as np
class StandardScalerCustom(BaseEstimator, TransformerMixin):
def fit(self, X, y=None):
self.mean = np.mean(X, axis=0)
self.std = np.std(X, axis=0)
return self
def transform(self, X):
return (X - self.mean) / self.std
In [2]: Created a specific scale
pipeline
scale_pipe = make_pipeline(StandardScalerCustom())
In [3]: Added in a full pipeline where it may get mixed with scalers, imputers, encoders etc.
full_pipeline = ColumnTransformer([
("imputer", impute_pipe, ['column_1'])
("scaler", scale_pipe, ['column_2'])
])
# From documentation
full_pipeline.set_output(transform="pandas")
Got this error:
ValueError: Unable to configure output for StandardScalerCustom() because set_output
is not available.
There is a solution and it can be:
set_config(transform_output="pandas")
But in case-to-case basis, how can I create a function in StandardScalerCustom() class that can fix the error above?
My guess is that one the rationales behind the enhancement of
set_config()
by means of the parametertransform_output
was indeed to enable also custom transformers to output pandas DataFrames.By looking at the underlying code, I've found one hack that allows custom transformers to output pandas DataFrames without the need to explicitly set the global configuration; it is sufficient to implement a dummy
.get_feature_names_out()
method. However, this works just because in this way the global configuration is automatically set. Indeed,_auto_wrap_is_configured()
returns True if.get_feature_names_out()
is available and, if so,full_pipeline
reverts to calling this.set_output()
method rather than getting to this._safe_set_output()
method, where the first sets the global configuration withtransform="pandas"
automatically, while the second would output the ValueError that you're getting.Here's a working example: