Using PipelineModel.load() in custom MLFlow PyFunc class results in error

268 Views Asked by At

Am creating a custom PyFunc class to use with Databricks Feature Store as their Model Serving UI and feature store's log_model() methods only work with the PythonModel class.

The underlying model is a PipelineModel() which perform various binning and transformations prior to the model.

The function is as such:

import mlflow.pyfunc
from pyspark.ml.functions import vector_to_array
from pyspark.ml import PipelineModel
    class custom_model_class(mlflow.pyfunc.PythonModel):
    
      def __init__(self, model_path, threshold):
        self.model_path = model_path
        self.threshold = threshold
        self.model = None
    
      def load_context(self, context):
        self.model = PipelineModel.load(self.model_path)
    
      def predict(self, context, model_input):
        return self.model.transform(model_input).withColumn("prediction_opt_thresh", when(vector_to_array("probability")[1] > lit(self.threshold), 1).otherwise(0))
    
    custom_model = custom_model_class(model_path = pipeline_model_directory, threshold = 0.52)

However I am getting this error:

'RuntimeError: SparkContext should only be created and accessed on the driver.', from <command-2164064949918430>, line 12.

I have tried various other methods to solve the issue, but most of them give the same issue:

  1. Either JVM returns None, indicating there is no Spark Session
  2. Creating a Spark Session returns the above error
  3. Using a mlflow.pyfunc.load_model (which does not utilize the Spark session) to load the PipelineModel() object works but does not return probabilities.

How do I resolve this? Is there a way for Feature Store and MLLib to work together?

1

There are 1 best solutions below

0
On

I managed to solve this. The issue lay in the way that we were using Feature Store's log_model() function. See the documentation here https://docs.databricks.com/dev-tools/api/python/latest/feature-store/client.html

You have to specify an argument 'flavor' in this method. If you specify anything that doesn't happen to be mlflow.spark (but still within the list of approved flavours), the environment you run the model with using the score_batch() function will in fact not have the result.

fs.log_model( model, "model", flavor=mlflow.pyfunc, training_set=training_set, registered_model_name="example_model" )

To ensure you are running in the correct type of environment, you need a spark context, and the flavor has to be mlflow.spark.

Hope this helps anyone using the new Feature Store API.