Continuous Batching을 사용하여 대규모 언어 모델(LLM) 서빙하기
대형 언어 모델(LLMs) 서빙 시 NPU 성능을 최대로 활용을 위해서는 continuous batching 이라는 서빙 최적화 기법이 필요합니다.
이 튜토리얼에서는 vllm-rbln
을 사용하여 continuous batching을 구현하고 LLM 서빙을 최적화하는 방법을 안내합니다.
vllm-rbln
은 vLLM 라이브러리의 확장으로, vLLM
이 optimum-rbln
과 함께 작동할 수 있도록 수정된 버전입니다.
필요 패키지 설치
rebel-compiler
, optimum-rbln
, vllm-rbln
패키지의 최신 버전이 설치되어 있어야 합니다. 각 패키지의 최신 버전은 릴리즈 노트에서 확인 할 수 있습니다.
| $ pip3 install -i https://pypi.rbln.ai/simple/ "rebel-compiler>=0.5.12" "optimum-rbln>=0.1.12" "vllm-rbln>=0.1.0"
|
Llama2-7B 모델 컴파일하기
먼저, optimum-rbln을 사용하여 Llama2-7B 모델을 컴파일합니다.
| from optimum.rbln import RBLNLlamaForCausalLM
# HuggingFace PyTorch Llama2 모델을 RBLN 컴파일된 모델로 내보내기
model_id = "meta-llama/Llama-2-7b-chat-hf"
compiled_model = RBLNLlamaForCausalLM.from_pretrained(
model_id=model_id,
export=True,
rbln_max_seq_len=4096,
rbln_tensor_parallel_size=4, # Rebellions Scalable Design (RSD)를 위한 ATOM+ 개수
rbln_batch_size=4, # Continous batching을 위해 batch_size > 1 권장
)
compiled_model.save_pretrained("rbln-Llama-2-7b-chat-hf")
|
적절한 배치 크기를 선택해야 합니다. 여기에서는 4로 설정합니다.
본 튜토리얼에서는 위에서 컴파일 된 모델을 서빙하는 세 가지 방법을 소개합니다.
첫 번째 방법은 vLLM API를 사용하는 것이고, 두 번째는 Triton Inference Server를 사용하는 방법, 마지막 세 번째는 vLLM이 제공하는 OpenAI API 호환 서버를 사용하는 방법입니다.
vLLM API 사용
vLLM의 API를 사용해 컴파일된 모델을 실행할 수 있습니다.
다음 코드는 앞서 컴파일된 모델을 이용해 vLLM 엔진을 초기화하고 초기화된 엔진으로 인퍼런스를 수행하는 방법을 보여줍니다.
vllm_api_example.py |
---|
| import asyncio
from transformers import AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
# Please make sure the engine configurations match the parameters used when compiling.
model_id = "meta-llama/Llama-2-7b-chat-hf"
max_seq_len = 4096
batch_size = 4
engine_args = AsyncEngineArgs(
model=model_id,
device="rbln",
max_num_seqs=batch_size,
max_num_batched_tokens=max_seq_len,
max_model_len=max_seq_len,
block_size=max_seq_len,
compiled_model_dir="rbln-Llama-2-7b-chat-hf",
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
tokenizer = AutoTokenizer.from_pretrained(model_id)
def stop_tokens():
eot_id = next((k for k, t in tokenizer.added_tokens_decoder.items() if t.content == "<|eot_id|>"), None)
if eot_id is not None:
return [tokenizer.eos_token_id, eot_id]
else:
return [tokenizer.eos_token_id]
sampling_params = SamplingParams(
temperature=0.0,
skip_special_tokens=True,
stop_token_ids=stop_tokens(),
)
# Runs a single inference for an example
async def run_single(chat, request_id):
results_generator = engine.generate(chat, sampling_params, request_id=request_id)
final_result = None
async for result in results_generator:
# You can use the intermediate `result` here, if needed.
final_result = result
return final_result
conversation = [{"role": "user", "content": "What is the first letter of English alphabets?"}]
chat = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
result = asyncio.run(run_single(chat, "123"))
print(result)
async def run_multi(chats):
tasks = [asyncio.create_task(run_single(chat, i)) for (i, chat) in enumerate(chats)]
return [await task for task in tasks]
# Runs multiple inferences in parallel
conversations = [
[{"role": "user", "content": "What is the first letter of English alphabets?"}],
[{"role": "user", "content": "What is the last letter of English alphabets?"}],
]
chats = [
tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
for conversation in conversations
]
results = asyncio.run(run_multi(chats))
print(results)
|
vLLM API에 대한 더 자세한 내용은 vLLM 문서를 참고해주세요.
Triton Inference Server 사용
vLLM을 사용하는 두번째 방법은 Triton Inference Server의 vLLM 백엔드를 사용하는 것입니다.
Backend.AI를 사용하는 경우에는 1단계로, 그렇지 않은 경우에는 1단계를 건너뛰고 2단계부터 시작하십시오.
1 단계. Backend.AI 환경 설정
Backend.AI를 통해 세션을 시작합니다.
실행 환경을 Triton Server (ngc-triton
)를 선택 후 24.01 / vllm / x86_64 / python-py3
버전을 사용합니다.
2 단계. Nvidia Triton vllm_backend
준비 및 Llama2-7B 모델 구성 수정
A. Nvidia Triton Inference Server의 vllm_backend
저장소를 클론합니다:
| $ git clone https://github.com/triton-inference-server/vllm_backend.git -b r24.01
|
B. 미리 컴파일된 모델이 들어 있는 rbln-Llama-2-7b-chat-hf
디렉토리를 vllm_backend/samples/model_repository/vllm_model/1
로 옮깁니다:
| $ cp -R /PATH/TO/YOUR/rbln-Llama-2-7b /PATH/TO/YOUR/CLONED/vllm_backend/samples/model_repository/vllm_model/1
|
이 시점에 vllm_backend 폴더의 구조는 다음과 같아야 합니다:
| +-- vllm_backend/
| +-- samples/
| | +-- model_repository/
| | | +-- vllm_model/
| | | | +-- config.pbtxt
| | | | +-- 1/
| | | | | +-- model.json
| | | | | +-- rbln-Llama-2-7b-chat-hf/
| | | | | | +-- compiled_model.rbln
| | | | | | +-- config.json
| | | | | | +-- (and others)
| | +-- (and others)
| +-- (and others)
|
C. 서빙되는 모델의 model.json
수정
vllm_backend/samples/model_repository/vllm_model/1/model.json
파일을 다음과 같이 수정합니다:
| {
"model": "meta-llama/Llama-2-7b-chat-hf",
"device": "rbln",
"max_num_seqs": 4,
"compiled_model_dir": "/ABSOLUTE/PATH/TO/rbln-Llama-2-7b-chat-hf",
"max_num_batched_tokens": 4096,
"max_model_len": 4096,
"block_size": 4096
}
|
- model
: HuggingFace transformers 모델의 이름 또는 경로. HuggingFace에 등록된 모델이 아니면 모델의 절대 경로로 설정해야 합니다 참고.
- device
: vLLM 실행을 위한 디바이스. rbln
으로 설정해주세요.
- max_num_seqs
: 최대 시퀀스 수. 이는 컴파일된 batch_size
와 반드시 일치해야 합니다.
- compiled_model_dir
: RBLN(optimum-rbln)으로 컴파일된 모델 디렉토리의 절대 경로
- RBLN 장치를 대상으로 할 때 max_model_len
, block_size
, max_num_batched_tokens
인수는 최대 시퀀스 길이와 동일해야 합니다.
3 단계. 추론 서버 실행
이제 추론 서버를 실행할 준비가 되었습니다.
Backend.AI를 사용하고 있다면 A. Backend.AI 섹션을 참조해주세요. Backend.AI 사용자가 아니라면 B. Backend.AI
없이 자체 Docker 컨테이너로 시작하기 단계로 건너뛰세요.
A. Backend.AI
필요한 패키지를 설치합니다:
| $ pip3 install -i https://pypi.rbln.ai/simple/ "rebel-compiler>=0.5.2" "optimum-rbln>=0.1.4" vllm-rbln
|
Triton 서버를 시작합니다:
| $ tritonserver --model-repository PATH/TO/YOUR/vllm_backend/samples/model_repository
|
아래와 같이 서버가 성공적으로 시작되었음을 의미하는 메시지를 확인할 수 있습니다:
| Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002
|
B. Backend.AI
없이 자체 Docker 컨테이너로 시작하기
Backend.AI
를 사용하지 않는 경우, 아래의 단계들로 자체 Docker 컨테이너에서 추론 서버를 시작할 수 있습니다. (Backend.AI
사용자는 Step 5로 건너뛰세요.)
컨테이너에서 RBLN NPU 장치를 사용하기 위해 privileged 모드로 추론 서버 컨테이너를 실행해야 합니다. 또한 이 전 단계에서 준비한 vllm_backend
디렉토리를 마운트해야합니다:
| sudo docker run --privileged --shm-size=1g --ulimit memlock=-1 \
-v /PATH/TO/YOUR/vllm_backend:/opt/tritonserver/vllm_backend \
-p 8000:8000 -p 8001:8001 -p 8002:8002 --ulimit stack=67108864 -ti nvcr.io/nvidia/tritonserver:24.01-vllm-python-py3
|
컨테이너 상에서 필요한 패키지들을 설치합니다:
| $ pip3 install -i https://pypi.rbln.ai/simple/ "rebel-compiler>=0.5.2" "optimum-rbln>=0.1.4" vllm-rbln
|
컨테이너 상에서 Triton 서버를 시작합니다:
| $ tritonserver --model-repository /opt/tritonserver/vllm_backend/samples/model_repository
|
아래와 같이 서버가 성공적으로 시작되었음을 의미하는 메시지를 확인할 수 있습니다:
| Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002
|
4 단계. gRPC 클라이언트 추론 요청
triton 추론 서버의 vLLM 백엔드는 자체적으로 model.py
를 정의하고 있기 때문에, 이전 섹션에서 정의한 model.py
와는 입출력 시그니쳐가 다릅니다. 이전 섹션에서 정의한 model.py
은 입력을 INPUT__0
, 출력을 OUTPUT__0
이라고 불렀지만 vLLM에서는 입력은 text_input
, 출력은 text_output
이라고 부릅니다. 따라서 클라이언트도 여기에 맞춰 수정해야 합니다. 자세한 내용은 vLLM model.py 페이지를 참고해주세요.
다음 코드는 vLLM 백엔드를 호출하기 위한 클라이언트 코드입니다. 이 코드 역시 이전 섹션의 클라이언트와 마찬가지로 실행하기 위해서 tritonclient
와 grpcio
패키지가 필요합니다.
simple_vllm_client.py |
---|
| import asyncio
import numpy as np
import tritonclient.grpc.aio as grpcclient
async def try_request():
url = "<host and port number of the triton inference server>" # e.g. "localhost:8001"
client = grpcclient.InferenceServerClient(url=url, verbose=False)
model_name = "vllm_model"
def create_request(prompt, request_id):
input = grpcclient.InferInput("text_input", [1], "BYTES")
prompt_data = np.array([prompt.encode("utf-8")])
input.set_data_from_numpy(prompt_data)
stream_setting = grpcclient.InferInput("stream", [1], "BOOL")
stream_setting.set_data_from_numpy(np.array([True]))
inputs = [input, stream_setting]
output = grpcclient.InferRequestedOutput("text_output")
outputs = [output]
return {
"model_name": model_name,
"inputs": inputs,
"outputs": outputs,
"request_id": request_id
}
prompt = "What is the first letter of English alphabets?"
async def requests_gen():
yield create_request(prompt, "req-0")
response_stream = client.stream_infer(requests_gen())
async for response in response_stream:
result, error = response
if error:
print("Error occurred!")
else:
output = result.as_numpy("text_output")
for i in output:
decoded = i.decode("utf-8")
if decoded.startswith(prompt):
decoded = decoded[len(prompt):]
print(decoded, end="", flush=True)
asyncio.run(try_request())
|
다른 샘플링 매개변수(temperature
, top_p
, top_k
, max_tokens
, early_stopping
등)를 변경해야 하는 경우 client.py를 참조해주세요.
OpenAI 호환 API 서버
vLLM은 OpenAI 호환 API 서버를 제공합니다.
이 서버는 OpenAI API의 completions와 chat를 지원합니다.
먼저, vllm-rbln
이 설치되어 있는지 확인하세요. 그리고 다음과 같이 vllm.entrypoints.openai.api_server
모듈을 실행하면 API 서버가 시작됩니다.
| $ python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-2-7b-chat-hf \
--device rbln \
--max-num-seqs 4 \
--compiled-model-dir </ABSOLUTE/PATH/TO/rbln-Llama-2-7b-chat-hf> \
--max-num-batched-tokens 4096 \
--max-model-len 4096 \
--block-size 4096
|
이 실행 명령에 주어지는 인자들의 내용은 이전 섹션의 model.json
의 내용과 동일합니다.
인증 기능을 활성화하려면 --api-key <API key로 사용될 문자열>
플래그를 추가합니다.
API 서버가 실행되고 나면 OpenAI의 파이썬 및 node.js 클라이언트를 이용해 API를 호출하거나 다음과 같이 curl 명령을 이용해 API를 실행할 수 있습니다.
| $ curl http://<host and port number of the server>/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer <API key, if specified when running the server>" \
-d '{
"model": "meta-llama/Llama-2-7b-chat-hf",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello!"
}
],
"stream": true
}'
|
API에 대한 더 자세한 사항은 OpenAI 문서를 참고하세요.