ONNX with custom ops from TensorFlow in Java

920 Views Asked by At

in order to make use of Machine Learning in Java, I'm trying to train a model in TensorFlow, save it as ONNX file and then use the file for inference in Java. While this works fine with simple models, it's getting more complicated using pre-processing layers, as they seem to depend on custom operators.

https://www.tensorflow.org/tutorials/keras/text_classification

As an example, this Colab deals with text classification and uses an TextVectorization layer this way:

@tf.keras.utils.register_keras_serializable()
def custom_standardization2(input_data):
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase, '<br />',' ')
    return tf.strings.regex_replace(stripped_html, '[%s]' % re.escape(string.punctuation), '')


vectorize_layer = layers.TextVectorization(
    standardize=custom_standardization2,
    max_tokens=max_features,
    output_mode='int',
    output_sequence_length=sequence_length
)

It is used as pre-processing layer in the compiled model:

export_model = tf.keras.Sequential([
    vectorize_layer,
    model,
    layers.Activation('sigmoid')
])

export_model.compile(loss=losses.BinaryCrossentropy(from_logits=False), optimizer="adam", metrics=['accuracy'])

In order to create the ONNX file I save the model as protobuf and then convert it to ONNX:

export_model.save("saved_model")

python -m tf2onnx.convert --saved-model saved_model --output saved_model.onnx --extra_opset ai.onnx.contrib:1 --opset 11

Using onnxruntime-extensions it is now possible to register the custom ops and to run the model in Python for inference.

import onnxruntime
from onnxruntime import InferenceSession
from onnxruntime_extensions import get_library_path

so = onnxruntime.SessionOptions()
so.register_custom_ops_library(get_library_path())

session = InferenceSession('saved_model.onnx', so)
res = session.run(None, { 'text_vectorization_2_input': example_new })

This raises the question if it's possible to use the same model in Java in a similar way. Onnxruntime for Java does have a SessionOptions#registerCustomOpLibrary function, so I thought of something like this:

OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.registerCustomOpLibrary(""); // reference the library
OrtSession session = env.createSession("...", options);

Does anyone have an idea if the use case described is feasable or how to use models with pre-processing layers in Java (without using TensorFlow Java)?

UPDATE: Spotted a potential solution. If I understand the comments in this GitHub Issue correctly, one possibility is to build the ONNXRuntime Extensions package from source (see this explanation) and reference the generated library file by calling registerCustomOpLibrary in the ONNX Runtime Library for Java. However, as I have no experience with tools like cmake this might become a challenge for me.

1

There are 1 best solutions below

0
On BEST ANSWER

The solution you propose in your update is correct, you need to compile the ONNX Runtime extension package from source to get the dll/so/dylib, and then you can load that into ONNX Runtime in Java using the session options. The Python whl doesn't distribute the binary in a format that can be loaded outside of Python, so compiling from source is the only option. I wrote the ONNX Runtime Java API, so if this approach fails open an issue on Github and we'll fix it.