콘텐츠로 이동

Llama3-8B 서빙

TorchServe에서 Custom handler 구조를 이용하여 vLLM 엔진을 사용할 수 있는 vLLM Handler를 제공합니다. 이 vLLM Handler를 이용하여 vllm-rbln과 함께 continuous batching을 이용하여 LLM 모델을 서빙할 수 있습니다. 이 튜토리얼에서는 TorchServe의 vLLM Handler와 vllm-rbln을 사용하여, Llama3-8B 모델을 서빙을 하는 방법을 안내합니다.

TorchServe 환경 구성 방법에 대해서는 TorchServe를 참고 바라며, 이 페이지에서 소개된 모델 컴파일 및 TorchServe 환경 구성을 위한 YAML 파일을 확인하려면 모델주를 참고 바랍니다.

참고

이 튜토리얼은 사용자가 RBLN SDK 기반의 모델 컴파일 및 추론에 대해 잘 이해하고 있다는 가정하에 작성되었습니다. RBLN SDK 사용법에 익숙하지 않을 경우 파이토치/텐서플로우 튜토리얼 및 파이썬 API 페이지를 참고 바랍니다.

사전준비

시작하기에 앞서 TorchServe, vllm-rbln, optimum-rbln이 설치된 환경 및 컴파일된 Llama3-8B 모델이 필요합니다.

Note

Llama3-8B 모델을 사용하기 위해 4개의 RBLN NPU가 필요합니다. 모델별 구동에 필요한 RBLN NPU 갯수는 Optimum RBLN Multi-NPUs Supported Models에서 확인할 수 있습니다.

Note

vllm-rbln 패키지는 vllm 패키지와 의존성이 없기 때문에, vllm 패키지를 중복 설치할 경우 vllm-rbln이 정상적으로 동작하지 않을 수 있습니다. 만약 vllm-rbln 설치 후 vllm을 설치했을 경우, vllm-rbln를 재설치 해주시기 바랍니다.

Llama3-8B 컴파일

먼저 서빙에 사용할 모델을 준비하기 위해서, rbln_model 디렉토리를 생성하고 이동합니다.

$ mkdir rbln_model
$ cd rbln_model

optimum-rbln를 이용하여 Llama3-8B 를 컴파일 합니다. 해당 코드는 리벨리온 모델주를 사용하였습니다.

get_model.py
import os

from optimum.rbln import RBLNLlamaForCausalLM

def main():
    model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

    # Compile and export
    model = RBLNLlamaForCausalLM.from_pretrained(
        model_id=model_id,
        export=True,  # export a PyTorch model to RBLN model with optimum
        rbln_batch_size=4,
        rbln_max_seq_len=8192,  # default "max_position_embeddings"
        rbln_tensor_parallel_size=4,
    )

    # Save compiled results to disk
    model.save_pretrained(os.path.basename(model_id))

if __name__ == "__main__":
    main()

Note

적절한 배치 크기를 선택해야 합니다. 여기에서는 4로 설정합니다.

TorchServe를 이용한 모델 서빙

TorchServe에서 모델 서빙은 모델 아카이브(.mar) 파일 단위로 이루어집니다. .mar 파일에는 모델 서빙에 필요한 모든 정보가 포함됩니다. 본 섹션에서는 .mar 파일 생성 및 생성된 .mar 파일을 이용하여 모델을 서빙하는 방법에 대해 설명합니다.

RBLN vLLM Handler

TorchServe에서 vLLM Engine을 사용하고자 할 때 TorchServe에서 제공하는 vLLM Handler를 활용할 수 있습니다. 그러나 TorchServe의 vLLM Handler에 반영된 vLLM 버전과 개발 환경에 설치된 vLLM 패키지 버전의 호환성에 따라 동작이 상이할 수 있습니다. 이 문서에서는 vllm-rbln의 최신 버전에 호환되는 RBLN vLLM Handler를 아래와 같이 제안합니다.

rbln_vllm_handler.py
import asyncio
import logging
import os
import pathlib
import time
from unittest.mock import MagicMock

from ts.handler_utils.utils import send_intermediate_predict_response
from ts.service import PredictionException
from ts.torch_handler.base_handler import BaseHandler
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels

from vllm import AsyncEngineArgs, AsyncLLMEngine

logger = logging.getLogger(__name__)


class RBLN_VLLMHandler(BaseHandler):
    def __init__(self):
        super().__init__()

        self.vllm_engine = None
        self.model_name = None
        self.model_dir = None
        self.adapters = None
        self.openai_serving_model = None
        self.chat_completion_service = None
        self.completion_service = None
        self.raw_request = None
        self.initialized = False

    def initialize(self, ctx):
        self.model_dir = ctx.system_properties.get("model_dir")
        vllm_engine_config = self._get_vllm_engine_config(ctx.model_yaml_config.get("handler", {}))

        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        self.vllm_engine = AsyncLLMEngine.from_engine_args(vllm_engine_config)

        if vllm_engine_config.served_model_name:
            served_model_names = vllm_engine_config.served_model_name
        else:
            served_model_names = [vllm_engine_config.model]

        chat_template = ctx.model_yaml_config.get("handler", {}).get("chat_template", None)

        loop = asyncio.get_event_loop()
        model_config = loop.run_until_complete(self.vllm_engine.get_model_config())

        base_model_paths = [
            BaseModelPath(name=name, model_path=self.model_dir) for name in served_model_names
        ]

        self.openai_serving_models = OpenAIServingModels(
            engine_client=self.vllm_engine,
            model_config=model_config,
            base_model_paths=base_model_paths,
        )

        self.completion_service = OpenAIServingCompletion(
            self.vllm_engine,
            model_config,
            self.openai_serving_models,
            request_logger=None,
        )

        self.chat_completion_service = OpenAIServingChat(
            self.vllm_engine,
            model_config,
            self.openai_serving_models,
            "assistant",
            request_logger=None,
            chat_template=chat_template,
        )

        async def isd():
            return False

        self.raw_request = MagicMock()
        self.raw_request.headers = {}
        self.raw_request.is_disconnected = isd

        self.initialized = True

    async def handle(self, data, context):
        start_time = time.time()

        metrics = context.metrics

        data_preprocess = await self.preprocess(data, context)
        output = await self.inference(data_preprocess, context)
        output = await self.postprocess(output)

        stop_time = time.time()
        metrics.add_time("HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms")
        return output

    async def preprocess(self, requests, context):
        assert len(requests) == 1, "Expecting batch_size = 1"
        req_data = requests[0]
        data = req_data.get("data") or req_data.get("body")
        if isinstance(data, (bytes, bytearray)):
            data = data.decode("utf-8")

        return [data]

    async def inference(self, input_batch, context):
        url_path = context.get_request_header(0, "url_path")

        if url_path == "v1/models":
            models = await self.chat_completion_service.show_available_models()
            return [models.model_dump()]

        directory = {
            "v1/completions": (
                CompletionRequest,
                self.completion_service,
                "create_completion",
            ),
            "v1/chat/completions": (
                ChatCompletionRequest,
                self.chat_completion_service,
                "create_chat_completion",
            ),
        }

        RequestType, service, func = directory.get(url_path, (None, None, None))

        if RequestType is None:
            raise PredictionException(f"Unknown API endpoint: {url_path}", 404)

        request = RequestType.model_validate(input_batch[0])
        g = await getattr(service, func)(
            request,
            self.raw_request,
        )

        if isinstance(g, ErrorResponse):
            return [g.model_dump()]
        if request.stream:
            async for response in g:
                if response != "data: [DONE]\n\n":
                    send_intermediate_predict_response(
                        [response], context.request_ids, "Result", 200, context
                    )
            return [response]
        else:
            return [g.model_dump()]

    async def postprocess(self, inference_outputs):
        return inference_outputs

    def _get_vllm_engine_config(self, handler_config: dict):
        vllm_engine_params = handler_config.get("vllm_engine_config", {})
        model = vllm_engine_params.get("model", {})
        if len(model) == 0:
            model_path = handler_config.get("model_path", {})
            assert (
                len(model_path) > 0
            ), "please define model in vllm_engine_config or model_path in handler"
            model = pathlib.Path(self.model_dir).joinpath(model_path)
            if not model.exists():
                logger.debug(
                    f"Model path ({model}) does not exist locally."
                    " Trying to give without model_dir as prefix."
                )
                model = model_path
            else:
                model = model.as_posix()
        logger.debug(f"EngineArgs model: {model}")
        vllm_engine_config = AsyncEngineArgs(model=model)
        self._set_attr_value(vllm_engine_config, vllm_engine_params)
        return vllm_engine_config

    def _set_attr_value(self, obj, config: dict):
        items = vars(obj)
        for k, v in config.items():
            if k in items:
                setattr(obj, k, v)

모델 서빙 설정 작성

아래와 같이 TorchServe로 Llama3-8B 모델을 서빙하기 위한 설정인 model_config.yaml 파일을 작성합니다. 이 파일은 Llama3-8B 모델을 서빙하는데 필요한 Worker 갯수와 TorchServe의 Frontend 파라미터를 지정하고, vLLM 엔진의 설정을 포함합니다.

설정에 대한 자세한 정보는 TorchServe 문서 - Advanced configuration을 참고하시기 바랍니다.

리벨리온 NPU를 사용하기 위해서 vllm_engine_configdevice: "rbln" 설정이 필요합니다. model 설정은 서빙에 사용할 Llama-3B 모델이 저장된 디렉토리와 정확히 일치해야 합니다.

model_config.yaml
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1             # Set the number of worker to create a single model instance
maxBatchDelay: 100
startupTimeout: 1200      # (in seconds) Give the worker time to load the model weights
asyncCommunication: true  # This ensures we can cummunicate asynchronously with the worker

# Handler parameters
handler:
    vllm_engine_config: # vLLM configuration which gets fed into AsyncVLLMEngine
        max_num_seqs: 4
        max_model_len: 4096
        max_num_batched_tokens: 4096
        device: "rbln"
        model: "Meta-Llama-3-8B-Instruct" # Can be a model identifier for Hugging Face hub or a local path
        served_model_name:
            - "llama3-8b"

torch-model-archiver를 이용한 모델 아카이빙

아래와 같이 모델 아카이브(.mar) 파일이 저장될 경로인 model_store 디렉토리를 생성합니다. 이 디렉토리에 Llama3-8B 모델 아카이브 파일이 저장됩니다.

$ mkdir model_store

이제 모델 아카이브 파일을 만들기 위해 필요한 내용이 모두 준비되었습니다. torch-model-archiver 도구를 이용해 모델 아카이브 파일을 만들 수 있습니다.

1
2
3
4
5
6
7
8
$ torch-model-archiver \
        --model-name llama3-8b \
        --version 1.0 \
        --handler ./rbln_vllm_handler.py \
        --config-file ./model_config.yaml \
        --archive-format no-archive \
        --export-path model_store/ \
        --extra-files rbln_model/

사용된 옵션은 아래와 같습니다.

  • --model-name: 서빙할 모델이름으로 llama3-8b 으로 설정합니다.
  • --version: TorchServe로 서빙할 모델에 대한 버전입니다.
  • --handler: 요청 모델에 대한 Handler Script를 지정하는 옵션이며, 위에서 작성한 rbln_vllm_handler.py 를 지정해줍니다.
  • --config-file: 서빙할 모델의 YAML 설정을 설정하는 옵션으로, 위에서 작성한 model_config.yaml로 설정합니다.
  • --archive-format: 아카이빙 포맷을 설정하는 옵션으로, no-archive로 설정합니다.
  • --export-path: 아카이빙 결과물을 저장할 경로를 설정하는 옵션으로, 위에서 생성한 model_store 디렉토리로 설정합니다.
  • --extra-files: 의존성이 있는 파일들을 추가로 아카이빙에 포함할 리스트를 설정하는 옵션으로 설정한 디렉토리를 제외한 내부 디렉토리의 구조를 그대로 포함하여 아카이빙합니다.

torch-model-archiver 를 이용한 아카이빙이 정상적으로 완료되면, model_store 디렉토리에 서빙할 모델 이름인 llama3-8b 의 디렉토리가 생성됩니다. no-archive 옵션으로 아카이빙 하였으므로, .mar 파일로 아카이빙되는 대신에 llama3-8b에 저장됩니다. no-archive 옵션을 사용하지 않았을 경우 llama3-8b 디렉토리 대신 llama3-8b.mar 파일이 생성됩니다.

+--(YOUR_PATH)/
|   +-- model_store/
|   |   +-- llama3-8b
|   |   |   +-- MAR-INF
|   |   |   |   +-- MANIFEST.json
|   |   |   +-- Meta-Llama-3-8B-Instruct
|   |   |   |   +-- prefill.rbln
|   |   |   |   +-- decoder.rbln
|   |   |   |   +-- config.json
|   |   |   |   +-- (기타 모델 파일들)
|   |   |   +-- model_config.yaml

torchserve 실행

torchserve를 이용하여 서빙을 시작합니다. 서빙에 사용되는 파라미터는 아래와 같습니다.

$ torchserve --start --ncs --model-store model_store --models llama3-8b --disable-token-auth
  • --start: model-server를 시작합니다.
  • --ncs: --no-config-snapshots 옵션입니다.
  • --model-store: 모델을 로드하거나 기본적으로 로드할 모델들의 경로를 지정합니다.
  • --models: 서빙할 모델을 지정합니다.
  • --disable-token-auth: 단순한 서빙 동작 테스트를 하기 위하여 token authorization을 비활성화 합니다.

TorchServe가 정상적으로 실행되면 백그라운드에서 동작합니다. TorchServe의 동작을 중지하기 위한 명령어는 아래와 같습니다.

$ torchserve --stop

TorchServe는 Management API의 기본 설정으로 8081 포트를 통해 요청을 받습니다.

서빙 되고 있는 모델 리스트는 아래와 같이 Management API를 통해 확인할 수 있습니다.

$ curl -X GET "http://localhost:8081/models"

정상적으로 동작할 경우 Llama3-8B 모델이 서빙되고 있는 것을 확인할 수 있습니다.

1
2
3
4
5
6
7
8
{
  "models": [
    {
      "modelName": "llama3-8b",
      "modelUrl": "llama3-8b"
    }
  ]
}

TorchServe Inference API 기반 추론 요청

torchserve로 시작한 Llama3-8B 서빙을 테스트하기 위해서 TorchServe Inference APIPrediction API를 이용하여 추론을 요청합니다.

curl을 이용한 간단한 테스트

TorchServe Inference API는 기본설정으로 8080 포트를 통해 요청을 받습니다.

curl을 이용하여 아래와 같이 8080 포트로 HTTP 요청를 전송하면 간단히 테스트할 수 있습니다.

1
2
3
4
5
$ echo '{
  "model": "llama3-8b",
  "prompt": "A robot may not injure a human being",
  "stream": 0
}' | curl --header "Content-Type: application/json"   --request POST --data-binary @-   http://localhost:8080/predictions/llama3-8b/1.0/v1/completions

정상 동작할 경우 아래와 같은 응답이 출력됩니다.

{
  "id": "cmpl-2826f5c0dc164a5d91d3ae0d4b71a480",
  "object": "text_completion",
  "created": 1737599538,
  "model": "llama3-8b",
  "choices": [
    {
      "index": 0,
      "text": " or, through inaction, allow a human being to come to harm.\nA",
      "logprobs": null,
      "finish_reason": "length",
      "stop_reason": null,
      "prompt_logprobs": null
    }
  ],
  "usage": {
    "prompt_tokens": 10,
    "total_tokens": 26,
    "completion_tokens": 16,
    "prompt_tokens_details": null
  }
}

TorchServe 저장소 예제 테스트

TorchServe 저장소는 vLLM을 이용한 서빙 예시를 제공합니다. 자세한 정보는 TorchServe 저장소의 README.md를 참고 바랍니다.

본 예제에서는 OpenAI의 Text completionChat interface를 이용하여Llama3-8B 모델을 테스트하는 방법을 소개합니다.

$ git clone https://github.com/pytorch/serve
$ cd examples/large_models/vllm/llama3
Text Completion
  • 테스트 명령어

    $ python3 ../../utils/test_llm_streaming_response.py -m llama3-8b -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json --openai-api
    
  • 결과물

    OUTPUT

    Tasks are completed
    payload={'prompt': 'A robot may not injure a human being', 'temperature': 0.8, 'logprobs': 1, 'max_tokens': 128, 'model': 'llama3-8b'}
    , output= or, through inaction, allow a human being to come to harm.
    A robot must obey the orders of human beings except where such orders would conflict with the First Law.
    A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.
    
    These three laws, developed by science fiction author Isaac Asimov, are a cornerstone of the robot rights movement. They provide a framework for robots to operate within, ensuring that they prioritize human well-being and safety above all else.
    
    In this world, robots have become an integral part of our daily lives. They work alongside humans, helping with everything from men
    
    payload={'prompt': 'A robot may not injure a human being', 'temperature': 0.8, 'logprobs': 1, 'max_tokens': 128, 'model': 'llama3-8b'}
    , output= or, through inaction, allow a human being to come to harm.
    A robot must obey the orders of human beings except where such orders would conflict with the First Law.
    A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.
    
    As I held the small, sleek device in my hand, I couldn't help but feel a sense of excitement and trepidation. This was it, the moment I had been waiting for. The moment when I would finally be able to see if my theories were correct. If my robot, my creation, was truly the first of its kind.
    
    I
    
    payload={'prompt': 'A robot may not injure a human being', 'temperature': 0.8, 'logprobs': 1, 'max_tokens': 128, 'model': 'llama3-8b'}
    , output= or, through inaction, allow a human being to come to harm.
    A robot must obey the orders given to it by human beings except where such orders would conflict with the First Law.
    A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.
    — The Three Laws of Robotics, developed by science fiction author Isaac Asimov
    
    When artificial intelligence (AI) is compared to mu ltiple human actors working together, it can be frustrating and challenging to manage the complexity of their relationships. With AI, we don't have a single "actor" with its own motivations, goals, and
    
    payload={'prompt': 'A robot may not injure a human being', 'temperature': 0.8, 'logprobs': 1, 'max_tokens': 128, 'model': 'llama3-8b'}
    , output= or, through inaction, allow a human being to come to harm. A robot must obey the orders given to it by human beings except where such orders would conflict with the First Law. A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.
    
    These rules, devised by Dr. Isaac Asimov, are the foundation of the Three Laws of Robotics. They were first introduced in his 1942 science fiction short story "Runaround" and have since become a cornerstone of the science fiction genre.
    
    The Three Laws are designed to ensure that robots behave in a way that is safe and
    
    payload={'prompt': 'A robot may not injure a human being', 'temperature': 0.8, 'logprobs': 1, 'max_tokens': 128, 'model': 'llama3-8b'}
    , output= or, through inaction, allow a human being to come to harm.
    Robotics and artificial intelligence (AI) are transforming industries and revolutionizing the way we live and work. However, with the increasing development of autonomous machines, there is a growing need to consider the ethics and moral implications of these technologies.
    
    One of the most important ethical principles in robotics and AI is the Three Laws of Robotics, which were first proposed by science fiction author Isaac Asimov in the 1940s. The three laws are:
    
    1. A robot may not injure a human being or, through inaction, allow a human being to come to
    
    payload={'prompt': 'A robot may not injure a human being', 'temperature': 0.8, 'logprobs': 1, 'max_tokens': 128, 'model': 'llama3-8b'}
    , output= or, through inaction, allow a human being to come to harm.
    A robot must obey the orders of human beings except where such orders would conflict with the First Law.
    A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.
    These three laws were written by Dr. Isaac Asimov, a renowned science fiction author, in his 1942 short story "Runaround." They are a fundamental part of the science fiction genre and have been widely adopted as a framework for exploring the ethics of artificial intelligence. Today, we'll be exploring the first law: A robot may not inj
    
    payload={'prompt': 'A robot may not injure a human being', 'temperature': 0.8, 'logprobs': 1, 'max_tokens': 128, 'model': 'llama3-8b'}
    , output=, or, through inaction, allow a human being to come to harm. A robot must protect its own existence as long as such protection does not conflict with the First Law. A robot must follow the instructions of its human masters, except where such instructions conflict with the First or Second Law. A robot must not have a significant negative influence on human existence.
    
    The Three Laws of Robotics, as formulated by science fiction author Isaac Asimov, are a set of rules designed to govern the behavior of robots and artificial intelligence (AI) in a way that prioritizes human safety and well-being. The laws are often seen as a way to
    
    payload={'prompt': 'A robot may not injure a human being', 'temperature': 0.8, 'logprobs': 1, 'max_tokens': 128, 'model': 'llama3-8b'}
    , output= or, through inaction, allow a human being to come to harm.
    To prevent harm, a robot must also obey the following three laws:
    
    1. A robot may not injure a human being or, through inaction, allow a human being to come to harm.
    2. A robot must obey the orders given to it by human beings except where such orders would conflict with the First Law.
    3. A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.
    
    These laws were formulated by Dr. Isaac Asimov in his science fiction stories, and they have since become a standard
    
Chat Interface
  • 테스트 명령어

    $ python3 ../../utils/test_llm_streaming_response.py -m llama3-8b -o 50 -t 2 -n 4 --prompt-text "@chat.json" --prompt-json --openai-api --demo-streaming --api-endpoint "v1/chat/completions"
    
  • 결과물

    OUTPUT

    1
    2
    3
    payload={'model': 'llama3-8b', 'messages': [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Who won the world series in 2020?'}, {'role': 'assistant', 'content': 'The Los Angeles Dodgers won the World Series in 2020.'}, {'role': 'user', 'content': 'Where was it played?'}], 'temperature': 0.0, 'max_tokens': 50, 'stream': True}
    , output=
    The 2020 World Series was played at Globe Life Field in Arlington, Texas, which is the home stadium of the Texas Rangers. However, the series was played with no fans in attendance due to the COVID-19 pandemic.Tasks are completed