콘텐츠로 이동

대규모 언어 모델(LLM) 서빙하기

이 페이지에서는 가장 유명한 대규모 언어 모델(LLM) 중 하나인 Llama2-7B 모델을 서빙하는 3가지 방법을 소개합니다. 한가지 방법은 Nvidia Triton 추론 서버를 이용하는 것이고, 두번째 방법은 vLLM과 Nvidia Triton 추론 서버를 이용하는 방법이고, 마지막 방법은 vLLM이 제공하는 OpenAI API 호환 서버를 이용하는 것입니다.

Note

이 튜토리얼은 사용자가 Nvidia Triton Inference ServerLlama2-7B를 완료했다는 가정하에 작성되었습니다.

Triton 추론 서버 이용해 서빙하기

Nvidia Triton 추론 서버 페이지의 1단계에서 설명된 것과 같이 triton-inference-server의 python_backend 리포지토리를 복제합니다.

$ git clone https://github.com/triton-inference-server/python_backend -b r24.01

1단계. Llama2-7B 모델 준비하기

Llama2-7B에서 생성된 디렉토리를 python_backend/examples/rbln/llama-2-7b-chat-hf/1로 위치시킵니다:

$ mkdir -p python_backend/examples/rbln/llama-2-7b-chat-hf/1
$ cp -r rbln-Llama-2-7b-chat-hf python_backend/examples/rbln/llama-2-7b-chat-hf/1/

이 과정 이후에 디렉토리의 구조가 다음과 같이 되어야 합니다.

+--python_backend/
|   +-- examples/
|   |   +-- rbln/
|   |   |   +-- llama-2-7b-chat-hf/
|   |   |   |   +-- 1/
|   |   |   |   |   +-- rbln-Llama-2-7b-chat-hf/
|   |   |   |   |   |   +-- compiled_model.rbln
|   |   |   |   |   |   +-- config.json
|   |   |   |   |   |   +-- (and others)
|   |   +-- (and others)
|   +-- (and others)

2단계. Llama2-7B TritonPythonModel 작성

이제 python_backend/examples/rbln/llama-2-7b-chat-hf/config.pbtxt 위치에 새로운 파일을 만들고 다음의 내용을 복사해 넣습니다. 이 파일은 모델의 입력과 출력 시그니쳐를 정의하고 모델의 몇가지 특성을 규정합니다.

config.pbtxt
name: "llama-2-7b-chat-hf"
backend: "python"

input [  # (1)
  {
    name: "INPUT__0"
    data_type: TYPE_STRING
    dims: [ 1 ]
  }
]
output [  # (2)
  {
    name: "OUTPUT__0"
    data_type: TYPE_STRING
    dims: [ 1 ]
  }
]

instance_group [
    {
      count: 1
      kind: KIND_MODEL
    }
]

max_batch_size: 1

model_transaction_policy {
  decoupled: True  # (3)
}
  1. 모델의 입력 시그니쳐를 기술하고 있습니다. 이 모델은 INPUT__0이라는 이름의 입력 하나를 받고, 그 입력은 문자열임을 나타냅니다.
  2. 모델의 출력 시그니쳐를 기술하고 있습니다. 이 모델은 OUTPUT__0이라는 이름의 출력을 하나 만들어내며, 그 출력은 문자열입니다.
  3. streaming 추론을 활성화하기 위해서는 model_transaction_policy.decoupledTrue 로 설정해주어야 합니다.

다음으로 python_backend/examples/rbln/llama-2-7b-chat-hf/1/model.py 위치에 새로운 파일을 만들고 다음의 내용을 복사해 넣습니다. 이 스크립트는 RBLN SDK를 사용하여 static batching 방식으로 LLM 모델을 실행합니다. 또한, 이 스크립트는 decoupled model을 지원하기 위해 클라이언트와 통신에 gRPC를 지원합니다.

model.py
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import os

import numpy as np
import triton_python_backend_utils as pb_utils
from optimum.rbln import BatchTextIteratorStreamer, RBLNLlamaForCausalLM
from transformers import AutoTokenizer
from threading import Thread

DEFAULT_PROMPT = "what is the first letter in alphabet?"

class TritonPythonModel:
    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.

        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device ID
          * model_instance_name: A string containing model instance name in form of <model_name>_<instance_group_id>_<instance_id>
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """

        self.model_config = model_config = json.loads(args["model_config"])
        self.max_batch_size = model_config["max_batch_size"]

        output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT__0")
        self.output0_dtype = pb_utils.triton_string_to_numpy(output0_config["data_type"])
        model_dir = os.path.join(
            args["model_repository"],
            args["model_version"],
            "rbln-Llama-2-7b-chat-hf",
        )

        self.model = RBLNLlamaForCausalLM.from_pretrained(
            model_id=model_dir,
            export=False,
        )
        self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", pad_token="[PAD]", padding_side="left")
        self.streamer = BatchTextIteratorStreamer(
            tokenizer=self.tokenizer, batch_size=self.max_batch_size, skip_prompt=True, skip_special_tokens=True
        )

    def execute(self, requests):
        """`execute` MUST be implemented in every Python model. `execute`
        function receives a list of pb_utils.InferenceRequest as the only
        argument. This function is called when an inference request is made
        for this model. Depending on the batching configuration (e.g. Dynamic
        Batching) used, `requests` may contain multiple requests. Every
        Python model, must create one pb_utils.InferenceResponse for every
        pb_utils.InferenceRequest in `requests`. If there is an error, you can
        set the error argument when creating a pb_utils.InferenceResponse

        Parameters
        ----------
        requests : list
          A list of pb_utils.InferenceRequest

        Returns
        -------
        list
          A list of pb_utils.InferenceResponse. The length of this list must
          be the same as `requests`
        """
        inputs = []
        num_requests = len(requests)
        batch_sentences = [DEFAULT_PROMPT] * self.max_batch_size
        for i in range(num_requests):
            sentence = pb_utils.get_input_tensor_by_name(requests[i], "INPUT__0").as_numpy()[0][0]
            sentence = str(sentence.decode("utf-8")).strip()
            batch_sentences[i] = sentence
            print(sentence)

        output0_dtype = self.output0_dtype
        inputs = self.tokenizer(batch_sentences, return_tensors="pt", padding=True)

        generation_kwargs = dict(
            **inputs,
            streamer=self.streamer,
            do_sample=False,
            max_length=self.model.max_seq_len,
        )

        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        for new_text in self.streamer:
            for i in range(num_requests):
                out_data = np.array([new_text[i].encode("utf-8")])
                out_tensor = pb_utils.Tensor("OUTPUT__0", out_data.astype(output0_dtype))
                inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor])
                response_sender = requests[i].get_response_sender()
                response_sender.send(inference_response)

        for i in range(num_requests):
            response_sender = requests[i].get_response_sender()
            out_data = np.array(["".encode("utf-8")])
            out_tensor = pb_utils.Tensor("OUTPUT__0", out_data.astype(output0_dtype))
            inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor])
            response_sender.send(
                inference_response,
                flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
            )
        return None

    def finalize(self):
        print("Cleaning up...")

이 과정까지 성공적으로 마친 뒤에 디렉토리 구조는 다음과 같이 되어야 합니다.

+--python_backend/
|   +-- examples/
|   |   +-- rbln/
|   |   |   +-- llama-2-7b-chat-hf/
|   |   |   |   +-- config.pbtxt   ============== (new file)
|   |   |   |   +-- 1/
|   |   |   |   |   +-- model.py   ============== (new file)
|   |   |   |   |   +-- rbln-Llama-2-7b-chat-hf/
|   |   |   |   |   |   +-- compiled_model.rbln
|   |   |   |   |   |   +-- config.json
|   |   |   |   |   |   +-- (and others)
|   |   +-- (and others)
|   +-- (and others)

3단계. 컨테이너에서 추론 서버 실행

Triton Inference Server의 3단계를 다시 반복합니다. 이때 추가로 컨테이너 안에 optimum-rbln를 설치합니다.

$ pip3 install -i https://pypi.rbln.ai/simple/ optimum-rbln

Step 4. gRPC 클라이언트 추론 요청

실행된 triton 추론 서버에 gRPC를 이용해 요청을 보내기 위해 tritonclientgrpcio 패키지를 설치합니다.

$ pip3 install tritonclient==2.41.1 grpcio

다음의 파이썬 스크립트는 gRPC를 이용해 triton 추론 서버에 요청을 보내고 받는 방법을 보여줍니다.

simple_client.py
import asyncio
import numpy as np
import tritonclient.grpc.aio as grpcclient

async def try_request():
  url = "<host and port number of the triton inference server>"  # e.g. "localhost:8001"
  client = grpcclient.InferenceServerClient(url=url, verbose=False)

  model_name = "llama-2-7b-chat-hf"

  def create_request(prompt, request_id):
    prompt_data = np.array([prompt.encode("utf-8")])

    input = grpcclient.InferInput("INPUT__0", [1, 1], "BYTES")
    input.set_data_from_numpy(prompt_data.reshape(1, 1))
    inputs = [input]

    output = grpcclient.InferRequestedOutput("OUTPUT__0")
    outputs = [output]

    return {
      "model_name": model_name,
      "inputs": inputs,
      "outputs": outputs,
      "request_id": request_id
    }

  prompt = "What is the first letter of English alphabets?"

  async def requests_gen():
    yield create_request(prompt, "req-0")

  response_stream = client.stream_infer(requests_gen())

  async for response in response_stream:
    result, error = response
    if error:
      print("Error occurred!")
    else:
      output = result.as_numpy("OUTPUT__0")
      for i in output:
          print(i.decode("utf-8"), end="", flush=True)

asyncio.run(try_request())

vllm-rbln 을 사용한 Continuous batching 지원

대형 언어 모델(LLMs) 서빙 시 NPU 성능을 최대로 활용을 위해서는 continuous batching 이라는 서빙 최적화 기법이 필요합니다. 이 튜토리얼에서는 vllm-rbln을 사용하여 continuous batching을 구현하고 LLM 서빙을 최적화하는 방법을 안내합니다. vllm-rblnvLLM 라이브러리의 확장으로, vLLMoptimum-rbln과 함께 작동할 수 있도록 수정된 버전입니다.

사전준비

1 단계. vllm 옵션을 사용하여 Llama2-7B 컴파일하기

먼저, optimum-rbln을 사용하여 Llama2-7B 모델을 vllm 옵션으로 컴파일합니다.

from optimum.rbln import RBLNLlamaForCausalLM

# HuggingFace PyTorch Llama2 모델을 RBLN 컴파일된 모델로 내보내기
model_id = "meta-llama/Llama-2-7b-chat-hf"
compiled_model = RBLNLlamaForCausalLM.from_pretrained(
    model_id=model_id,
    export=True,
    rbln_max_seq_len=4096,
    rbln_tensor_parallel_size=4,  # Rebellions Scalable Design (RSD)를 위한 ATOM+ 개수
    rbln_batch_size=4,            # 지속적 배칭을 위해 batch_size > 1 권장
    rbln_batching="vllm",         # 지속적 배칭을 위해 `vllm` 옵션으로 컴파일
)

compiled_model.save_pretrained("rbln-Llama-2-7b-chat-hf")
rbln_batchingvllm으로 설정하고 서빙 목적에 맞는 적절한 배치 크기를 선택하십시오. 여기에서는 4로 설정합니다.

2 단계. 환경 설정

Backend.AI를 사용하는 경우

Backend.AI를 통해 세션을 시작합니다. 실행 환경을 Triton Server (ngc-triton)를 선택 후 24.01 / vllm / x86_64 / python-py3 버전을 사용합니다.

온프레미스에서 실행하는 경우

Backend.AI를 사용하지 않고 온프레미스에서 실행하는 경우, 3단계로 건너뜁니다.

3 단계. Nvidia Triton vllm_backend 준비 및 Llama2-7B 모델 구성 수정

A. Nvidia Triton Inference Server의 vllm_backend 저장소를 클론합니다:

$ git clone https://github.com/triton-inference-server/vllm_backend.git -b r24.01

B. 미리 컴파일된 모델이 들어 있는 rbln-Llama-2-7b-chat-hf 디렉토리를 복제한 vllm_backend/samples/model_repository/vllm_model/1 로 옮깁니다:

$ cp -R /PATH/TO/YOUR/rbln-Llama-2-7b /PATH/TO/YOUR/CLONED/vllm_backend/samples/model_repository/vllm_model/1

이 시점에 vllm_backend 폴더의 구조는 다음과 같아야 합니다:

+-- vllm_backend/
|   +-- samples/
|   |   +-- model_repository/
|   |   |   +-- vllm_model/
|   |   |   |   +-- config.pbtxt
|   |   |   |   +-- 1/
|   |   |   |   |   +-- model.json
|   |   |   |   |   +-- rbln-Llama-2-7b-chat-hf/
|   |   |   |   |   |   +-- compiled_model.rbln
|   |   |   |   |   |   +-- config.json
|   |   |   |   |   |   +-- (and others)
|   |   +-- (and others)
|   +-- (and others)

C. 서빙되는 모델의 model.json 수정

vllm_backend/samples/model_repository/vllm_model/1/model.json 파일을 다음과 같이 수정합니다:

1
2
3
4
5
6
7
8
9
{
    "model": "meta-llama/Llama-2-7b-chat-hf",
    "device": "rbln",
    "max_num_seqs": 4,
    "compiled_model_dir": "/ABSOLUTE/PATH/TO/rbln-Llama-2-7b-chat-hf",
    "max_num_batched_tokens": 4096,
    "max_model_len": 4096,
    "block_size": 4096
}
- model: HuggingFace Transformers 모델의 이름 또는 경로. HuggingFace에 등록된 모델이 아니면 모델의 절대 경로로 설정해야 합니다 참고. - device: vLLM 실행을 위한 디바이스. rbln 으로 설정해주세요. - max_num_seqs: 최대 시퀀스 수. 이는 컴파일된 batch_size와 반드시 일치해야 합니다. - compiled_model_dir: RBLN(optimum-rbln)으로 컴파일된 모델 디렉토리의 절대 경로 - RBLN 장치를 대상으로 할 때 max_model_len, block_size, max_num_batched_tokens 인수는 최대 시퀀스 길이와 동일해야 합니다.

Step 4. 추론 서버 실행

이제 추론 서버를 실행할 준비가 되었습니다.

Backend.AI를 사용하고 있다면 A. Backend.AI 섹션을 참조해주세요. Backend.AI 사용자가 아니라면 B. Backend.AI 없이 자체 Docker 컨테이너로 시작하기 단계로 건너뛰세요.

A. Backend.AI

필요한 패키지를 설치합니다:

$ pip3 install -i https://pypi.rbln.ai/simple/ "rebel-compiler>=0.5.2" "optimum-rbln>=0.1.4" vllm-rbln
Triton 서버를 시작합니다:
$ tritonserver --model-repository PATH/TO/YOUR/vllm_backend/samples/model_repository

아래와 같이 서버가 성공적으로 시작되었음을 의미하는 메시지를 확인할 수 있습니다:

1
2
3
Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002

B. Backend.AI 없이 자체 Docker 컨테이너로 시작하기

Backend.AI를 사용하지 않는 경우, 아래의 단계들로 자체 Docker 컨테이너에서 추론 서버를 시작할 수 있습니다. (Backend.AI 사용자는 Step 5로 건너뛰세요.)

컨테이너에서 RBLN NPU 장치를 사용하기 위해 privileged 모드로 추론 서버 컨테이너를 실행해야 합니다. 또한 이 전 단계에서 준비한 vllm_backend 디렉토리를 마운트해야합니다:

1
2
3
sudo docker run --privileged --shm-size=1g --ulimit memlock=-1 \
   -v /PATH/TO/YOUR/vllm_backend:/opt/tritonserver/vllm_backend \
   -p 8000:8000 -p 8001:8001 -p 8002:8002 --ulimit stack=67108864 -ti nvcr.io/nvidia/tritonserver:24.01-vllm-python-py3

컨테이너 상에서 필요한 패키지들을 설치합니다:

$ pip3 install -i https://pypi.rbln.ai/simple/ "rebel-compiler>=0.5.2" "optimum-rbln>=0.1.4" vllm-rbln
컨테이너 상에서 Triton 서버를 시작합니다:
$ tritonserver --model-repository /opt/tritonserver/vllm_backend/samples/model_repository

아래와 같이 서버가 성공적으로 시작되었음을 의미하는 메시지를 확인할 수 있습니다:

1
2
3
Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002

Step 5. gRPC 클라이언트 추론 요청

triton 추론 서버의 vLLM 백엔드는 자체적으로 model.py를 정의하고 있기 때문에, 이전 섹션에서 정의한 model.py와는 입출력 시그니쳐가 다릅니다. 이전 섹션에서 정의한 model.py은 입력을 INPUT__0, 출력을 OUTPUT__0이라고 불렀지만 vLLM에서는 입력은 text_input, 출력은 text_output이라고 부릅니다. 따라서 클라이언트도 여기에 맞춰 수정해야 합니다. 자세한 내용은 vLLM model.py 페이지를 참고해주세요.

다음 코드는 vLLM 백엔드를 호출하기 위한 클라이언트 코드입니다. 이 코드 역시 이전 섹션의 클라이언트와 마찬가지로 실행하기 위해서 tritonclientgrpcio 패키지가 필요합니다.

simple_vllm_client.py
import asyncio
import numpy as np
import tritonclient.grpc.aio as grpcclient

async def try_request():
  url = "<host and port number of the triton inference server>"  # e.g. "localhost:8001"
  client = grpcclient.InferenceServerClient(url=url, verbose=False)

  model_name = "vllm_model"

  def create_request(prompt, request_id):
    input = grpcclient.InferInput("text_input", [1], "BYTES")
    prompt_data = np.array([prompt.encode("utf-8")])
    input.set_data_from_numpy(prompt_data)

    stream_setting = grpcclient.InferInput("stream", [1], "BOOL")
    stream_setting.set_data_from_numpy(np.array([True]))

    inputs = [input, stream_setting]

    output = grpcclient.InferRequestedOutput("text_output")
    outputs = [output]

    return {
      "model_name": model_name,
      "inputs": inputs,
      "outputs": outputs,
      "request_id": request_id
    }

  prompt = "What is the first letter of English alphabets?"

  async def requests_gen():
    yield create_request(prompt, "req-0")

  response_stream = client.stream_infer(requests_gen())

  async for response in response_stream:
    result, error = response
    if error:
      print("Error occurred!")
    else:
      output = result.as_numpy("text_output")
      for i in output:
          print(i.decode("utf-8"), end="", flush=True)

asyncio.run(try_request())

다른 샘플링 매개변수(temperature, top_p, top_k, max_tokens, early_stopping 등)를 변경해야 하는 경우 client.py를 참조해주세요.

OpenAI 호환 API 서버

vLLM은 OpenAI 호환 API 서버를 제공합니다. 이 서버는 OpenAI API의 completionschat를 지원합니다.

먼저, vllm-rbln이 설치되어 있는지 확인하세요. 그리고 다음과 같이 vllm.entrypoints.openai.api_server 모듈을 실행하면 API 서버가 시작됩니다.

1
2
3
4
5
6
7
8
$ python -m vllm.entrypoints.openai.api_server \
  --model meta-llama/Llama-2-7b-chat-hf \
  --device rbln \
  --max-num-seqs 4 \
  --compiled-model-dir </ABSOLUTE/PATH/TO/rbln-Llama-2-7b-chat-hf> \
  --max-num-batched-tokens 4096 \
  --max-model-len 4096 \
  --block-size 4096

이 실행 명령에 주어지는 인자들의 내용은 이전 섹션의 model.json의 내용과 동일합니다.

인증 기능을 활성화하려면 --api-key <Random string to be used as API key> 플래그를 추가합니다.

API 서버가 실행되고 나면 OpenAI의 파이썬 및 node.js 클라이언트를 이용해 API를 호출하거나 다음과 같이 curl 명령을 이용해 API를 실행할 수 있습니다.

$ curl http://<host and port number of the server>/v1/chat/completions \
  -H "Content-Type: application/json" \
  -H "Authorization: Bearer <API key, if specified when running the server>" \
  -d '{
    "model": "meta-llama/Llama-2-7b-chat-hf",
    "messages": [
      {
        "role": "system",
        "content": "You are a helpful assistant."
      },
      {
        "role": "user",
        "content": "Hello!"
      }
    ],
    "stream": true
  }'

API에 대한 더 자세한 사항은 OpenAI 문서를 참고하세요.