Llama3.1-8B with Flash Attention
TorchServe provides a vLLM Handler that utilizes a custom handler to support the vLLM engine. With this handler, vllm-rbln can be leveraged to serve LLM models efficiently. This tutorial guides you through serving the Llama3.1-8B model with Flash Attention using TorchServe’s vLLM Handler and vllm-rbln.
For instructions on setting up the TorchServe environment, refer to TorchServe.
To check the YAML files, model compilation and TorchServe configuration introduced on this page, visit the Model Zoo.
Note
This tutorial is written with the assumption that the reader already has a good understanding of how to compile and infer models using RBLN SDK. If you are not familiar with RBLN SDK, please refer to the Tutorials.
Prerequisites
The following prerequisites should be prepared for this tutorial.
Compile Llama3.1-8B
To prepare the model for serving, create the rbln_model folder and navigate into it.
| $ mkdir rbln_model
$ cd rbln_model
|
Compile the Llama3.1-8B model using optimum-rbln.
| get_model.py |
|---|
| from optimum.rbln import RBLNLlamaForCausalLM
import os
model_id = "meta-llama/Llama-3.1-8B-Instruct"
# Compile and export
model = RBLNLlamaForCausalLM.from_pretrained(
model_id=model_id,
export=True, # Export a PyTorch model to RBLN model with Optimum
rbln_batch_size=1, # Batch size
rbln_max_seq_len=131_072, # Maximum sequence length
rbln_tensor_parallel_size=8, # Tensor parallelism
rbln_kvcache_partition_len=16_384, # Length of KV cache partitions for flash attention
)
# Save compiled results to disk
model.save_pretrained(os.path.basename(model_id))
|
Note
You need to select an appropriate batch size. In this case, it is set to 1.
Quick Start with TorchServe
In TorchServe, models are served as Model Archive (.mar) units, which contain all necessary information for serving the model. The following guide explains how to create a .mar file and use it for model serving.
RBLN vLLM Handler
TorchServe provides a vLLM Handler to utilize the vLLM Engine. Because the handler code may have a dependency issue with the installed vLLM version, we suggest using RBLN vLLM Handler, which is compatible with the latest version of vllm-rbln, as shown below:
| rbln_vllm_handler.py |
|---|
| import asyncio
import logging
import os
import pathlib
import time
from unittest.mock import MagicMock
from ts.handler_utils.utils import send_intermediate_predict_response
from ts.service import PredictionException
from ts.torch_handler.base_handler import BaseHandler
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm import AsyncEngineArgs, AsyncLLMEngine
logger = logging.getLogger(__name__)
class RBLN_VLLMHandler(BaseHandler):
def __init__(self):
super().__init__()
self.vllm_engine = None
self.model_name = None
self.model_dir = None
self.adapters = None
self.openai_serving_model = None
self.chat_completion_service = None
self.completion_service = None
self.raw_request = None
self.initialized = False
def initialize(self, ctx):
self.model_dir = ctx.system_properties.get("model_dir")
vllm_engine_config = self._get_vllm_engine_config(ctx.model_yaml_config.get("handler", {}))
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
self.vllm_engine = AsyncLLMEngine.from_engine_args(vllm_engine_config)
if vllm_engine_config.served_model_name:
served_model_names = vllm_engine_config.served_model_name
else:
served_model_names = [vllm_engine_config.model]
chat_template = ctx.model_yaml_config.get("handler", {}).get("chat_template", None)
loop = asyncio.get_event_loop()
model_config = loop.run_until_complete(self.vllm_engine.get_model_config())
base_model_paths = [
BaseModelPath(name=name, model_path=self.model_dir) for name in served_model_names
]
self.openai_serving_models = OpenAIServingModels(
engine_client=self.vllm_engine,
model_config=model_config,
base_model_paths=base_model_paths,
)
self.completion_service = OpenAIServingCompletion(
self.vllm_engine,
model_config,
self.openai_serving_models,
request_logger=None,
)
self.chat_completion_service = OpenAIServingChat(
self.vllm_engine,
model_config,
self.openai_serving_models,
"assistant",
request_logger=None,
chat_template=chat_template,
)
async def isd():
return False
self.raw_request = MagicMock()
self.raw_request.headers = {}
self.raw_request.is_disconnected = isd
self.initialized = True
async def handle(self, data, context):
start_time = time.time()
metrics = context.metrics
data_preprocess = await self.preprocess(data, context)
output = await self.inference(data_preprocess, context)
output = await self.postprocess(output)
stop_time = time.time()
metrics.add_time("HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms")
return output
async def preprocess(self, requests, context):
assert len(requests) == 1, "Expecting batch_size = 1"
req_data = requests[0]
data = req_data.get("data") or req_data.get("body")
if isinstance(data, (bytes, bytearray)):
data = data.decode("utf-8")
return [data]
async def inference(self, input_batch, context):
url_path = context.get_request_header(0, "url_path")
if url_path == "v1/models":
models = await self.chat_completion_service.show_available_models()
return [models.model_dump()]
directory = {
"v1/completions": (
CompletionRequest,
self.completion_service,
"create_completion",
),
"v1/chat/completions": (
ChatCompletionRequest,
self.chat_completion_service,
"create_chat_completion",
),
}
RequestType, service, func = directory.get(url_path, (None, None, None))
if RequestType is None:
raise PredictionException(f"Unknown API endpoint: {url_path}", 404)
request = RequestType.model_validate(input_batch[0])
g = await getattr(service, func)(
request,
self.raw_request,
)
if isinstance(g, ErrorResponse):
return [g.model_dump()]
if request.stream:
async for response in g:
if response != "data: [DONE]\n\n":
send_intermediate_predict_response(
[response], context.request_ids, "Result", 200, context
)
return [response]
else:
return [g.model_dump()]
async def postprocess(self, inference_outputs):
return inference_outputs
def _get_vllm_engine_config(self, handler_config: dict):
vllm_engine_params = handler_config.get("vllm_engine_config", {})
model = vllm_engine_params.get("model", {})
if len(model) == 0:
model_path = handler_config.get("model_path", {})
assert (
len(model_path) > 0
), "please define model in vllm_engine_config or model_path in handler"
model = pathlib.Path(self.model_dir).joinpath(model_path)
if not model.exists():
logger.debug(
f"Model path ({model}) does not exist locally."
" Trying to give without model_dir as prefix."
)
model = model_path
else:
model = model.as_posix()
logger.debug(f"EngineArgs model: {model}")
vllm_engine_config = AsyncEngineArgs(model=model)
self._set_attr_value(vllm_engine_config, vllm_engine_params)
return vllm_engine_config
def _set_attr_value(self, obj, config: dict):
items = vars(obj)
for k, v in config.items():
if k in items:
setattr(obj, k, v)
|
Write the Model Configuration
Let’s create a model_config.yaml file to configure the number of workers and TorchServe frontend parameters for serving the Llama3.1-8B model. This yaml file contains the vLLM engine settings for LLM serving.
For more details, refer to TorchServe Document - Advanced configuration.
| model_config.yaml |
|---|
| # TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1 # Set the number of worker to create a single model instance
maxBatchDelay: 100
startupTimeout: 1200 # (in seconds) Give the worker time to load the model weights
asyncCommunication: true # This ensures we can cummunicate asynchronously with the worker
# Handler parameters
handler:
vllm_engine_config: # vLLM configuration which gets fed into AsyncVLLMEngine
model: "Llama-3.1-8B-Instruct" # Can be a model identifier for Hugging Face hub or a local path
served_model_name:
- "llama3.1-8b"
|
model: Absolute path of the compiled model.
served_model_name: The name of the model to be served.
Model Archiving with torch-model-archiver
The model_store directory stores .mar files, including the Llama3.1-8B model archive used in this tutorial, for serving.
Now that the setup is complete, run the torch-model-archiver command to create the model archive file.
| $ torch-model-archiver \
--model-name llama3.1-8b \
--version 1.0 \
--handler ./rbln_vllm_handler.py \
--config-file ./model_config.yaml \
--archive-format no-archive \
--export-path model_store/ \
--extra-files rbln_model/
|
The options passed to torch-model-archiver are as follows.
--model-name: Set the name of the model to be served as llama3.1-8b.
--version: Specifies the version of the model to be served with TorchServe.
--handler: Specifies the handler script for the model, set as rbln_vllm_handler.py.
--config-file: Specifies the yaml configuration file for the model, set as model_config.yaml.
--archive-format: An option to specify the archiving format. Set as no-archive.
--export-path: Specifies the directory where the archived model will be stored, set to the model_store folder created earlier.
--extra-files: Specifies a list of additional dependency files to include in the archive. Multiple files or directories can be specified, separated by commas (,). The internal folder structure of the specified directories is preserved in the archive.
Once the archiving process using torch-model-archiver is complete, a folder named llama3.1-8b will be created in model_store, where the model will be served. Since the no-archive option was used, the archive’s internal files will be stored in this folder instead of being packaged into a .mar file.
| +--(YOUR_PATH)/
| +-- model_store/
| | +-- llama3.1-8b
| | | +-- MAR-INF
| | | | +-- MANIFEST.json
| | | +-- Llama-3.1-8B-Instruct
| | | | +-- prefill.rbln
| | | | +-- decoder.rbln
| | | | +-- config.json
| | | | +-- (else model files)
| | | +-- model_config.yaml
|
Run torchserve
TorchServe can be started by running the following command. For a simple test where token authentication is not required, you can use the --disable-token-auth option.
| $ torchserve --start --ncs --model-store model_store --models llama3.1-8b --disable-token-auth
|
--start: Starts the TorchServe service.
--ncs: Disable snapshot feature.
--model-store: Specifies the directory containing models.
--models: Loads a specific model. Loads all models available in the model_store directory.
--disable-token-auth: Disables authentication for management API endpoints, simplifying testing.
When TorchServe is started in success, it operates in the background. The command to stop TorchServe is as follows:
The Management API of TorchServe receives requests on port 8081 by default.
You can check the list of models currently being served using the following Management API.
| $ curl -X GET "http://localhost:8081/models"
|
If the operation is successful, you can verify that the Llama3.1-8B model is being served.
| {
"models": [
{
"modelName": "llama3.1-8b",
"modelUrl": "llama3.1-8b"
}
]
}
|
Inference Request with TorchServe Inference API
Simple Request with curl
Now, we can send an inference request using the Prediction API from the TorchServe Inference API to test the Llama3.1-8B model served with TorchServe.
The Inference API of TorchServe receives requests on port 8080 by default.
Make an inference request using the TorchServe Inference API with curl.
| $ curl -sS --header 'Content-Type: application/json' \
--request POST \
--data-binary @- http://localhost:8080/predictions/llama3.1-8b/1.0/v1/chat/completions <<EOF | \
grep -oP '"content":"[^"]*"' | \
sed 's/"content":"//;s/"$//' | \
tr -d '\n'; echo
{
"model": "llama3.1-8b",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello!"
}
],
"stream": true
}
EOF
|
If the inference request is successful, the following similar response is returned.
| Hello. It's nice to meet you. Is there something I can help you with or would you like to chat?
|