Triton 추론 서버를 사용하여 대규모 언어 모델(LLM) 서빙하기
이 페이지에서는 가장 유명한 대규모 언어 모델(LLM) 중 하나인 Llama2-7B
모델을 Nvidia Triton 추론 서버를 이용해 서빙하는 방법을 소개합니다.
Triton 추론 서버 이용해 서빙하기
Nvidia Triton 추론 서버 페이지의 1단계에서 설명된 것과 같이 triton-inference-server의 python_backend 리포지토리를 복제합니다.
| $ git clone https://github.com/triton-inference-server/python_backend -b r24.01
|
1단계. Llama2-7B 모델 준비하기
Llama2-7B에서 생성된 디렉토리를 python_backend/examples/rbln/llama-2-7b-chat-hf/1
로 위치시킵니다:
| $ mkdir -p python_backend/examples/rbln/llama-2-7b-chat-hf/1
$ cp -r rbln-Llama-2-7b-chat-hf python_backend/examples/rbln/llama-2-7b-chat-hf/1/
|
이 과정 이후에 디렉토리의 구조가 다음과 같이 되어야 합니다.
| +--python_backend/
| +-- examples/
| | +-- rbln/
| | | +-- llama-2-7b-chat-hf/
| | | | +-- 1/
| | | | | +-- rbln-Llama-2-7b-chat-hf/
| | | | | | +-- compiled_model.rbln
| | | | | | +-- config.json
| | | | | | +-- (and others)
| | +-- (and others)
| +-- (and others)
|
2단계. Llama2-7B TritonPythonModel 작성
이제 python_backend/examples/rbln/llama-2-7b-chat-hf/config.pbtxt
위치에 새로운 파일을 만들고 다음의 내용을 복사해 넣습니다. 이 파일은 모델의 입력과 출력 시그니쳐를 정의하고 모델의 몇가지 특성을 규정합니다.
config.pbtxt |
---|
| name: "llama-2-7b-chat-hf"
backend: "python"
input [ # (1)
{
name: "INPUT__0"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
output [ # (2)
{
name: "OUTPUT__0"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
instance_group [
{
count: 1
kind: KIND_MODEL
}
]
max_batch_size: 1
model_transaction_policy {
decoupled: True # (3)
}
|
- 모델의 입력 시그니쳐를 기술하고 있습니다. 이 모델은
INPUT__0
이라는 이름의 입력 하나를 받고, 그 입력은 문자열임을 나타냅니다.
- 모델의 출력 시그니쳐를 기술하고 있습니다. 이 모델은
OUTPUT__0
이라는 이름의 출력을 하나 만들어내며, 그 출력은 문자열입니다.
- streaming 추론을 활성화하기 위해서는
model_transaction_policy.decoupled
를 True
로 설정해주어야 합니다.
다음으로 python_backend/examples/rbln/llama-2-7b-chat-hf/1/model.py
위치에 새로운 파일을 만들고 다음의 내용을 복사해 넣습니다. 이 스크립트는 RBLN SDK를 사용하여 static batching 방식으로 LLM 모델을 실행합니다. 또한, 이 스크립트는 decoupled model을 지원하기 위해 클라이언트와 통신에 gRPC를 지원합니다.
model.py |
---|
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import numpy as np
import triton_python_backend_utils as pb_utils
from optimum.rbln import BatchTextIteratorStreamer, RBLNLlamaForCausalLM
from transformers import AutoTokenizer
from threading import Thread
DEFAULT_PROMPT = "what is the first letter in alphabet?"
class TritonPythonModel:
def initialize(self, args):
"""`initialize` is called only once when the model is being loaded.
Parameters
----------
args : dict
Both keys and values are strings. The dictionary keys and values are:
* model_config: A JSON string containing the model configuration
* model_instance_kind: A string containing model instance kind
* model_instance_device_id: A string containing model instance device ID
* model_instance_name: A string containing model instance name in form of <model_name>_<instance_group_id>_<instance_id>
* model_repository: Model repository path
* model_version: Model version
* model_name: Model name
"""
self.model_config = model_config = json.loads(args["model_config"])
self.max_batch_size = model_config["max_batch_size"]
output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT__0")
self.output0_dtype = pb_utils.triton_string_to_numpy(output0_config["data_type"])
model_dir = os.path.join(
args["model_repository"],
args["model_version"],
"rbln-Llama-2-7b-chat-hf",
)
self.model = RBLNLlamaForCausalLM.from_pretrained(
model_id=model_dir,
export=False,
)
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", pad_token="[PAD]", padding_side="left")
self.streamer = BatchTextIteratorStreamer(
tokenizer=self.tokenizer, batch_size=self.max_batch_size, skip_prompt=True, skip_special_tokens=True
)
def execute(self, requests):
"""`execute` MUST be implemented in every Python model. `execute`
function receives a list of pb_utils.InferenceRequest as the only
argument. This function is called when an inference request is made
for this model. Depending on the batching configuration (e.g. Dynamic
Batching) used, `requests` may contain multiple requests. Every
Python model, must create one pb_utils.InferenceResponse for every
pb_utils.InferenceRequest in `requests`. If there is an error, you can
set the error argument when creating a pb_utils.InferenceResponse
Parameters
----------
requests : list
A list of pb_utils.InferenceRequest
Returns
-------
list
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""
inputs = []
num_requests = len(requests)
batch_sentences = [DEFAULT_PROMPT] * self.max_batch_size
for i in range(num_requests):
sentence = pb_utils.get_input_tensor_by_name(requests[i], "INPUT__0").as_numpy()[0][0]
sentence = str(sentence.decode("utf-8")).strip()
batch_sentences[i] = sentence
print(sentence)
output0_dtype = self.output0_dtype
inputs = self.tokenizer(batch_sentences, return_tensors="pt", padding=True)
generation_kwargs = dict(
**inputs,
streamer=self.streamer,
do_sample=False,
max_length=self.model.max_seq_len,
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in self.streamer:
for i in range(num_requests):
out_data = np.array([new_text[i].encode("utf-8")])
out_tensor = pb_utils.Tensor("OUTPUT__0", out_data.astype(output0_dtype))
inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor])
response_sender = requests[i].get_response_sender()
response_sender.send(inference_response)
for i in range(num_requests):
response_sender = requests[i].get_response_sender()
out_data = np.array(["".encode("utf-8")])
out_tensor = pb_utils.Tensor("OUTPUT__0", out_data.astype(output0_dtype))
inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor])
response_sender.send(
inference_response,
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
)
return None
def finalize(self):
print("Cleaning up...")
|
이 과정까지 성공적으로 마친 뒤에 디렉토리 구조는 다음과 같이 되어야 합니다.
| +--python_backend/
| +-- examples/
| | +-- rbln/
| | | +-- llama-2-7b-chat-hf/
| | | | +-- config.pbtxt ============== (new file)
| | | | +-- 1/
| | | | | +-- model.py ============== (new file)
| | | | | +-- rbln-Llama-2-7b-chat-hf/
| | | | | | +-- compiled_model.rbln
| | | | | | +-- config.json
| | | | | | +-- (and others)
| | +-- (and others)
| +-- (and others)
|
3단계. 컨테이너에서 추론 서버 실행
Triton Inference Server의 3단계를 다시 반복합니다. 이때 추가로 컨테이너 안에 optimum-rbln
를 설치합니다.
| $ pip3 install -i https://pypi.rbln.ai/simple/ optimum-rbln
|
Step 4. gRPC 클라이언트 추론 요청
실행된 Triton 추론 서버에 gRPC를 이용해 요청을 보내기 위해 tritonclient
와 grpcio
패키지를 설치합니다.
| $ pip3 install tritonclient==2.41.1 grpcio
|
다음의 파이썬 스크립트는 gRPC를 이용해 Triton 추론 서버에 요청을 보내고 받는 방법을 보여줍니다.
simple_client.py |
---|
| import asyncio
import numpy as np
import tritonclient.grpc.aio as grpcclient
async def try_request():
url = "<host and port number of the triton inference server>" # e.g. "localhost:8001"
client = grpcclient.InferenceServerClient(url=url, verbose=False)
model_name = "llama-2-7b-chat-hf"
def create_request(prompt, request_id):
prompt_data = np.array([prompt.encode("utf-8")])
input = grpcclient.InferInput("INPUT__0", [1, 1], "BYTES")
input.set_data_from_numpy(prompt_data.reshape(1, 1))
inputs = [input]
output = grpcclient.InferRequestedOutput("OUTPUT__0")
outputs = [output]
return {
"model_name": model_name,
"inputs": inputs,
"outputs": outputs,
"request_id": request_id
}
prompt = "What is the first letter of English alphabets?"
async def requests_gen():
yield create_request(prompt, "req-0")
response_stream = client.stream_infer(requests_gen())
async for response in response_stream:
result, error = response
if error:
print("Error occurred!")
else:
output = result.as_numpy("OUTPUT__0")
for i in output:
print(i.decode("utf-8"), end="", flush=True)
asyncio.run(try_request())
|
Continuous Batching
대형 언어 모델(LLMs) 서빙 시 NPU 성능을 최대로 활용을 위해서는 continuous batching 이라는 서빙 최적화 기법이 필요합니다.
이어지는 Continuous Batching 사용하여 LLM 서빙 문서에서는 continuous batching을 구현한 vLLM을 이용해 Llama2-7B
모델을 실행하는 방법을 소개합니다.