How to handle multiple pytorch models with pytriton + sagemaker

218 Views Asked by At

I am trying to adapt pytriton to host multiple models for a multi-model sagemaker setup. In my case, I am trying to get it to load all models that are hosted in the SAGEMAKER_MULTI_MODEL_DIR folder.

I could not find any relevnt example here for a multimodel use case, so I am trying with this code below. Is this the right approach?

import logging

import numpy as np

from pytriton.decorators import batch
from pytriton.model_config import ModelConfig, Tensor
from pytriton.triton import Triton

logger = logging.getLogger("examples.multiple_models_python.server")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")

# assume these are custom pytorch models
# loaded from SAGEMAKER_MULTI_MODEL_DIR using a custom function
models = [model1, model2]

@batch
def _infer(input, model):
    # do processing
    return [result]



with Triton() as triton:
    logger.info("Loading models")
    for model in models:
        triton.bind(
            model_name=model.name,
            infer_func=_infer,
            inputs=[
                Tensor(name="multiplicand", dtype=np.float32, shape=(-1,)),
                model
            ],
            outputs=[
                Tensor(name="product", dtype=np.float32, shape=(-1,)),
            ],
            config=ModelConfig(max_batch_size=8),
        )

    triton.serve()

However, this does not work due to the models not existing on loadtime for pytriton. Is there anymore documentation to using pytriton in a multimodel setup?

1

There are 1 best solutions below

0
On

If you are trying to use Triton Inference Server as the model server with SageMaker MME, please reference this example: https://aws.amazon.com/blogs/machine-learning/host-ml-models-on-amazon-sagemaker-using-triton-tensorrt-models/. You need to package your tarball in the appropriate format: model.py, config.pbtxt, and model artifacts. Dumping all these different tarballs in a common S3 location will then enable an MME endpoint which you can specify in the create_model API call.

container = {
"Image": triton_image_uri,
"ModelDataUrl": model_data_uri,
"Mode": "MultiModel",
}

create_model_response = sm.create_model(
ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)