Flash Attention을 이용한 Llama3.1-8B 서빙
Flash Attention은 메모리 사용량을 줄이고 처리 속도를 높여, Llama-3.1-8B와 같은 모델이 긴 문맥을 효율적으로 처리할 수 있도록 도와줍니다. 이 튜토리얼에서는 vllm-rbln
과 Nvidia Triton Inference Server를 사용하여 Flash Attention을 적용한 Llama3.1-8B
모델을 서빙하는 방법을 안내합니다.
이 페이지에서 소개된 모델 컴파일 및 triton vllm_backend 설정에 필요한 커맨드들을 확인하려면 모델 주를 참고해주세요.
참고
이 튜토리얼은 사용자가 RBLN SDK 기반의 모델 컴파일 및 추론에 대해 잘 이해하고 있다는 가정하에 작성되었습니다. RBLN SDK 사용법에 익숙하지 않을 경우 튜토리얼 페이지를 참고 바랍니다.
사전준비
시작하기에 앞서 시스템에 아래 항목들이 준비되어 있는지 확인합니다:
Note
vllm-rbln
패키지는 vllm
패키지와 의존성이 없기 때문에 vllm
패키지를 중복 설치할 경우 vllm-rbln
이 정상적으로 동작하지 않을 수 있습니다. 만약 vllm-rbln
패키지 설치 후 vllm
패키지를 설치했을 경우, vllm-rbln
패키지를 재설치 해주시기 바랍니다.
Llama3.1-8B 컴파일
optimum-rbln를 이용하여 Llama3.1-8B
를 컴파일 합니다.
get_model.py |
---|
| from optimum.rbln import RBLNLlamaForCausalLM
import os
model_id = "meta-llama/Llama-3.1-8B-Instruct"
# Compile and export
model = RBLNLlamaForCausalLM.from_pretrained(
model_id=model_id,
export=True, # Export a PyTorch model to RBLN model with Optimum
rbln_batch_size=1, # Batch size
rbln_max_seq_len=131_072, # Maximum sequence length
rbln_tensor_parallel_size=8, # Tensor parallelism
rbln_kvcache_partition_len=16_384, # Length of KV cache partitions for flash attention
)
# Save compiled results to disk
model.save_pretrained(os.path.basename(model_id))
|
Note
적절한 배치 크기를 선택해야 합니다. 여기에서는 1로 설정합니다.
Triton Inference Server 사용
vLLM을 Triton 추론 서버에서 사용하기 위해 vLLM backend를 사용합니다.
Backend.AI를 사용하는 경우에는 1단계로, 그렇지 않은 경우에는 1단계를 건너뛰고 2단계부터 시작하십시오.
1 단계. Backend.AI 환경 설정
Backend.AI를 통해 세션을 시작합니다.
실행 환경을 Triton Server (ngc-triton
)를 선택 후 24.12 / vllm / x86_64 / python-py3
버전을 사용합니다.
2 단계. Nvidia Triton vllm_backend
준비 및 Llama3.1-8B 모델 구성 수정
A. Nvidia Triton Inference Server의 vllm_backend
저장소를 클론합니다:
| $ git clone https://github.com/triton-inference-server/vllm_backend.git -b r24.12
|
B. 미리 컴파일된 모델이 들어 있는 Llama-3.1-8B-Instruct
디렉토리를 Triton 모델 저장소의 해당 버전 디렉토리 vllm_backend/samples/model_repository/vllm_model/1/
로 옮깁니다:
| $ cp -R /PATH/TO/YOUR/Llama-3.1-8B-Instruct \
/PATH/TO/YOUR/CLONED/vllm_backend/samples/model_repository/vllm_model/1/
|
이 시점에 vllm_backend 폴더의 구조는 다음과 같아야 합니다:
| +-- vllm_backend/ # Triton vLLM 백엔드 메인 디렉토리
| +-- samples/ # 애플리케이션 예제 디렉토리
| | +-- model_repository/ # 모델 저장소 (model_repositories)
| | | +-- vllm_model/ # Triton 으로 서빙할 개별 모델
| | | | +-- config.pbtxt # Triton 모델 설정 파일
| | | | +-- 1/ # 버전 디렉토리
| | | | | +-- model.json # vLLM 서빙위한 모델 설정
| | | | | +-- Llama-3.1-8B-Instruct/ # 컴파일된 모델 파일들(rbln files)
| | | | | | +-- decoder.rbln
| | | | | | +-- prefill.rbln
| | | | | | +-- config.json
| | | | | | +-- (기타 모델 파일들)
| | +-- (기타 예제 파일들)
| +-- (기타 백엔드 파일들)
|
Note
- 기존의 비전 모델에서 사용했던 Triton 서빙과는 달리 Nvidia Triton Server의 vLLM 백엔드는 별도의 model.py 파일이 필요하지 않습니다. vLLM 백엔드가 이미 내부적으로 필요한 모델 처리 로직(도커 컨터이너 내
backends/vllm/model.py
)을 포함하고 있기 때문에, model.json 파일만으로 모델 설정이 가능합니다.
config.pbtxt
의 경우에도 저장소의 config.pbtxt
를 그대로 사용해도 무방합니다. config.pbtxt
가 없는 경우, 다음을 참고해서 만들어주시면 됩니다. 위에서 언급했던 vLLM 백엔드 처리 로직에 따라 처리되므로, input, output 입력은 고정해두어야 합니다(4단계. gRPC 클라이언트 추론 요청 참고).
| name: "vllm_model"
backend: "vllm"
input [
{
name: "text_input"
data_type: TYPE_STRING
dims: [ 1 ]
},
{
name: "stream"
data_type: TYPE_BOOL
dims: [ 1 ]
}
]
output [
{
name: "text_output"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
instance_group [
{
count: 1
kind: KIND_MODEL
}
]
|
C. 서빙되는 모델의 model.json
수정
vllm_backend/samples/model_repository/vllm_model/1/model.json
파일을 다음과 같이 수정합니다:
| {
"model": "/ABSOLUTE/PATH/TO/Llama-3.1-8B-Instruct",
"max_num_seqs": 1,
"max_num_batched_tokens": 131072,
"max_model_len": 131072,
"block_size": 16384,
"device": "rbln"
}
|
model
: 컴파일된 모델의 절대 경로를 설정합니다.
max_num_seqs
: 최대 시퀀스 수. 이는 컴파일된 batch_size
와 반드시 일치해야 합니다.
block_size
: Paged Attention을 위한 블록 크기입니다. Flash Attention을 사용할 때는 블록 크기가 rbln_kvcache_partition_len
과 동일해야 합니다.
device
: vLLM 실행을 위한 디바이스. rbln
으로 설정합니다.
- RBLN NPU를 대상으로 할 때
max_num_batched_tokens
는 max_model_len
과 동일해야합니다
3 단계. 추론 서버 실행
이제 추론 서버를 실행할 준비가 되었습니다. Backend.AI를 사용하고 있다면 A. Backend.AI 섹션을 참고 바랍니다. Backend.AI 사용자가 아니라면 B. Backend.AI
없이 자체 Docker 컨테이너로 시작하기 단계로 건너뛰시기 바랍니다.
A. Backend.AI
필요한 패키지를 설치합니다:
| $ pip3 install --extra-index https://pypi.rbln.ai/simple/ "rebel-compiler>=0.7.3" "optimum-rbln>=0.7.3.post2" "vllm-rbln>=0.7.3"
|
Triton 서버를 시작합니다:
| $ tritonserver --model-repository /opt/tritonserver/vllm_backend/samples/model_repository
|
아래와 같이 서버가 성공적으로 시작되었음을 의미하는 메시지를 확인할 수 있습니다:
| Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002
|
B. Backend.AI
없이 자체 Docker 컨테이너로 시작하기
Backend.AI
를 사용하지 않는 경우, 아래의 단계들로 자체 Docker 컨테이너에서 추론 서버를 시작할 수 있습니다. (Backend.AI
사용자는 Step 4로 건너뛰시기 바랍니다.)
컨테이너에서 RBLN NPU 장치를 사용하기 위해 반드시 서버 컨테이너를 privileged 모드로 실행하거나, 필요한 디바이스를 마운트해서 실행해야 합니다. 자세한 내용은 Docker 지원를 참고 바랍니다. 이 튜토리얼에서는 --privileged 모드로 실행합니다. 그리고 이 전 단계에서 준비한 vllm_backend
디렉토리를 마운트해야합니다:
| $ docker run --privileged --shm-size=1g --ulimit memlock=-1 \
-v /PATH/TO/YOUR/vllm_backend:/opt/tritonserver/vllm_backend \
-p 8000:8000 -p 8001:8001 -p 8002:8002 --ulimit stack=67108864 -ti rebellions/tritonserver:24.12-vllm-python-py3
|
컨테이너 상에서 필요한 패키지들을 설치합니다:
| $ pip3 install --extra-index https://pypi.rbln.ai/simple/ "rebel-compiler>=0.7.3" "optimum-rbln>=0.7.3.post2" "vllm-rbln>=0.7.3"
|
컨테이너 상에서 Triton 서버를 시작합니다:
| $ tritonserver --model-repository /opt/tritonserver/vllm_backend/samples/model_repository
|
아래와 같이 서버가 성공적으로 시작되었음을 의미하는 메시지를 확인할 수 있습니다:
| Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002
|
4 단계. gRPC 클라이언트 추론 요청
Triton 추론 서버의 vLLM 백엔드는 자체적으로 model.py
를 정의하고 있기 때문에, Resnet50 튜토리얼에서 정의한 model.py
와는 입출력 시그니쳐가 다릅니다. Resnet50 튜토리얼에서 정의한 model.py
은 입력을 INPUT__0
, 출력을 OUTPUT__0
이라고 불렀지만 vLLM에서는 입력은 text_input
, 출력은 text_output
이라고 부릅니다. 따라서 클라이언트도 여기에 맞춰 수정해야 합니다. 자세한 내용은 vLLM model.py 페이지를 참고 바랍니다.
다음 코드는 vLLM 백엔드를 호출하기 위한 클라이언트 코드입니다. 이 코드 역시 Resnet50 튜토리얼의 클라이언트와 마찬가지로 실행하기 위해서 tritonclient
와 grpcio
패키지가 필요합니다.
Note
grpc 클라이언트를 적절하게 사용하기 위해서는 채팅 템플릿을 적용해야 합니다. 대화는 system
, user
, assistant
역할로 포맷되어야 하며, 또한 샘플링 파라미터sampling_params
에 적절한 값이 다음과 같이 포함되어있어야 합니다.
| sampling_params = {
"temperature": 0.0,
"stop": ["[User]", "[System]", "[Assistant]"], # stop tokens
}
|
아래의 예제 코드를 참고하시기 바랍니다.
simple_vllm_client.py |
---|
| import asyncio
import numpy as np
import tritonclient.grpc.aio as grpcclient
import json
# Define a simple chat message class
class ChatMessage:
def __init__(self, role, content):
self.role = role
self.content = content
# Apply a simple chat template to the messages
def apply_chat_template(messages):
lines = []
system_msg = ChatMessage(role="system", content="You are a helpful assistant.")
for msg in [system_msg, *messages, ChatMessage(role="assistant", content="")]:
lines.append(f"[{msg.role.capitalize()}]\n{msg.content}")
return "\n".join(lines)
async def try_request():
url = "<host and port number of the triton inference server>" # e.g. "localhost:8001"
client = grpcclient.InferenceServerClient(url=url, verbose=False)
model_name = "vllm_model"
def create_request(messages, request_id):
prompt = apply_chat_template(messages)
print(f"prompt:\n{prompt}\n---") # print prompt
input = grpcclient.InferInput("text_input", [1], "BYTES")
prompt_data = np.array([prompt.encode("utf-8")])
input.set_data_from_numpy(prompt_data)
stream_setting = grpcclient.InferInput("stream", [1], "BOOL")
stream_setting.set_data_from_numpy(np.array([True]))
sampling_params = {
"temperature": 0.0,
"stop": ["[User]", "[System]", "[Assistant]"], # add stop tokens
}
sampling_parameters = grpcclient.InferInput("sampling_parameters", [1], "BYTES")
sampling_parameters.set_data_from_numpy(
np.array([json.dumps(sampling_params).encode("utf-8")], dtype=object)
)
inputs = [input, stream_setting, sampling_parameters]
output = grpcclient.InferRequestedOutput("text_output")
outputs = [output]
return {
"model_name": model_name,
"inputs": inputs,
"outputs": outputs,
"request_id": request_id,
}
messages = [
ChatMessage(
role="user", content="What is the first letter of English alphabets?"
)
]
request_id = "req-0"
async def requests_gen():
yield create_request(messages, request_id)
response_stream = client.stream_infer(requests_gen())
prompt = apply_chat_template(messages)
is_first_response = True
async for response in response_stream:
result, error = response
if error:
print("Error occurred!")
else:
output = result.as_numpy("text_output")
for i in output:
decoded = i.decode("utf-8")
if is_first_response:
if decoded.startswith(prompt):
decoded = decoded[len(prompt) :]
is_first_response = False
print(decoded, end="", flush=True)
print("\n") # end of stream
asyncio.run(try_request())
|
정상적으로 동작할 경우 아래와 유사하게 응답이 출력됩니다.
| The first letter of the English alphabet is 'A'. Would you like to know more about the English alphabet? I can help you with that!
|
다른 샘플링 매개변수(temperature
, top_p
, top_k
, max_tokens
, early_stopping
등)를 변경해야 하는 경우 client.py를 참고 바랍니다.