Llama3-8B
개요
이 페이지에서는 Ray Serve에서 RBLN SDK를 활용하여 Llama3-8B 모델을 서빙하는 방법을 소개합니다.
전체 튜토리얼의 주요 흐름은 다음과 같습니다:
- 환경 및 모델 컴파일 확인
- RBLN NPU를 활용하는 Ray Serve deployment 정의
- Serve CLI를 이용한 모델 서빙 애플리케이션 실행
- 엔드포인트 검증을 위한 추론 요청 전송
Ray Serve 환경 구성 방법에 대해서는 Ray Serve 개요를 먼저 참고 바랍니다. 모델 컴파일 및 배포에 대한 전체 스크립트 기반 예제는 모델 주를 참고 바랍니다.
환경 설정 및 설치 확인
시작하기 전에 시스템 환경이 올바르게 구성되어 있으며, 필요한 모든 필수 패키지가 설치되어 있는지 확인하십시오. 다음 항목이 포함됩니다:
- 시스템 요구 사항:
- Ubuntu 20.04 LTS (Debian bullseye) or higher
- System with RBLN NPUs equipped (e.g., RBLN ATOM™)
- 필수 패키지:
- 설치 명령어:
| pip install -U ray[serve] transformers requests torch --extra-index-url https://download.pytorch.org/whl/cpu
|
Note
이 튜토리얼은 사용자가 RBLN SDK 기반의 모델 컴파일 및 추론에 대해 잘 이해하고 있다는 가정하에 작성되었습니다. RBLN SDK 사용법에 익숙하지 않을 경우 RBLN Optimum, vLLM 튜토리얼을 참고 바랍니다.
사전준비
Llama3-8B 모델 컴파일
optimum-rbln를 이용하여 Llama3-8B를 컴파일합니다. 해당 코드는 리벨리온 모델 주를 사용하였습니다.
Llama3-8B 모델을 컴파일하기 위해 사용된 파라미터는 다음과 같습니다.
export: 모델을 컴파일하려면 True로 설정해야 합니다.
rbln_batch_size: 컴파일을 위한 배치 크기를 정의합니다.
rbln_max_seq_len: 최대 시퀀스 길이를 지정합니다.
rbln_tensor_parallel_size: 추론에 사용할 NPU의 수를 설정합니다.
Note
적절한 배치 크기를 선택해야 합니다. 여기에서는 4로 설정합니다.
| get_model.py |
|---|
| import os
from optimum.rbln import RBLNLlamaForCausalLM
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# Compile and export
model = RBLNLlamaForCausalLM.from_pretrained(
model_id=model_id,
export=True,
rbln_batch_size=4,
rbln_max_seq_len=8192,
rbln_tensor_parallel_size=4,
)
# Save compiled results to disk
model.save_pretrained(os.path.basename(model_id))
|
Deployment
Deployment 개요
| Step |
Description |
| 1. Deployment 구현 |
Ray를 RBLN NPU와 함께 사용하도록 설정하고, 컴파일된 모델을 로드하여 런타임을 초기화하고 엔드포인트를 노출하는 Ray Serve deployment를 정의합니다. |
| 2. 실행 |
Ray Serve CLI(serve run)를 사용하여 deployment를 실행합니다. 애플리케이션 이름, 디바이스 세트 또는 원격 Ray 클러스터를 옵션으로 구성할 수 있습니다. |
| 3. 추론 요청 |
Serve 엔드포인트로 HTTP 요청을 보내고 응답을 검사하여 deployment를 검증합니다. |
아래 섹션은 위의 단계를 순서대로 설명합니다.
1.1 리소스 할당
Ray Serve에서 리소스를 할당하는 방식은 각 서빙 작업(Actor, Deployment)별로 resources 파라미터를 통해 NPU 자원을 할당할 수 있습니다.
아래 Actor는 @ray.remote(resources={"RBLN": 4})로 RBLN NPU 리소스를 요청하는 방법을 보여주며, 배포에 필요한 NPU 수만큼 값을 조절할 수 있습니다. RBLNActor 클래스는 할당된 RBLN NPU ID를 Serve Deployment로 전달하는데 사용합니다. 이에 대한 자세한 내용은 Ray에서 RBLN NPU 사용을 참고 바랍니다.
| @ray.remote(resources={"RBLN": 4})
class RBLNActor:
def getDeviceId(self):
return ray.get_runtime_context().get_accelerator_ids()["RBLN"]
|
1.2 Deployment 구현
Ray Serve deployment는 @serve.deployment 데코레이터를 활용하여 클래스를 하나의 Deployment(배포 단위)로 정의합니다. 이 데코레이터를 적용하면 해당 클래스가 Ray Serve의 서비스 엔드포인트로 등록되어, 각 Deployment의 라이프사이클(배포, 확장, 업데이트 등)과 관리를 Ray Serve가 담당하게 됩니다.
| llama3-8b.py |
|---|
| import json
import os
from unittest.mock import MagicMock
import ray
from fastapi import FastAPI, HTTPException
from ray import serve
from starlette.requests import Request
from starlette.responses import StreamingResponse
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
ErrorResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
app = FastAPI()
ray.init(resources={"RBLN": 4})
@ray.remote(resources={"RBLN": 4})
class RBLNActor:
def getDeviceId(self):
return ray.get_runtime_context().get_accelerator_ids()["RBLN"]
@serve.deployment(num_replicas=1, ray_actor_options={"num_cpus": 16})
@serve.ingress(app)
class Llama3_8B:
def __init__(self, rbln_actor: RBLNActor):
"""
Initialize actor.
:return:
"""
self.engine = None
self.rbln_actor = rbln_actor
self.model_name = "Meta-Llama-3-8B-Instruct"
self.raw_request = None
self.vllm_engine = None
self.openai_serving_models = None
self.completion_service = None
self.chat_completion_service = None
self.ids = ray.get(rbln_actor.getDeviceId.remote())
self.os_environment_vars()
self.initialize()
def os_environment_vars(self):
"""
Redefine the environment variables to be passed to the RBLN runtime and vLLM
:return:
"""
if self.ids is None or len(self.ids) <= 0:
os.environ.pop("RBLN_DEVICES")
os.environ["RBLN_DEVICES"] = ",".join(self.ids)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def initialize(self):
"""
Initialize vLLM engine and prepare request handlers
:return:
"""
vllm_engine_args = AsyncEngineArgs(
model=self.model_name
)
self.vllm_engine = AsyncLLMEngine.from_engine_args(vllm_engine_args)
self.openai_serving_models = OpenAIServingModels(
engine_client=self.vllm_engine,
model_config=self.vllm_engine.vllm_config.model_config,
base_model_paths=[
BaseModelPath(name=self.model_name, model_path=os.getcwd())
],
)
self.completion_service = OpenAIServingCompletion(
self.vllm_engine,
self.vllm_engine.vllm_config.model_config,
self.openai_serving_models,
request_logger=None,
)
self.chat_completion_service = OpenAIServingChat(
self.vllm_engine,
self.vllm_engine.vllm_config.model_config,
self.openai_serving_models,
"assistant",
request_logger=None,
chat_template_content_format="auto",
chat_template=None,
)
async def isd():
return False
self.raw_request = MagicMock()
self.raw_request.headers = {}
self.raw_request.is_disconnected = isd
@app.post("/v1/chat/completions")
async def chat_completion(self, http_request: Request):
"""
Handle chat completion request.
:param http_request: The HTTP request object
:return: The chat completion response
"""
try:
json_string: dict = await http_request.json()
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON format request")
request: ChatCompletionRequest = ChatCompletionRequest.model_validate(
json_string
)
g = await self.chat_completion_service.create_chat_completion(
request, self.raw_request
)
if isinstance(g, ErrorResponse):
return [g.model_dump()]
if request.stream:
async def stream_generator():
async for response in g:
yield response
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
return [g.model_dump()]
@app.post("/v1/completions")
async def completion(self, http_request: Request):
"""
Handle completion request.
:param http_request: The HTTP request object
:return: The completion response
"""
try:
json_string: dict = await http_request.json()
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON format request")
request: CompletionRequest = CompletionRequest.model_validate(json_string)
g = await self.completion_service.create_completion(request, self.raw_request)
if isinstance(g, ErrorResponse):
return [g.model_dump()]
if request.stream:
async def stream_generator():
async for response in g:
yield response
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
return [g.model_dump()]
rbln_actor = RBLNActor.remote()
app = Llama3_8B.bind(rbln_actor)
|
2. 실행
Ray Serve CLI(serve run)를 사용하여 애플리케이션을 실행합니다. 파라미터는 module:application 형식으로 지정되며, 여기서 module은 Python 파일명(확장자 .py 제외), application은 Serve 엔트리포인트 객체명입니다.
이 예시에서는 llama3-8b.py에서 app 객체를 정의하므로 아래와 같은 명령어로 deployment를 시작할 수 있습니다. 원격 Ray 클러스터로 연결하거나 RBLN_DEVICES를 이용하여 특정 카드를 지정하는 등 추가 옵션을 함께 사용할 수 있습니다.
| $ serve run llama3-8b:app --name "llama3-8b"
|
예시 출력:
| Application 'llama3-8b' is ready at http://127.0.0.1:8000/.
|
3.1 추론 요청 예시(Completion API)
| echo '{
"model": "Meta-Llama-3-8B-Instruct",
"prompt": "A robot may not injure a human being",
"stream": false
}' | curl -sN -H "Content-Type: application/json" -X POST --data-binary @- http://localhost:8000/v1/completions | jq .
|
예시 출력:
OUTPUT
| [
{
"id": "cmpl-6e39b262d7064dc882c3c3655f396906",
"object": "text_completion",
"created": 1762938935,
"model": "Meta-Llama-3-8B-Instruct",
"choices": [
{
"index": 0,
"text": " or, through inaction, allow a human being to come to harm.\nA",
"logprobs": null,
"finish_reason": "length",
"stop_reason": null,
"token_ids": null,
"prompt_logprobs": null,
"prompt_token_ids": null
}
],
"service_tier": null,
"system_fingerprint": null,
"usage": {
"prompt_tokens": 10,
"total_tokens": 26,
"completion_tokens": 16,
"prompt_tokens_details": null
},
"kv_transfer_params": null
}
]
|
3.2 추론 요청 예시(Chat Completions API)
| echo '{
"model": "Meta-Llama-3-8B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
{"role": "user", "content": "Where was it played?"}
],
"temperature": 0.0,
"max_tokens": 50,
"stream": false
}' | curl -sN -H "Content-Type: application/json" -X POST --data-binary @- http://localhost:8000/v1/chat/completions | jq .
|
예시 출력:
OUTPUT
| [
{
"id": "chatcmpl-3bbd9db1dc1243549ab942a74dabd2d8",
"object": "chat.completion",
"created": 1762941547,
"model": "Meta-Llama-3-8B-Instruct",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "The 2020 World Series was played at Globe Life Field in Arlington, Texas, which is the home stadium of the Texas Rangers. However, the series was played with no fans in attendance due to the COVID-19 pandemic.",
"refusal": null,
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": [],
"reasoning_content": null
},
"logprobs": null,
"finish_reason": "stop",
"stop_reason": null,
"token_ids": null
}
],
"service_tier": null,
"system_fingerprint": null,
"usage": {
"prompt_tokens": 59,
"total_tokens": 106,
"completion_tokens": 47,
"prompt_tokens_details": null
},
"prompt_logprobs": null,
"prompt_token_ids": null,
"kv_transfer_params": null
}
]
|
참고