Skip to content

Nvidia Triton Inference Server

RBLN SDK with Nvidia Triton Inference Server

NVIDIA Triton Inference Server is an open-source software designed to enable efficient and scalable deployment of machine learning models in production environments. In this tutorial, we will guide you through the steps required to seamlessly integrate RBLN SDK with Nvidia Triton Inference Server using a precompiled ResNet50 model.

Note

This tutorial is written with the assumption that the reader already has a good understanding of how to compile and infer models using RBLN SDK. If you are not familiar with RBLN SDK, please refer to the PyTorch/TensorFlow tutorials and the API page.

Prerequisites

Before we start, please make sure you have prepared the following prerequisites in your system:

Quick Start with Triton Inference Server Container

If you are not running from Backend.AI, and instead running on-premise, skip to Step 1.

Step 0. Starting session with Triton server image

When starting your session via Backend.AI, select Triton Server (NGC) as your environment. This will automatically set up the environment to nvcr.io/nvidia/tritonserver:24.01-py3.

If you are not using Backend.AI, you can skip this step.

Step 1. Prepare the Nvidia Triton python_backend

First, clone the Nvidia Triton Inference Server python_backend repository using the following command:

$ git clone https://github.com/triton-inference-server/python_backend -b r24.01

Before proceeding to the next step, you must place the precompiled resnet50.rbln file into the python_backend/examples/rbln/resnet50/1 directory:

$ mkdir -p python_backend/examples/rbln/resnet50/1
$ mv resnet50.rbln python_backend/examples/rbln/resnet50/1/

Step 2. Write your own TritonPythonModel using RBLN SDK

The Triton python_backend requires users to write TritonPythonModel class with the following member methods:

  • auto_complete_config()
  • initialize()
  • execute()
  • finalize()

Please refer to the official Triton python_backend repository for more detailed information about each function.

Below is a simple model.py, where we define the initialize() and execute() functions for loading the model and performing inference, respectively. Save this code along with the resnet50.rbln file in the following directory: python_backend/examples/rbln/resnet50/1/model.py.

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#
# model.py
#
import json
import os
import rebel # RBLN Runtime
import triton_python_backend_utils as pb_utils

# Number of devices to allocate.
# Available device numbers can be found through `rbln-stat` command.
NUM_OF_DEVICES = 1


class TritonPythonModel:
    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.

        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device ID
          * model_instance_name: A string containing model instance name in form of <model_name>_<instance_group_id>_<instance_id>
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """

        self.model_config = model_config = json.loads(args["model_config"])
        output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT__0")
        self.output0_dtype = pb_utils.triton_string_to_numpy(
            output0_config["data_type"]
        )

        # Path to rbln compiled model file
        rbln_path = os.path.join(
            args["model_repository"],
            args["model_version"],
            f"{args['model_name']}.rbln",
        )

        # Create rbln runtime module
        self.module = rebel.Runtime(rbln_path)

    def execute(self, requests):
        """`execute` MUST be implemented in every Python model. `execute`
        function receives a list of pb_utils.InferenceRequest as the only
        argument. This function is called when an inference request is made
        for this model. Depending on the batching configuration (e.g. Dynamic
        Batching) used, `requests` may contain multiple requests. Every
        Python model, must create one pb_utils.InferenceResponse for every
        pb_utils.InferenceRequest in `requests`. If there is an error, you can
        set the error argument when creating a pb_utils.InferenceResponse

        Parameters
        ----------
        requests : list
          A list of pb_utils.InferenceRequest

        Returns
        -------
        list
          A list of pb_utils.InferenceResponse. The length of this list must
          be the same as `requests`
        """
        output0_dtype = self.output0_dtype
        responses = []

        for request in requests:
            in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT__0")

            # Run inference
            result = self.module.run(in_0.as_numpy())
            out_tensor_0 = pb_utils.Tensor("OUTPUT__0", result.astype(output0_dtype))
            inference_response = pb_utils.InferenceResponse(
                output_tensors=[out_tensor_0]
            )
            responses.append(inference_response)

        return responses

The following config.pbtxt file must be saved in python_backend/examples/rbln/resnet50/config.pbtxt.

name: "resnet50"
backend: "python"

input [
  {
    name: "INPUT__0"
    data_type: TYPE_FP32
    dims: [ 3, 224, 224 ]
  }
]
output [
  {
    name: "OUTPUT__0"
    data_type: TYPE_FP32
    dims: [ 1000 ]
  }
]

# Configure instance group
instance_group [
  {
    count: 1
    kind: KIND_MODEL
  }
]

max_batch_size: 1

If you have successfully completed the steps so far, you will have the following directory structure:

1
2
3
4
5
+--resnet50/
|      +-- config.pbtxt
|      +-- 1/
|      |   +-- model.py
|      |   +-- resnet50.rbln

Step 3. Run the inference server in the container

We are now ready to run the inference server. If you are using Backend.AI, please refer to the Backend.AI section. If you are not a Backend.AI user, proceed to the On-premise server section.

Backend.AI

Start within Backend.AI docker

Install the RBLN SDK:

$ pip3 install -i https://pypi.rbln.ai/simple/ rebel-compiler
Start the Triton server:
$ tritonserver --model-repository /opt/tritonserver/python_backend/examples/rbln

You will see the following messages that indicate successful initiation of the server:

1
2
3
Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002

On-premise server

If you are not using Backend.AI, follow these steps to start the inference server in the Docker container. (Backend.AI users can skip to Step 4.)

To access the RBLN NPU devices, the inference server container must be run in privileged mode. Add a mount option for the cloned python_backend repository as below:

1
2
3
$ sudo docker run --privileged --shm-size=1g --ulimit memlock=-1 \
   -v /PATH/TO/YOUR/python_backend:/opt/tritonserver/python_backend \
   -p 8000:8000 -p 8001:8001 -p 8002:8002 --ulimit stack=67108864 -ti nvcr.io/nvidia/tritonserver:24.01-py3

Install RBLN SDK inside the container:

$ pip3 install -i https://pypi.rbln.ai/simple/ rebel-compiler
Start the Triton Server inside the container:
$ tritonserver --model-repository /opt/tritonserver/python_backend/examples/rbln

You will see the following messages indicating successful initiation of the server:

1
2
3
Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002

Step 4. Requesting inference via HTTP API

Before proceeding, install the required dependencies:

$ pip3 install tritonclient==2.41.1 gevent geventhttpclient fire

Next, download a sample image:

$ wget https://rbln-public.s3.ap-northeast-2.amazonaws.com/images/tabby.jpg

Below is a sample client.py for making a ResNet50 inference request:

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import pathlib

import fire
import torchvision
import tritonclient.http as httpclient
from torchvision.io.image import read_image
from tritonclient.utils import np_to_triton_dtype

DEFAULT_URL = "localhost:8000"
# "tabby.jpg" can be downloaded from https://rbln-public.s3.ap-northeast-2.amazonaws.com/images/tabby.jpg
DEFAULT_IMG = os.path.join(pathlib.Path(__file__).parent.resolve(), "tabby.jpg")
DEFAULT_REQUESTS = 10
MODEL_NAME = "resnet50"


def infer(
    url: str = DEFAULT_URL,
    img_path: str = DEFAULT_IMG,
    requests: int = DEFAULT_REQUESTS,
    verbose: bool = False,
):
    # prepare input img
    img = read_image(img_path)
    weights = torchvision.models.get_model_weights(MODEL_NAME).DEFAULT
    preprocess = weights.transforms(antialias=True)
    batch = preprocess(img).unsqueeze(0)

    # configure httpclient
    with httpclient.InferenceServerClient(
        url=url, verbose=verbose
    ) as client:
        input0_data = batch.numpy()
        inputs = [
            httpclient.InferInput(
                "INPUT__0", input0_data.shape, np_to_triton_dtype(input0_data.dtype)
            )
        ]
        inputs[0].set_data_from_numpy(input0_data)
        outputs = [
            httpclient.InferRequestedOutput("OUTPUT__0"),
        ]
        responses = []
        # inference
        for i in range(requests):
            responses.append(
                client.infer(MODEL_NAME, inputs, request_id=str(i), outputs=outputs)
            )
        # check result
        for response in responses:
            out = response.as_numpy("OUTPUT__0")
            top_index_rebel = (-out[0,]).argsort(axis=-1)[:1].flatten()
            top_category = [weights.meta["categories"][x] for x in top_index_rebel][0]
            print(top_category)
            assert top_category == "tabby"


if __name__ == "__main__":
    fire.Fire(infer)

Your output should resemble the following:

tabby
tabby
tabby
tabby
tabby
tabby
tabby
tabby
tabby
tabby

Advanced features

Multiple model instances with multiple RBLN NPU devices

You can configure multiple model instances to distribute the inference workloads across multiple RBLN NPU devices. Multiple model instances can be configured by specifying the count value of the instance_group field inside the config.pbtxt file.

To create two execution instances of a model in the previous ResNet50 tutorial, for example, you can increase the count value to 2 in python_backend/examples/rbln/resnet50/config.pbtxt as follows:

...

instance_group [
  {
    count: 2  # number of instances
    kind: KIND_MODEL
  }
]

...

Inside the model.py, you need to set the device index for the runtime instance to run with the device parameter:

1
2
3
4
5
6
7
#
# model.py
#

def initalize():
    # .......
    module = rebel.Runtime(rbln_path, device=instance_idx)

The instance_idx parameter is an index that represents the instance number within the instance_group. You can map this index to the appropriate RBLN NPU device based on your hardware configuration. Here is a simple example of mapping instance_idx to the corresponding RBLN NPU device:

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#
# model.py
#
import json
import os
import rebel
import triton_python_backend_utils as pb_utils

# Number of devices to allocate.
# Available device numbers can be found through `rbln-stat` command.
NUM_OF_DEVICES = 2

class TritonPythonModel:
    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.

        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device ID
          * model_instance_name: A string containing model instance name in form of <model_name>_<instance_group_id>_<instance_id>
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """

        self.model_config = model_config = json.loads(args["model_config"])
        instance_group_config = model_config["instance_group"][0]
        instance_count = instance_group_config["count"]
        instance_idx = 0
        # Get `instance_idx` for multiple instances.
        # instance_group's count should be bigger than 1 in config.pbtxt.
        if instance_count > 1:
            instance_name_parts = args["model_instance_name"].split("_")
            if not instance_name_parts[-1].isnumeric():
                raise pb_utils.TritonModelException(
                    "model instance name should end with '_<instance_idx>', got {}".format(
                        args["model_instance_name"]
                    )
                )
            instance_idx = int(instance_name_parts[-1])

        output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT__0")
        self.output0_dtype = pb_utils.triton_string_to_numpy(
            output0_config["data_type"]
        )
        rbln_path = os.path.join(
            args["model_repository"],
            args["model_version"],
            f"{args['model_name']}.rbln",
        )
         # Allocate instance to device.
         # Simple example of round robin assignment to multiple devices.
        self.module = rebel.Runtime(rbln_path, device=instance_idx % NUM_OF_DEVICES)

    def execute(self, requests):
        # ... Same as previous ...
        return responses

Dynamic Batching

Triton offers dynamic batching, enabling the server to automatically group multiple incoming inference requests into batches for processing. Instead of handling each request individually, the server dynamically combines relevant inference requests to create batches on the fly.

To enable dynamic batching with RBLN SDK, it is necessary to compile the model with various input shapes. Below is an example for a range of batch sizes 1-4:

size = 224 # Width and height of image
batches = [1, 2, 3, 4] # Supported batch sizes
input_infos = []

# Create input information for each batch size
for i, batch in enumerate(batches):
    input_info = [("x", [batch, 3, size, size], "float32")]
    input_infos.append(input_info)

# Compile the model with the pre-defined input information
compiled_model = rebel.compile_from_torch(model, input_info=input_infos)

In the config.pbtxt file, specify the maximum batch size for the compiled model like below :

1
2
3
4
5
#
# config.pbtxt
#

max_batch_size: 4

In the model.py file, you can create a runtime based on a specific batch size and define the inference computation using the runtime and the provided input data.

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import os

import numpy as np
import rebel

import triton_python_backend_utils as pb_utils

class TritonPythonModel:
    def _get_index(self, name):
        parts = name.split("__")
        return int(parts[1])

    def initialize(self, args):
        self.model_config = model_config = json.loads(args["model_config"])

        # Configure input_dict
        self.input_dict = {}
        for config_input in model_config["input"]:
            index = self._get_index(config_input["name"])
            self.input_dict[index] = [
                config_input["name"],
                config_input["data_type"],
                config_input["dims"],
            ]

        # Configure output_dict
        self.output_dict = {}
        for config_output in model_config["output"]:
            index = self._get_index(config_output["name"])
            self.output_dict[index] = [
                config_output["name"],
                config_output["data_type"],
                config_output["dims"],
            ]

        output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT__0")
        self.output0_dtype = pb_utils.triton_string_to_numpy(
            output0_config["data_type"]
        )
        rbln_path = os.path.join(
            args["model_repository"],
            args["model_version"],
            f"{args['model_name']}.rbln",
        )
        self.model_name = args["model_name"]

        # Load compiled model
        compiled_model = rebel.RBLNCompiledModel(
            rbln_path
        )

        # Create runners for each batch size
        self.runners = []
        for i in range(model_config["max_batch_size"]):
            self.runners.append(
                compiled_model.create_runtime(input_info_index=i)
            )

    def execute(self, requests):
        if len(requests) > 1:
            print("dynamic batch size : ", len(requests))

        # Preprocess the input data
        responses = []
        inputs = []
        num_requests = len(requests)
        request_batch_sizes = []
        for i in self.input_dict.keys():
            name, dt, _ = self.input_dict[i]
            first_tensor = pb_utils.get_input_tensor_by_name(
                requests[0], name
            ).as_numpy()
            request_batch_sizes.append(first_tensor.shape[0])
            batched_tensor = first_tensor
            for j in range(1, num_requests):
                tensor = pb_utils.get_input_tensor_by_name(requests[j], name).as_numpy()
                request_batch_sizes.append(request_batch_sizes[-1] + tensor.shape[0])
                batched_tensor = np.concatenate((batched_tensor, tensor), axis=0)

            inputs.append(batched_tensor)

        batch_size = batched_tensor.shape[0]
        if batch_size > 1:
            print(f"running inference with batch size : {batch_size}")

        # Run inference on the RBLN model
        batched_results = self.runners[batch_size - 1](batched_tensor)

        # Postprocess the output data
        chunky_batched_results = []
        for i in self.output_dict.keys():
            batch = (
                batched_results[i]
                if isinstance(batched_results, tuple)
                else batched_results
            )
            chunky_batched_results.append(
                np.array_split(batch, request_batch_sizes, axis=0)
            )

        # Send response
        for i in range(num_requests):
            output_tensors = []
            for j in self.output_dict.keys():
                name, dt, _ = self.output_dict[j]
                result = chunky_batched_results[j][i]
                output_tensor = pb_utils.Tensor(
                    name, result.astype(pb_utils.triton_string_to_numpy(dt))
                )
                output_tensors.append(output_tensor)
            inference_response = pb_utils.InferenceResponse(
                output_tensors=output_tensors
            )
            responses.append(inference_response)

        return responses

    def finalize(self):
        print("Cleaning up...")