Skip to content

Llama3.1-8B with Flash Attention

Overview

In this tutorial you will deploy the Llama3.1-8B model with Flash Attention on Ray Serve using RBLN NPUs.

The workflow covers:

  1. Verifying the environment and compiled model.
  2. Defining a Ray Serve deployment that targets RBLN hardware.
  3. Launching the application with the Serve CLI.
  4. Sending an inference request to validate the endpoint.

If you need help configuring Ray Serve itself, review the Ray Serve overview first. For a complete script-based example (from compilation to deployment), see the model zoo reference.

Setup & Installation

Before you begin, ensure that your system environment is properly configured and that all required packages are installed. This includes:

  • System Requirements:
    • Ubuntu 20.04 LTS (Debian bullseye) or higher
    • System with RBLN NPUs equipped (e.g., RBLN ATOM™)
  • Packages Requirements:
  • Installation Command:
    pip install -U ray[serve] transformers requests torch --extra-index-url https://download.pytorch.org/whl/cpu
    

Note

The following sections assume you already understand how to compile and execute models with the RBLN SDK. Revisit the RBLN Optimum and the vLLM guide if you need a refresher.

Prerequisites

Prepare the Compiled Model

Compile the Llama3.1-8B model using optimum-rbln. This code is based on the Rebellions Model Zoo.

The following parameters are used to compile the Llama3.1-8B model:

  • export: Must be True to compile the model.
  • rbln_batch_size: Defines the batch size to use for compilation.
  • rbln_max_seq_len: Specifies the maximum sequence length.
  • rbln_tensor_parallel_size: Sets the number of NPUs to use for inference.
  • rbln_kvcache_partition_len: Defines the length of KV cache partitions for flash attention. rbln_max_seq_len must be a multiple of rbln_kvcache_partition_len and greater than rbln_kvcache_partition_len.

Note

You need to select an appropriate batch size. In this case, it is set to 1.

get_model.py
import os

from optimum.rbln import RBLNLlamaForCausalLM

model_id = "meta-llama/Llama-3.1-8B-Instruct"

# Compile and export
model = RBLNLlamaForCausalLM.from_pretrained(
    model_id=model_id,
    export=True,
    rbln_batch_size=1,
    rbln_max_seq_len=131_072,
    rbln_tensor_parallel_size=8,
    rbln_kvcache_partition_len=16_384,
)

# Save compiled results to disk
model.save_pretrained(os.path.basename(model_id))

Deployment Flow

Deployment Overview

Step Description
1. Deployment implementation Configure Ray to use RBLN NPUs and define the Ray Serve deployment that loads the compiled model, initializes the runtime, and exposes an endpoint.
2. Execution Launch the deployment with Ray Serve CLI(serve run), optionally configuring application names, device sets, or remote Ray clusters.
3. Inference request Send an HTTP request to the Serve endpoint and inspect the response to validate the deployment.

The sections below walk through these steps in order.

1.1 Resource Allocation

Ray exposes custom accelerators through the resources argument, so each task or deployment can request exactly the hardware it needs.

The Actor below shows how to request an RBLN resource with @ray.remote(resources={"RBLN": 8}); increase the value whenever your deployment needs more cards. The companion RBLNActor helper retrieves the assigned device ID and passes it to the Serve deployment. See RBLN NPUs with Ray for additional background.

1
2
3
4
@ray.remote(resources={"RBLN": 8})
class RBLNActor:
    def getDeviceId(self):
        return ray.get_runtime_context().get_accelerator_ids()["RBLN"]

1.2 Deployment Definition

Ray Serve deployments are defined by annotating a class or function with @serve.deployment. This decorator registers the class as a Ray Serve service endpoint, allowing Ray Serve to manage the lifecycle (deployment, scaling, updates).

llama3_1-8b.py
import json
import os
from unittest.mock import MagicMock

import ray
from fastapi import FastAPI, HTTPException
from ray import serve
from starlette.requests import Request
from starlette.responses import StreamingResponse
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    CompletionRequest,
    ErrorResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels

app = FastAPI()

ray.init(resources={"RBLN": 8})


@ray.remote(resources={"RBLN": 8})
class RBLNActor:
    def getDeviceId(self):
        return ray.get_runtime_context().get_accelerator_ids()["RBLN"]


@serve.deployment(num_replicas=1, ray_actor_options={"num_cpus": 16})
@serve.ingress(app)
class Llama3_1__8B:
    def __init__(self, rbln_actor: RBLNActor):
        """
        Initialize actor.
        :return:
        """
        self.engine = None
        self.rbln_actor = rbln_actor
        self.model_name = "Llama-3.1-8B-Instruct"
        self.raw_request = None
        self.vllm_engine = None
        self.openai_serving_models = None
        self.completion_service = None
        self.chat_completion_service = None
        self.ids = ray.get(rbln_actor.getDeviceId.remote())

        self.os_environment_vars()
        self.initialize()

    def os_environment_vars(self):
        """
        Redefine the environment variables to be passed to the RBLN runtime and vLLM
        :return:
        """
        if self.ids is None or len(self.ids) <= 0:
            os.environ.pop("RBLN_DEVICES")
        os.environ["RBLN_DEVICES"] = ",".join(self.ids)
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    def initialize(self):
        """
        Initialize vLLM engine and prepare request handlers
        :return:
        """
        vllm_engine_args = AsyncEngineArgs(
            model=self.model_name
        )
        self.vllm_engine = AsyncLLMEngine.from_engine_args(vllm_engine_args)

        self.openai_serving_models = OpenAIServingModels(
            engine_client=self.vllm_engine,
            model_config=self.vllm_engine.vllm_config.model_config,
            base_model_paths=[
                BaseModelPath(name=self.model_name, model_path=os.getcwd())
            ],
        )

        self.completion_service = OpenAIServingCompletion(
            self.vllm_engine,
            self.vllm_engine.vllm_config.model_config,
            self.openai_serving_models,
            request_logger=None,
        )

        self.chat_completion_service = OpenAIServingChat(
            self.vllm_engine,
            self.vllm_engine.vllm_config.model_config,
            self.openai_serving_models,
            "assistant",
            request_logger=None,
            chat_template_content_format="auto",
            chat_template=None,
        )

        async def isd():
            return False

        self.raw_request = MagicMock()
        self.raw_request.headers = {}
        self.raw_request.is_disconnected = isd

    @app.post("/v1/chat/completions")
    async def chat_completion(self, http_request: Request):
        """
        Handle chat completion request.
        :param http_request: The HTTP request object
        :return: The chat completion response
        """
        try:
            json_string: dict = await http_request.json()
        except json.JSONDecodeError:
            raise HTTPException(status_code=400, detail="Invalid JSON format request")
        request: ChatCompletionRequest = ChatCompletionRequest.model_validate(
            json_string
        )

        g = await self.chat_completion_service.create_chat_completion(
            request, self.raw_request
        )

        if isinstance(g, ErrorResponse):
            return [g.model_dump()]

        if request.stream:

            async def stream_generator():
                async for response in g:
                    yield response

            return StreamingResponse(stream_generator(), media_type="text/event-stream")
        else:
            return [g.model_dump()]

    @app.post("/v1/completions")
    async def completion(self, http_request: Request):
        """
        Handle completion request.
        :param http_request: The HTTP request object
        :return: The completion response
        """
        try:
            json_string: dict = await http_request.json()
        except json.JSONDecodeError:
            raise HTTPException(status_code=400, detail="Invalid JSON format request")
        request: CompletionRequest = CompletionRequest.model_validate(json_string)

        g = await self.completion_service.create_completion(request, self.raw_request)

        if isinstance(g, ErrorResponse):
            return [g.model_dump()]

        if request.stream:

            async def stream_generator():
                async for response in g:
                    yield response

            return StreamingResponse(stream_generator(), media_type="text/event-stream")
        else:
            return [g.model_dump()]


rbln_actor = RBLNActor.remote()
app = Llama3_1__8B.bind(rbln_actor)

2. Execution

Use the Ray Serve CLI (serve run) to launch the application. The argument uses the module:application format, where module is the Python filename (without .py) and application is the exported Serve entry point.

In this sample, llama3_1-8b.py defines app, so the following command starts the deployment. Add extra options when connecting to a remote Ray cluster or when pinning RBLN_DEVICES to specific cards.

$ serve run llama3_1-8b:app --name "llama3.1-8b"

Example Output:

Application 'llama3.1-8b' is ready at http://127.0.0.1:8000/.

3.1 Inference Request Example(Completion API)

1
2
3
4
5
echo '{
"model": "Llama-3.1-8B-Instruct",
"prompt": "A robot may not injure a human being",
"stream": false
}' | curl -sN -H "Content-Type: application/json" -X POST --data-binary @-                             http://localhost:8000/v1/completions | jq .

Example Output:

OUTPUT
[
{
    "id": "cmpl-790b3169a1c1428a8912140848b42753",
    "object": "text_completion",
    "created": 1762993969,
    "model": "Llama-3.1-8B-Instruct",
    "choices": [
    {
        "index": 0,
        "text": " or, through inaction, allow a human being to come to harm.\nA",
        "logprobs": null,
        "finish_reason": "length",
        "stop_reason": null,
        "token_ids": null,
        "prompt_logprobs": null,
        "prompt_token_ids": null
    }
    ],
    "service_tier": null,
    "system_fingerprint": null,
    "usage": {
    "prompt_tokens": 10,
    "total_tokens": 26,
    "completion_tokens": 16,
    "prompt_tokens_details": null
    },
    "kv_transfer_params": null
}
]    

3.2 Inference Request Example(Chat Completions API)

echo '{
"model": "Llama-3.1-8B-Instruct",
"messages": [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Who won the world series in 2020?"},
    {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
    {"role": "user", "content": "Where was it played?"}
],
"temperature": 0.0,
"max_tokens": 50,
"stream": false
}' | curl -sN -H "Content-Type: application/json" -X POST --data-binary @-                             http://localhost:8000/v1/chat/completions | jq .

Example Output:

OUTPUT
[
{
    "id": "chatcmpl-bad227512c58497eb40dcc1287f29b50",
    "object": "chat.completion",
    "created": 1762994293,
    "model": "Llama-3.1-8B-Instruct",
    "choices": [
    {
        "index": 0,
        "message": {
        "role": "assistant",
        "content": "The 2020 World Series was played at Globe Life Field in Arlington, Texas. It was a neutral site due to the COVID-19 pandemic, and the Dodgers played the Tampa Bay Rays in the series.",
        "refusal": null,
        "annotations": null,
        "audio": null,
        "function_call": null,
        "tool_calls": [],
        "reasoning_content": null
        },
        "logprobs": null,
        "finish_reason": "stop",
        "stop_reason": null,
        "token_ids": null
    }
    ],
    "service_tier": null,
    "system_fingerprint": null,
    "usage": {
    "prompt_tokens": 79,
    "total_tokens": 122,
    "completion_tokens": 43,
    "prompt_tokens_details": null
    },
    "prompt_logprobs": null,
    "prompt_token_ids": null,
    "kv_transfer_params": null
}
]

References