콘텐츠로 이동

Triton 추론 서버를 사용하여 대규모 언어 모델(LLM) 서빙하기

이 페이지에서는 가장 유명한 대규모 언어 모델(LLM) 중 하나인 Llama2-7B 모델을 Nvidia Triton 추론 서버를 이용해 서빙하는 방법을 소개합니다.

Note

이 튜토리얼은 사용자가 Nvidia Triton Inference ServerLlama2-7B를 완료했다는 가정하에 작성되었습니다.

Triton 추론 서버 이용해 서빙하기

Nvidia Triton 추론 서버 페이지의 1단계에서 설명된 것과 같이 triton-inference-server의 python_backend 리포지토리를 복제합니다.

$ git clone https://github.com/triton-inference-server/python_backend -b r24.01

1단계. Llama2-7B 모델 준비하기

Llama2-7B에서 생성된 디렉토리를 python_backend/examples/rbln/llama-2-7b-chat-hf/1로 위치시킵니다:

$ mkdir -p python_backend/examples/rbln/llama-2-7b-chat-hf/1
$ cp -r rbln-Llama-2-7b-chat-hf python_backend/examples/rbln/llama-2-7b-chat-hf/1/

이 과정 이후에 디렉토리의 구조가 다음과 같이 되어야 합니다.

+--python_backend/
|   +-- examples/
|   |   +-- rbln/
|   |   |   +-- llama-2-7b-chat-hf/
|   |   |   |   +-- 1/
|   |   |   |   |   +-- rbln-Llama-2-7b-chat-hf/
|   |   |   |   |   |   +-- compiled_model.rbln
|   |   |   |   |   |   +-- config.json
|   |   |   |   |   |   +-- (and others)
|   |   +-- (and others)
|   +-- (and others)

2단계. Llama2-7B TritonPythonModel 작성

이제 python_backend/examples/rbln/llama-2-7b-chat-hf/config.pbtxt 위치에 새로운 파일을 만들고 다음의 내용을 복사해 넣습니다. 이 파일은 모델의 입력과 출력 시그니쳐를 정의하고 모델의 몇가지 특성을 규정합니다.

config.pbtxt
name: "llama-2-7b-chat-hf"
backend: "python"

input [  # (1)
  {
    name: "INPUT__0"
    data_type: TYPE_STRING
    dims: [ 1 ]
  }
]
output [  # (2)
  {
    name: "OUTPUT__0"
    data_type: TYPE_STRING
    dims: [ 1 ]
  }
]

instance_group [
    {
      count: 1
      kind: KIND_MODEL
    }
]

max_batch_size: 1

model_transaction_policy {
  decoupled: True  # (3)
}
  1. 모델의 입력 시그니쳐를 기술하고 있습니다. 이 모델은 INPUT__0이라는 이름의 입력 하나를 받고, 그 입력은 문자열임을 나타냅니다.
  2. 모델의 출력 시그니쳐를 기술하고 있습니다. 이 모델은 OUTPUT__0이라는 이름의 출력을 하나 만들어내며, 그 출력은 문자열입니다.
  3. streaming 추론을 활성화하기 위해서는 model_transaction_policy.decoupledTrue 로 설정해주어야 합니다.

다음으로 python_backend/examples/rbln/llama-2-7b-chat-hf/1/model.py 위치에 새로운 파일을 만들고 다음의 내용을 복사해 넣습니다. 이 스크립트는 RBLN SDK를 사용하여 static batching 방식으로 LLM 모델을 실행합니다. 또한, 이 스크립트는 decoupled model을 지원하기 위해 클라이언트와 통신에 gRPC를 지원합니다.

model.py
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import os

import numpy as np
import triton_python_backend_utils as pb_utils
from optimum.rbln import BatchTextIteratorStreamer, RBLNLlamaForCausalLM
from transformers import AutoTokenizer
from threading import Thread

DEFAULT_PROMPT = "what is the first letter in alphabet?"

class TritonPythonModel:
    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.

        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device ID
          * model_instance_name: A string containing model instance name in form of <model_name>_<instance_group_id>_<instance_id>
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """

        self.model_config = model_config = json.loads(args["model_config"])
        self.max_batch_size = model_config["max_batch_size"]

        output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT__0")
        self.output0_dtype = pb_utils.triton_string_to_numpy(output0_config["data_type"])
        model_dir = os.path.join(
            args["model_repository"],
            args["model_version"],
            "rbln-Llama-2-7b-chat-hf",
        )

        self.model = RBLNLlamaForCausalLM.from_pretrained(
            model_id=model_dir,
            export=False,
        )
        self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", pad_token="[PAD]", padding_side="left")
        self.streamer = BatchTextIteratorStreamer(
            tokenizer=self.tokenizer, batch_size=self.max_batch_size, skip_prompt=True, skip_special_tokens=True
        )

    def execute(self, requests):
        """`execute` MUST be implemented in every Python model. `execute`
        function receives a list of pb_utils.InferenceRequest as the only
        argument. This function is called when an inference request is made
        for this model. Depending on the batching configuration (e.g. Dynamic
        Batching) used, `requests` may contain multiple requests. Every
        Python model, must create one pb_utils.InferenceResponse for every
        pb_utils.InferenceRequest in `requests`. If there is an error, you can
        set the error argument when creating a pb_utils.InferenceResponse

        Parameters
        ----------
        requests : list
          A list of pb_utils.InferenceRequest

        Returns
        -------
        list
          A list of pb_utils.InferenceResponse. The length of this list must
          be the same as `requests`
        """
        inputs = []
        num_requests = len(requests)
        batch_sentences = [DEFAULT_PROMPT] * self.max_batch_size
        for i in range(num_requests):
            sentence = pb_utils.get_input_tensor_by_name(requests[i], "INPUT__0").as_numpy()[0][0]
            sentence = str(sentence.decode("utf-8")).strip()
            batch_sentences[i] = sentence
            print(sentence)

        output0_dtype = self.output0_dtype
        inputs = self.tokenizer(batch_sentences, return_tensors="pt", padding=True)

        generation_kwargs = dict(
            **inputs,
            streamer=self.streamer,
            do_sample=False,
            max_length=self.model.max_seq_len,
        )

        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        for new_text in self.streamer:
            for i in range(num_requests):
                out_data = np.array([new_text[i].encode("utf-8")])
                out_tensor = pb_utils.Tensor("OUTPUT__0", out_data.astype(output0_dtype))
                inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor])
                response_sender = requests[i].get_response_sender()
                response_sender.send(inference_response)

        for i in range(num_requests):
            response_sender = requests[i].get_response_sender()
            out_data = np.array(["".encode("utf-8")])
            out_tensor = pb_utils.Tensor("OUTPUT__0", out_data.astype(output0_dtype))
            inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor])
            response_sender.send(
                inference_response,
                flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
            )
        return None

    def finalize(self):
        print("Cleaning up...")

이 과정까지 성공적으로 마친 뒤에 디렉토리 구조는 다음과 같이 되어야 합니다.

+--python_backend/
|   +-- examples/
|   |   +-- rbln/
|   |   |   +-- llama-2-7b-chat-hf/
|   |   |   |   +-- config.pbtxt   ============== (new file)
|   |   |   |   +-- 1/
|   |   |   |   |   +-- model.py   ============== (new file)
|   |   |   |   |   +-- rbln-Llama-2-7b-chat-hf/
|   |   |   |   |   |   +-- compiled_model.rbln
|   |   |   |   |   |   +-- config.json
|   |   |   |   |   |   +-- (and others)
|   |   +-- (and others)
|   +-- (and others)

3단계. 컨테이너에서 추론 서버 실행

Triton Inference Server의 3단계를 다시 반복합니다. 이때 추가로 컨테이너 안에 optimum-rbln를 설치합니다.

$ pip3 install -i https://pypi.rbln.ai/simple/ optimum-rbln

Step 4. gRPC 클라이언트 추론 요청

실행된 Triton 추론 서버에 gRPC를 이용해 요청을 보내기 위해 tritonclientgrpcio 패키지를 설치합니다.

$ pip3 install tritonclient==2.41.1 grpcio

다음의 파이썬 스크립트는 gRPC를 이용해 Triton 추론 서버에 요청을 보내고 받는 방법을 보여줍니다.

simple_client.py
import asyncio
import numpy as np
import tritonclient.grpc.aio as grpcclient

async def try_request():
  url = "<host and port number of the triton inference server>"  # e.g. "localhost:8001"
  client = grpcclient.InferenceServerClient(url=url, verbose=False)

  model_name = "llama-2-7b-chat-hf"

  def create_request(prompt, request_id):
    prompt_data = np.array([prompt.encode("utf-8")])

    input = grpcclient.InferInput("INPUT__0", [1, 1], "BYTES")
    input.set_data_from_numpy(prompt_data.reshape(1, 1))
    inputs = [input]

    output = grpcclient.InferRequestedOutput("OUTPUT__0")
    outputs = [output]

    return {
      "model_name": model_name,
      "inputs": inputs,
      "outputs": outputs,
      "request_id": request_id
    }

  prompt = "What is the first letter of English alphabets?"

  async def requests_gen():
    yield create_request(prompt, "req-0")

  response_stream = client.stream_infer(requests_gen())

  async for response in response_stream:
    result, error = response
    if error:
      print("Error occurred!")
    else:
      output = result.as_numpy("OUTPUT__0")
      for i in output:
          print(i.decode("utf-8"), end="", flush=True)

asyncio.run(try_request())

Continuous Batching

대형 언어 모델(LLMs) 서빙 시 NPU 성능을 최대로 활용을 위해서는 continuous batching 이라는 서빙 최적화 기법이 필요합니다. 이어지는 Continuous Batching 사용하여 LLM 서빙 문서에서는 continuous batching을 구현한 vLLM을 이용해 Llama2-7B 모델을 실행하는 방법을 소개합니다.