콘텐츠로 이동

Nvidia Triton 추론 서버

Nvidia Triton 추론 서버에서 RBLN SDK 활용

NVIDIA Triton 추론 서버는 프로덕션 환경에서 머신러닝 모델을 효과적으로 서빙할 수 있도록 설계 된 오픈소스 소프트웨어입니다. 이 페이지에서는 Nvidia Triton 추론 서버에서 RBLN SDK를 활용하여, 컴파일된 ResNet50 모델을 서빙할 수 있는 방법을 소개합니다.

참고

이 튜토리얼은 사용자가 RBLN SDK 기반의 모델 컴파일 및 추론에 대해 잘 이해하고 있다는 가정하에 작성되었습니다. RBLN SDK 사용법에 익숙하지 않을 경우 파이토치/텐서플로우 튜토리얼 및 파이썬 API 페이지를 참고 부탁드립니다.

사전 준비

시작하기에 앞서 시스템에 아래 항목들이 준비되어 있는지 확인합니다:

Triton 추론 서버 컨테이너 활용

Backend.AI에서 실행하지 않고 온프레미스에서 실행하는 경우, 1단계로 건너뛰세요.

0 단계. Triton 서버 이미지로 세션 시작하기

Backend.AI을 통해 세션을 시작할 경우, Triton Server (NGC) 를 시작 환경으로 선택합니다. 이 경우 세션의 환경은 자동으로 nvcr.io/nvidia/tritonserver:24.01-py3 로 설정됩니다.

Backend.AI 사용자가 아니라면, 이 단계를 건너뛸 수 있습니다.

1단계. Nvidia Triton python_backend 준비

다음 명령어를 사용하여 Nvidia Triton 추론 서버인 python_backend 저장소를 복제(clone)합니다:

$ git clone https://github.com/triton-inference-server/python_backend -b r24.01

다음 단계로 진행하기 전에 준비된 resnet50.rbln 파일을 python_backend/examples/rbln/resnet50/1 디렉토리에 위치시켜야 합니다:

$ mkdir -p python_backend/examples/rbln/resnet50/1
$ mv resnet50.rbln python_backend/examples/rbln/resnet50/1/

2단계. RBLN SDK를 사용하여 TritonPythonModel 작성

Triton python_backend 사용을 위해 아래와 같은 멤버 메서드를 포함하는 TritonPythonModel 클래스를 작성해야 합니다:

  • auto_complete_config()
  • initialize()
  • execute()
  • finalize()

각 함수에 대한 자세한 정보는 공식 Triton python_backend 저장소를 참조하시기 바랍니다.

아래는 모델 로드를 위한 initialize() 함수와 추론을 위한 execute() 함수를 정의해 놓은 model.py 예시 코드입니다. 이 코드를 resnet50.rbln 파일과 함께 python_backend/examples/rbln/resnet50/1/model.py에 저장합니다.

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#
# model.py
#
import json
import os
import rebel # RBLN 런타임
import triton_python_backend_utils as pb_utils

# Number of devices to allocate.
# Available device numbers can be found through `rbln-stat` command.
NUM_OF_DEVICES = 1


class TritonPythonModel:
    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.

        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device ID
          * model_instance_name: A string containing model instance name in form of <model_name>_<instance_group_id>_<instance_id>
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """

        self.model_config = model_config = json.loads(args["model_config"])
        output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT__0")
        self.output0_dtype = pb_utils.triton_string_to_numpy(
            output0_config["data_type"]
        )

        # Path to rbln compiled model file
        rbln_path = os.path.join(
            args["model_repository"],
            args["model_version"],
            f"{args['model_name']}.rbln",
        )

        # Create rbln runtime module
        self.module = rebel.Runtime(rbln_path)

    def execute(self, requests):
        """`execute` MUST be implemented in every Python model. `execute`
        function receives a list of pb_utils.InferenceRequest as the only
        argument. This function is called when an inference request is made
        for this model. Depending on the batching configuration (e.g. Dynamic
        Batching) used, `requests` may contain multiple requests. Every
        Python model, must create one pb_utils.InferenceResponse for every
        pb_utils.InferenceRequest in `requests`. If there is an error, you can
        set the error argument when creating a pb_utils.InferenceResponse

        Parameters
        ----------
        requests : list
          A list of pb_utils.InferenceRequest

        Returns
        -------
        list
          A list of pb_utils.InferenceResponse. The length of this list must
          be the same as `requests`
        """
        output0_dtype = self.output0_dtype
        responses = []

        for request in requests:
            in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT__0")

            # Run inference
            result = self.module.run(in_0.as_numpy())
            out_tensor_0 = pb_utils.Tensor("OUTPUT__0", result.astype(output0_dtype))
            inference_response = pb_utils.InferenceResponse(
                output_tensors=[out_tensor_0]
            )
            responses.append(inference_response)

        return responses

아래 config.pbtxt 파일도 python_backend/examples/rbln/resnet50/config.pbtxt 경로에 함께 저장되어야 합니다.

name: "resnet50"
backend: "python"

input [
  {
    name: "INPUT__0"
    data_type: TYPE_FP32
    dims: [ 3, 224, 224 ]
  }
]
output [
  {
    name: "OUTPUT__0"
    data_type: TYPE_FP32
    dims: [ 1000 ]
  }
]

# Configure instance group
instance_group [
  {
    count: 1
    kind: KIND_MODEL
  }
]

max_batch_size: 1

지금까지의 스텝을 모두 올바르게 수행했다면 아래와 같은 디렉토리 구조를 갖게 됩니다.

1
2
3
4
5
+--resnet50/
|      +-- config.pbtxt
|      +-- 1/
|      |   +-- model.py
|      |   +-- resnet50.rbln

3단계. 컨테이너에서 추론 서버 실행

이제 추론 서버를 실행할 준비가 되었습니다.

Backend.AI 가 제공해주는 Docker 컨테이너로 시작하기

RBLN SDK를 설치합니다:

$ pip3 install -i https://pypi.rbln.ai/simple/ rebel-compiler

Triton 서버를 시작합니다:

$ tritonserver --model-repository /opt/tritonserver/python_backend/examples/rbln

서버가 성공적으로 시작되면 다음과 같은 메시지가 표시됩니다:

1
2
3
Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002

Backend.AI 없이 자체 Docker 컨테이너로 시작하기

Backend.AI를 사용하지 않는 경우, 아래의 단계들로 자체 Docker 컨테이너에서 추론 서버를 시작할 수 있습니다. (Backend.AI 사용자는 Step 4로 건너뛰세요.)

컨테이너에서 RBLN NPU 장치를 사용하기 위해 privileged 모드로 추론 서버 컨테이너를 실행해야 합니다. 또한 이 전 단계에서 준비한 python_backend 디렉토리를 마운트해야합니다:

1
2
3
$ sudo docker run --privileged --shm-size=1g --ulimit memlock=-1 \
   -v /PATH/TO/YOUR/python_backend:/opt/tritonserver/python_backend \
   -p 8000:8000 -p 8001:8001 -p 8002:8002 --ulimit stack=67108864 -ti nvcr.io/nvidia/tritonserver:24.01-py3

컨테이너 상에서 RBLN SDK를 설치합니다:

$ pip3 install -i https://pypi.rbln.ai/simple/ rebel-compiler

컨테이너 상에서 Triton 서버를 시작합니다:

$ tritonserver --model-repository /opt/tritonserver/python_backend/examples/rbln

아래와 같이 서버가 성공적으로 시작되었음을 의미하는 메시지를 확인할 수 있습니다:

1
2
3
Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002

4단계. HTTP API 기반 추론 요청

필요한 패키지를 설치합니다:

$ pip3 install tritonclient==2.41.1 gevent geventhttpclient fire

그리고 필요한 샘플 이미지를 다운로드 합니다:

$ wget https://rbln-public.s3.ap-northeast-2.amazonaws.com/images/tabby.jpg

아래는 ResNet50 모델에 추론 요청을 하는 예제 코드 client.py 입니다:

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import pathlib

import fire
import torchvision
import tritonclient.http as httpclient
from torchvision.io.image import read_image
from tritonclient.utils import np_to_triton_dtype

DEFAULT_URL = "localhost:8000"
# "tabby.jpg" can be downloaded from https://rbln-public.s3.ap-northeast-2.amazonaws.com/images/tabby.jpg
DEFAULT_IMG = os.path.join(pathlib.Path(__file__).parent.resolve(), "tabby.jpg")
DEFAULT_REQUESTS = 10
MODEL_NAME = "resnet50"


def infer(
    url: str = DEFAULT_URL,
    img_path: str = DEFAULT_IMG,
    requests: int = DEFAULT_REQUESTS,
    verbose: bool = False,
):
    # prepare input img
    img = read_image(img_path)
    weights = torchvision.models.get_model_weights(MODEL_NAME).DEFAULT
    preprocess = weights.transforms(antialias=True)
    batch = preprocess(img).unsqueeze(0)

    # configure httpclient
    with httpclient.InferenceServerClient(
        url=url, verbose=verbose
    ) as client:
        input0_data = batch.numpy()
        inputs = [
            httpclient.InferInput(
                "INPUT__0", input0_data.shape, np_to_triton_dtype(input0_data.dtype)
            )
        ]
        inputs[0].set_data_from_numpy(input0_data)
        outputs = [
            httpclient.InferRequestedOutput("OUTPUT__0"),
        ]
        responses = []
        # inference
        for i in range(requests):
            responses.append(
                client.infer(MODEL_NAME, inputs, request_id=str(i), outputs=outputs)
            )
        # check result
        for response in responses:
            out = response.as_numpy("OUTPUT__0")
            top_index_rebel = (-out[0,]).argsort(axis=-1)[:1].flatten()
            top_category = [weights.meta["categories"][x] for x in top_index_rebel][0]
            print(top_category)
            assert top_category == "tabby"


if __name__ == "__main__":
    fire.Fire(infer)

성공적으로 실행됐다면 다음과 같은 메시지가 표시됩니다:

tabby
tabby
tabby
tabby
tabby
tabby
tabby
tabby
tabby
tabby

고급 기능

다중 디바이스를 이용한 다중 모델 인스턴스

다중 모델 인스턴스를 구성하여 추론 작업 부하를 여러 RBLN NPU 장치로 분산시킬 수 있습니다. 다중 모델 인스턴스는 config.pbtxt 파일내 instance_group 필드의 count 값을 지정하여 설정할 수 있습니다.

예를 들어, 이전 ResNet50 튜토리얼 기준으로 모델에 대한 두 개의 실행 인스턴스를 생성하기 위해 python_backend/examples/rbln/resnet50/config.pbtxt 파일에서 count 값을 2로 증가 시킵니다:

...

instance_group [
  {
    count: 2  # number of instances
    kind: KIND_MODEL
  }
]

...

model.py에서 런타임 인스턴스의 장치 인덱스를 device 매개변수로 설정해야 합니다:

1
2
3
4
5
6
7
#
# model.py
#

def initalize():
    # .......
    module = rebel.Runtime(rbln_path, device=instance_idx)

instance_idx 매개변수는 instance_group 내에서 인스턴스 번호를 나타내는 인덱스입니다. 이 인덱스를 이용하여 사용자의 하드웨어 구성에 맞게 적절한 RBLN NPU 장치로 맵핑할 수 있습니다. 아래 코드는 instance_idx를 해당하는 RBLN NPU 장치에 맵핑하는 간단한 예시입니다:

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#
# model.py
#
import json
import os
import rebel
import triton_python_backend_utils as pb_utils

# Number of devices to allocate.
# Available device numbers can be found through `rbln-stat` command.
NUM_OF_DEVICES = 2

class TritonPythonModel:
    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.

        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device ID
          * model_instance_name: A string containing model instance name in form of <model_name>_<instance_group_id>_<instance_id>
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """

        self.model_config = model_config = json.loads(args["model_config"])
        instance_group_config = model_config["instance_group"][0]
        instance_count = instance_group_config["count"]
        instance_idx = 0
        # Get `instance_idx` for multiple instances.
        # instance_group's count should be bigger than 1 in config.pbtxt.
        if instance_count > 1:
            instance_name_parts = args["model_instance_name"].split("_")
            if not instance_name_parts[-1].isnumeric():
                raise pb_utils.TritonModelException(
                    "model instance name should end with '_<instance_idx>', got {}".format(
                        args["model_instance_name"]
                    )
                )
            instance_idx = int(instance_name_parts[-1])

        output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT__0")
        self.output0_dtype = pb_utils.triton_string_to_numpy(
            output0_config["data_type"]
        )
        rbln_path = os.path.join(
            args["model_repository"],
            args["model_version"],
            f"{args['model_name']}.rbln",
        )
         # Allocate instance to device.
         # Simple example of round robin assignment to multiple devices.
        self.module = rebel.Runtime(rbln_path, device=instance_idx % NUM_OF_DEVICES)

    def execute(self, requests):
        # ... Same as previous ...
        return responses

동적 배칭(Dynamic Batching)

Triton은 동적 배칭 기능을 제공하여 다수의 추론 요청을 서버가 자동으로 그룹화하여 처리할 수 있도록 합니다. 즉, 서버가 개별 추론 요청을 개별적으로 처리하지 않고 실시간으로 관련된 추론 요청을 동적으로 결합하여 배치를 생성합니다.

RBLN SDK를 사용하여 동적 배칭을 활성화하려면 여러 입력 형상으로 컴파일된 모델이 필요합니다. 아래는 배치 범위 1-4에 대한 예시 코드입니다:

size = 224 # 이미지의 넓이와 높이
batches = [1, 2, 3, 4] # 지원되는 배치의 크기
input_infos = []

# 개별 배치 크기에 대한 입력 정보를 생성
for i, batch in enumerate(batches):
    input_info = [("x", [batch, 3, size, size], "float32")]
    input_infos.append(input_info)

# 미리 정의된 입력 정보를 기반으로 모델 컴파일
compiled_model = rebel.compile_from_torch(model, input_info=input_infos)

config.pbtxt 파일에서는 아래처럼 컴파일된 모델의 최대 배치 크기를 특정합니다:

1
2
3
4
5
#
# config.pbtxt
#

max_batch_size: 4

model.py 파일에서 특정 배치 크기를 기반으로 런타임을 생성하고, 생성된 런타임을 통해 제공된 입력 데이터 기반의 추론 연산을 정의할 수 있습니다:

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import os

import numpy as np
import rebel

import triton_python_backend_utils as pb_utils

class TritonPythonModel:
    def _get_index(self, name):
        parts = name.split("__")
        return int(parts[1])

    def initialize(self, args):
        self.model_config = model_config = json.loads(args["model_config"])

        # input_dict를 구성
        self.input_dict = {}
        for config_input in model_config["input"]:
            index = self._get_index(config_input["name"])
            self.input_dict[index] = [
                config_input["name"],
                config_input["data_type"],
                config_input["dims"],
            ]

        # output_dict를 구성
        self.output_dict = {}
        for config_output in model_config["output"]:
            index = self._get_index(config_output["name"])
            self.output_dict[index] = [
                config_output["name"],
                config_output["data_type"],
                config_output["dims"],
            ]

        output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT__0")
        self.output0_dtype = pb_utils.triton_string_to_numpy(
            output0_config["data_type"]
        )
        rbln_path = os.path.join(
            args["model_repository"],
            args["model_version"],
            f"{args['model_name']}.rbln",
        )
        self.model_name = args["model_name"]

        # 컴파일된 모델을 읽기
        compiled_model = rebel.RBLNCompiledModel(
            rbln_path
        )

        # 개별 배치 크기에 따른 런타임 생성
        self.runners = []
        for i in range(model_config["max_batch_size"]):
            self.runners.append(
                compiled_model.create_runtime(input_info_index=i)
            )

    def execute(self, requests):
        if len(requests) > 1:
            print("dynamic batch size : ", len(requests))

        # 입력 데이터에 대한 선행 처리
        responses = []
        inputs = []
        num_requests = len(requests)
        request_batch_sizes = []
        for i in self.input_dict.keys():
            name, dt, _ = self.input_dict[i]
            first_tensor = pb_utils.get_input_tensor_by_name(
                requests[0], name
            ).as_numpy()
            request_batch_sizes.append(first_tensor.shape[0])
            batched_tensor = first_tensor
            for j in range(1, num_requests):
                tensor = pb_utils.get_input_tensor_by_name(requests[j], name).as_numpy()
                request_batch_sizes.append(request_batch_sizes[-1] + tensor.shape[0])
                batched_tensor = np.concatenate((batched_tensor, tensor), axis=0)

            inputs.append(batched_tensor)

        batch_size = batched_tensor.shape[0]
        if batch_size > 1:
            print(f"running inference with batch size : {batch_size}")

        # RBLN 모델에 대한 추론 실행
        batched_results = self.runners[batch_size - 1](batched_tensor)

        # 출력 데이터에 대한 후행 처리
        chunky_batched_results = []
        for i in self.output_dict.keys():
            batch = (
                batched_results[i]
                if isinstance(batched_results, tuple)
                else batched_results
            )
            chunky_batched_results.append(
                np.array_split(batch, request_batch_sizes, axis=0)
            )

        # 결과 전송
        for i in range(num_requests):
            output_tensors = []
            for j in self.output_dict.keys():
                name, dt, _ = self.output_dict[j]
                result = chunky_batched_results[j][i]
                output_tensor = pb_utils.Tensor(
                    name, result.astype(pb_utils.triton_string_to_numpy(dt))
                )
                output_tensors.append(output_tensor)
            inference_response = pb_utils.InferenceResponse(
                output_tensors=output_tensors
            )
            responses.append(inference_response)

        return responses

    def finalize(self):
        print("Cleaning up...")