I'm working on modifying a working Flask application I've been using to serve an ML model to now serve predictions from a new, updated model.
The model is an sklearn pipeline object I've trained and serialized using the Pickle library. It includes several Column Transformer steps, Imputation, Encoding, and finally a Prediction step. With the way some of the custom Column Transformers are written, it's important that the output of each intermediate step is a Pandas Dataframe rather than an array, which is how the model was trained and serialized.
Here's where it starts to get weird:
- When I load and make a prediction with the Model in what I understand to be the "application context," the Pipeline works as expected and returns a prediction.
- When I use the same loaded Model with the same data in a request context, the intermediate Pipeline steps don't return Pandas outputs, causing the Pipeline to fail once the second step is reached and looks for column names that don't exist in the array output.
My questions are:
- Why is this behavior different in each of these two contexts?
- How can I solve this problem without going back to reconfigure the pipeline to use arrays instead and retraining--which I don't want to do for several reasons.
- Is there a way I can configure the setting "sklearn.set_config(transform_output="pandas")" just in the request context?
Here's my basic code:
import os
import pickle
import pandas as pd
from flask import (Flask, redirect, make_response)
import logging
#define app
app = Flask(__name__)
# load the trained Model
model = pickle.load(open("model.pkl"), "rb")
# load test record
test_record = pd.read_csv("test_record.csv")
# Make a prediction using the model and test record. This step works.
try:
model.predict(test_record)
except Exception as e:
log.debug(str(e), stack_info=True)
@app.route('/predict')
def predict():
# Make a prediction using the model and test record. This step *doesn't* work.
try:
prediction = model.predict(test_record)
except Exception as e:
log.debug(str(e), stack_info=True)
return make_response('Test Record Prediction: ' + str(prediction),200)
# Start the Flask app
if __name__ == '__main__':
if os.environ['ENV'] in {'local','local_w_db','DEV'}:
app.run(debug=True)
else:
app.run()
Environment Specs:
python==3.11.7
Flask==3.0.1
scikit-learn==1.3.2
scikit-learn-intelex==2023.2.1
scipy==1.11.4
pandas==2.1.4
category_encoders==2.6.3
werkzeug==3.0.1
joblib==1.2.0
I've tried:
- Running the model both within and outside of the request context.
- Inside the request context, I've traced the execution of the predict step far enough to see that the output of the first ColumnTransformer is an array, where it needs to be a Pandas df.
- Outside of the request context, the outputs of all ColumnTransformers and intermediate Pipeline steps is "pandas," how the model was configured.
Setting the sklearn transform output to always be pandas by running these lines at the beginning of my flask app code, which didn't make a difference:
import sklearn sklearn.set_config(transform_output="pandas")
Here is the stack trace that is returned with the Exception on the prediction step, which points to an issue in the threading configurations rather than the sklearn settings:
[2024-01-30 08:58:44,666] DEBUG [app.predict:109] - Specifying the columns using strings is only supported for pandas DataFrames
Stack (most recent call last):
File "c:\Users\user\.vscode\extensions\ms-python.python-2023.22.1\pythonFiles\lib\python\debugpy\_vendored\pydevd\_pydev_bundle\pydev_monkey.py", line 1118, in __call__
ret = self.original_func(*self.args, **self.kwargs)
File "..\miniforge3\envs\APP_ENV\Lib\threading.py", line 1002, in _bootstrap
self._bootstrap_inner()
File "..\miniforge3\envs\APP_ENV\Lib\threading.py", line 1045, in _bootstrap_inner
self.run()
File "..\miniforge3\envs\APP_ENV\Lib\threading.py", line 982, in run
self._target(*self._args, **self._kwargs)
File "..\miniforge3\envs\APP_ENV\Lib\socketserver.py", line 691, in process_request_thread
self.finish_request(request, client_address)
File "..\miniforge3\envs\APP_ENV\Lib\socketserver.py", line 361, in finish_request
self.RequestHandlerClass(request, client_address, self)
File "..\miniforge3\envs\APP_ENV\Lib\socketserver.py", line 755, in __init__
self.handle()
File "..\miniforge3\envs\APP_ENV\Lib\site-packages\werkzeug\serving.py", line 390, in handle
super().handle()
File "..\miniforge3\envs\APP_ENV\Lib\http\server.py", line 436, in handle
self.handle_one_request()
File "..\miniforge3\envs\APP_ENV\Lib\http\server.py", line 424, in handle_one_request
method()
File "..\miniforge3\envs\APP_ENV\Lib\site-packages\werkzeug\serving.py", line 362, in run_wsgi
execute(self.server.app)
File "..\miniforge3\envs\APP_ENV\Lib\site-packages\werkzeug\serving.py", line 323, in execute
application_iter = app(environ, start_response)
File "..\miniforge3\envs\APP_ENV\Lib\site-packages\flask\app.py", line 1488, in __call__
return self.wsgi_app(environ, start_response)
File "..\miniforge3\envs\APP_ENV\Lib\site-packages\flask\app.py", line 1463, in wsgi_app
response = self.full_dispatch_request()
File "..\miniforge3\envs\APP_ENV\Lib\site-packages\flask\app.py", line 870, in full_dispatch_request
rv = self.dispatch_request()
File "..\miniforge3\envs\APP_ENV\Lib\site-packages\flask\app.py", line 855, in dispatch_request
return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) # type: ignore[no-any-return]
File "C:\Users\user\app_directory\app.py", line 109, in predict
log.debug(str(e), stack_info=True)