Llama3.1-8B with Flash Attention
Overview
In this tutorial you will deploy the Llama3.1-8B model with Flash Attention on Ray Serve using RBLN NPUs.
The workflow covers:
- Verifying the environment and compiled model.
- Defining a Ray Serve deployment that targets RBLN hardware.
- Launching the application with the Serve CLI.
- Sending an inference request to validate the endpoint.
If you need help configuring Ray Serve itself, review the Ray Serve overview first. For a complete script-based example (from compilation to deployment), see the model zoo reference.
Setup & Installation
Before you begin, ensure that your system environment is properly configured and that all required packages are installed. This includes:
- System Requirements:
- Ubuntu 20.04 LTS (Debian bullseye) or higher
- System with RBLN NPUs equipped (e.g., RBLN ATOM™)
- Packages Requirements:
- Installation Command:
| pip install -U ray[serve] transformers requests torch --extra-index-url https://download.pytorch.org/whl/cpu
|
Note
The following sections assume you already understand how to compile and execute models with the RBLN SDK. Revisit the RBLN Optimum and the vLLM guide if you need a refresher.
Prerequisites
Prepare the Compiled Model
Compile the Llama3.1-8B model using optimum-rbln. This code is based on the Rebellions Model Zoo.
The following parameters are used to compile the Llama3.1-8B model:
export: Must be True to compile the model.
rbln_batch_size: Defines the batch size to use for compilation.
rbln_max_seq_len: Specifies the maximum sequence length.
rbln_tensor_parallel_size: Sets the number of NPUs to use for inference.
rbln_kvcache_partition_len: Defines the length of KV cache partitions for flash attention. rbln_max_seq_len must be a multiple of rbln_kvcache_partition_len and greater than rbln_kvcache_partition_len.
Note
You need to select an appropriate batch size. In this case, it is set to 1.
| get_model.py |
|---|
| import os
from optimum.rbln import RBLNLlamaForCausalLM
model_id = "meta-llama/Llama-3.1-8B-Instruct"
# Compile and export
model = RBLNLlamaForCausalLM.from_pretrained(
model_id=model_id,
export=True,
rbln_batch_size=1,
rbln_max_seq_len=131_072,
rbln_tensor_parallel_size=8,
rbln_kvcache_partition_len=16_384,
)
# Save compiled results to disk
model.save_pretrained(os.path.basename(model_id))
|
Deployment Flow
Deployment Overview
| Step |
Description |
| 1. Deployment implementation |
Configure Ray to use RBLN NPUs and define the Ray Serve deployment that loads the compiled model, initializes the runtime, and exposes an endpoint. |
| 2. Execution |
Launch the deployment with Ray Serve CLI(serve run), optionally configuring application names, device sets, or remote Ray clusters. |
| 3. Inference request |
Send an HTTP request to the Serve endpoint and inspect the response to validate the deployment. |
The sections below walk through these steps in order.
1.1 Resource Allocation
Ray exposes custom accelerators through the resources argument, so each task or deployment can request exactly the hardware it needs.
The Actor below shows how to request an RBLN resource with @ray.remote(resources={"RBLN": 8}); increase the value whenever your deployment needs more cards. The companion RBLNActor helper retrieves the assigned device ID and passes it to the Serve deployment. See RBLN NPUs with Ray for additional background.
| @ray.remote(resources={"RBLN": 8})
class RBLNActor:
def getDeviceId(self):
return ray.get_runtime_context().get_accelerator_ids()["RBLN"]
|
1.2 Deployment Definition
Ray Serve deployments are defined by annotating a class or function with @serve.deployment. This decorator registers the class as a Ray Serve service endpoint, allowing Ray Serve to manage the lifecycle (deployment, scaling, updates).
| llama3_1-8b.py |
|---|
| import json
import os
from unittest.mock import MagicMock
import ray
from fastapi import FastAPI, HTTPException
from ray import serve
from starlette.requests import Request
from starlette.responses import StreamingResponse
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
ErrorResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
app = FastAPI()
ray.init(resources={"RBLN": 8})
@ray.remote(resources={"RBLN": 8})
class RBLNActor:
def getDeviceId(self):
return ray.get_runtime_context().get_accelerator_ids()["RBLN"]
@serve.deployment(num_replicas=1, ray_actor_options={"num_cpus": 16})
@serve.ingress(app)
class Llama3_1__8B:
def __init__(self, rbln_actor: RBLNActor):
"""
Initialize actor.
:return:
"""
self.engine = None
self.rbln_actor = rbln_actor
self.model_name = "Llama-3.1-8B-Instruct"
self.raw_request = None
self.vllm_engine = None
self.openai_serving_models = None
self.completion_service = None
self.chat_completion_service = None
self.ids = ray.get(rbln_actor.getDeviceId.remote())
self.os_environment_vars()
self.initialize()
def os_environment_vars(self):
"""
Redefine the environment variables to be passed to the RBLN runtime and vLLM
:return:
"""
if self.ids is None or len(self.ids) <= 0:
os.environ.pop("RBLN_DEVICES")
os.environ["RBLN_DEVICES"] = ",".join(self.ids)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def initialize(self):
"""
Initialize vLLM engine and prepare request handlers
:return:
"""
vllm_engine_args = AsyncEngineArgs(
model=self.model_name
)
self.vllm_engine = AsyncLLMEngine.from_engine_args(vllm_engine_args)
self.openai_serving_models = OpenAIServingModels(
engine_client=self.vllm_engine,
model_config=self.vllm_engine.vllm_config.model_config,
base_model_paths=[
BaseModelPath(name=self.model_name, model_path=os.getcwd())
],
)
self.completion_service = OpenAIServingCompletion(
self.vllm_engine,
self.vllm_engine.vllm_config.model_config,
self.openai_serving_models,
request_logger=None,
)
self.chat_completion_service = OpenAIServingChat(
self.vllm_engine,
self.vllm_engine.vllm_config.model_config,
self.openai_serving_models,
"assistant",
request_logger=None,
chat_template_content_format="auto",
chat_template=None,
)
async def isd():
return False
self.raw_request = MagicMock()
self.raw_request.headers = {}
self.raw_request.is_disconnected = isd
@app.post("/v1/chat/completions")
async def chat_completion(self, http_request: Request):
"""
Handle chat completion request.
:param http_request: The HTTP request object
:return: The chat completion response
"""
try:
json_string: dict = await http_request.json()
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON format request")
request: ChatCompletionRequest = ChatCompletionRequest.model_validate(
json_string
)
g = await self.chat_completion_service.create_chat_completion(
request, self.raw_request
)
if isinstance(g, ErrorResponse):
return [g.model_dump()]
if request.stream:
async def stream_generator():
async for response in g:
yield response
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
return [g.model_dump()]
@app.post("/v1/completions")
async def completion(self, http_request: Request):
"""
Handle completion request.
:param http_request: The HTTP request object
:return: The completion response
"""
try:
json_string: dict = await http_request.json()
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON format request")
request: CompletionRequest = CompletionRequest.model_validate(json_string)
g = await self.completion_service.create_completion(request, self.raw_request)
if isinstance(g, ErrorResponse):
return [g.model_dump()]
if request.stream:
async def stream_generator():
async for response in g:
yield response
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
return [g.model_dump()]
rbln_actor = RBLNActor.remote()
app = Llama3_1__8B.bind(rbln_actor)
|
2. Execution
Use the Ray Serve CLI (serve run) to launch the application. The argument uses the module:application format, where module is the Python filename (without .py) and application is the exported Serve entry point.
In this sample, llama3_1-8b.py defines app, so the following command starts the deployment. Add extra options when connecting to a remote Ray cluster or when pinning RBLN_DEVICES to specific cards.
| $ serve run llama3_1-8b:app --name "llama3.1-8b"
|
Example Output:
| Application 'llama3.1-8b' is ready at http://127.0.0.1:8000/.
|
3.1 Inference Request Example(Completion API)
| echo '{
"model": "Llama-3.1-8B-Instruct",
"prompt": "A robot may not injure a human being",
"stream": false
}' | curl -sN -H "Content-Type: application/json" -X POST --data-binary @- http://localhost:8000/v1/completions | jq .
|
Example Output:
OUTPUT
| [
{
"id": "cmpl-790b3169a1c1428a8912140848b42753",
"object": "text_completion",
"created": 1762993969,
"model": "Llama-3.1-8B-Instruct",
"choices": [
{
"index": 0,
"text": " or, through inaction, allow a human being to come to harm.\nA",
"logprobs": null,
"finish_reason": "length",
"stop_reason": null,
"token_ids": null,
"prompt_logprobs": null,
"prompt_token_ids": null
}
],
"service_tier": null,
"system_fingerprint": null,
"usage": {
"prompt_tokens": 10,
"total_tokens": 26,
"completion_tokens": 16,
"prompt_tokens_details": null
},
"kv_transfer_params": null
}
]
|
3.2 Inference Request Example(Chat Completions API)
| echo '{
"model": "Llama-3.1-8B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
{"role": "user", "content": "Where was it played?"}
],
"temperature": 0.0,
"max_tokens": 50,
"stream": false
}' | curl -sN -H "Content-Type: application/json" -X POST --data-binary @- http://localhost:8000/v1/chat/completions | jq .
|
Example Output:
OUTPUT
| [
{
"id": "chatcmpl-bad227512c58497eb40dcc1287f29b50",
"object": "chat.completion",
"created": 1762994293,
"model": "Llama-3.1-8B-Instruct",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "The 2020 World Series was played at Globe Life Field in Arlington, Texas. It was a neutral site due to the COVID-19 pandemic, and the Dodgers played the Tampa Bay Rays in the series.",
"refusal": null,
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": [],
"reasoning_content": null
},
"logprobs": null,
"finish_reason": "stop",
"stop_reason": null,
"token_ids": null
}
],
"service_tier": null,
"system_fingerprint": null,
"usage": {
"prompt_tokens": 79,
"total_tokens": 122,
"completion_tokens": 43,
"prompt_tokens_details": null
},
"prompt_logprobs": null,
"prompt_token_ids": null,
"kv_transfer_params": null
}
]
|
References