Llama3.1-8B with Flash Attention
Flash Attention enables efficient handling of long contexts in models like Llama3.1-8B by reducing memory usage and improving throughput. This tutorial will guide you how to serve Llama3.1-8B with Flash Attention using vllm-rbln and Nvidia Triton Inference Server.
You can check out the actual commands required to compile the model and initialize triton vllm_backend on our model zoo.
Note
This tutorial assumes that you are familiar with compiling and running inference using the RBLN SDK. If you are not familiar with RBLN SDK, refer to RBLN Optimum tutorials and the API Documentation.
Prerequisites
Compile Llama3.1-8B
You need to compile the Llama3.1-8B model using optimum-rbln.
| get_model.py |
|---|
| import os
from optimum.rbln import RBLNLlamaForCausalLM
model_id = "meta-llama/Llama-3.1-8B-Instruct"
# Compile and export
model = RBLNLlamaForCausalLM.from_pretrained(
model_id=model_id,
export=True, # Export a PyTorch model to RBLN model with Optimum
rbln_batch_size=1, # Batch size
rbln_max_seq_len=131_072, # Maximum sequence length
rbln_tensor_parallel_size=8, # Tensor parallelism
rbln_kvcache_partition_len=16_384, # Length of KV cache partitions for flash attention
)
# Save compiled results to disk
model.save_pretrained(os.path.basename(model_id))
|
Note
Choose an appropriate batch size for your serving needs. Here, it is set to 1.
Triton Inference Server with vLLM enabled
Nvidia Triton Inference Server provides vLLM backend.
If you are using Backend.AI, refer to Step 1. If you are using an on-premise server, skip Step 1 and proceed directly to Step 2.
Step 1. Setting Up the Backend.AI Environment
- Start a session via
Backend.AI.
- Select Triton Server (
ngc-triton) as your environment. You can see the version of 26.01 / vllm / x86_64 / python-py3.
Step 2. Prepare Nvidia Triton vllm_backend and Modify Model Configurations for Llama3.1-8B
A. Clone the Nvidia Triton Inference Server vllm_backend repository:
| $ git clone https://github.com/triton-inference-server/vllm_backend.git -b r26.01
|
B. Place the precompiled Llama-3.1-8B-Instruct directory into the cloned vllm_backend/samples/model_repository/vllm_model/1 directory:
| $ cp -R /PATH/TO/YOUR/Llama-3.1-8B-Instruct \
/PATH/TO/YOUR/CLONED/vllm_backend/samples/model_repository/vllm_model/1/
|
Your directory should look like the following at this point:
| +-- vllm_backend/ # Main directory for Triton vLLM backend
| +-- samples/ # Application example directory
| | +-- model_repository/ # Triton model repository (model_repositories)
| | | +-- vllm_model/ # Individual model for Triton serving
| | | | +-- config.pbtxt # Triton model configuration file
| | | | +-- 1/ # Version directory
| | | | | +-- model.json # Model configuration for vLLM serving
| | | | | +-- Llama-3.1-8B-Instruct/ # rbln compiled model files
| | | | | | +-- decoder.rbln
| | | | | | +-- prefill.rbln
| | | | | | +-- config.json
| | | | | | +-- (other model files)
| | +-- (other example files)
| +-- (other backend files)
|
Note
- The vLLM backend for Nvidia Triton Server doesn't need a
model.py file unlike other vision model backends. All model processing logic is pre-included in the Docker container at backends/vllm/model.py, so you only need model.json for configuration.
- You can either use the default
config.pbtxt from the repository or create a new one using the template below. Note that input and output formats must match exactly as shown, since they're required by the vLLM backend (see Step 4: gRPC Client Inference Request).
| name: "vllm_model"
backend: "vllm"
input [
{
name: "text_input"
data_type: TYPE_STRING
dims: [ 1 ]
},
{
name: "stream"
data_type: TYPE_BOOL
dims: [ 1 ]
}
]
output [
{
name: "text_output"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
instance_group [
{
count: 1
kind: KIND_MODEL
}
]
|
C. Modify model.json
Modify vllm_backend/samples/model_repository/vllm_model/1/model.json.
| {
"model": "/ABSOLUTE/PATH/TO/Llama-3.1-8B-Instruct"
}
|
Step 3. Run the Inference Server
We are now ready to run the inference server. If you are using Backend.AI, please refer to the A. Backend.AI section. If you are not a Backend.AI user, proceed to the B. On-premise server section.
A. Backend.AI
Before proceeding, install the required dependencies:
| $ pip3 install \
--extra-index-url https://pypi.rbln.ai/simple \
rebel-compiler==0.10.2
pip install \
--extra-index-url https://wheels.vllm.ai/0.13.0/cpu \
--extra-index-url https://download.pytorch.org/whl/cpu \
vllm-rbln==0.10.2
|
Start the Triton Server:
| $ tritonserver --model-repository PATH/TO/YOUR/vllm_backend/samples/model_repository
|
You will see the following messages that indicate successful initiation of the server:
| Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002
|
B. On-premise server
If you are not using Backend.AI, follow these steps to start the inference server in the Docker container. (Backend.AI users can skip to Step 4.)
To access the RBLN NPU devices, RBLN Container Toolkit must be installed and the cloned vllm_backend repository should also be mounted.
| $ docker run --device rebellions.ai/npu=runtime --privileged --shm-size=1g --ulimit memlock=-1 \
-v /PATH/TO/YOUR/vllm_backend:/opt/tritonserver/vllm_backend \
-p 8000:8000 -p 8001:8001 -p 8002:8002 --ulimit stack=67108864 -ti nvcr.io/nvidia/tritonserver:26.01-vllm-python-py3
|
Install the required dependencies in the container:
| $ pip3 install \
--extra-index-url https://pypi.rbln.ai/simple \
rebel-compiler==0.10.2
pip install \
--extra-index-url https://wheels.vllm.ai/0.13.0/cpu \
--extra-index-url https://download.pytorch.org/whl/cpu \
vllm-rbln==0.10.2
|
Start the Triton Server in the container:
| $ tritonserver --model-repository /opt/tritonserver/vllm_backend/samples/model_repository
|
You will see the following messages that indicate successful initiation of the server:
| Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002
|
Step 4. Send a Chat Completion Request
The vLLM backend of the Triton inference server defines its own model.py, which is different from the model.py defined in the Resnet50 tutorial in terms of input/output signature. The input parameter was called INPUT__0 and the output was called OUTPUT__0. But the input parameter of vLLM has the name text_input and the output is called text_output. Our client should be modified accordingly. Please refer to vLLM model.py for more details.
The following shows the client code for vLLM backend. This client also requires tritonclient and grpcio packages.
Note
You need to appropriately apply a chat template for proper functioning. The chat template must be formatted with system, user, and `assistant roles. Additionally, you must configure the sampling parameters like this:
| sampling_params = {
"temperature": 0.0,
"stop": ["[User]", "[System]", "[Assistant]"], # stop tokens
}
|
Please refer to the sample code.
| simple_vllm_client.py |
|---|
| import asyncio
import json
import numpy as np
import tritonclient.grpc.aio as grpcclient
# Define a simple chat message class
class ChatMessage:
def __init__(self, role, content):
self.role = role
self.content = content
# Apply a simple chat template to the messages
def apply_chat_template(messages):
lines = []
system_msg = ChatMessage(role="system", content="You are a helpful assistant.")
for msg in [system_msg, *messages, ChatMessage(role="assistant", content="")]:
lines.append(f"[{msg.role.capitalize()}]\n{msg.content}")
return "\n".join(lines)
async def try_request():
url = (
"<host and port number of the triton inference server>" # e.g. "localhost:8001"
)
client = grpcclient.InferenceServerClient(url=url, verbose=False)
model_name = "vllm_model"
def create_request(messages, request_id):
prompt = apply_chat_template(messages)
print(f"prompt:\n{prompt}\n---") # print prompt
input = grpcclient.InferInput("text_input", [1], "BYTES")
prompt_data = np.array([prompt.encode("utf-8")])
input.set_data_from_numpy(prompt_data)
stream_setting = grpcclient.InferInput("stream", [1], "BOOL")
stream_setting.set_data_from_numpy(np.array([True]))
sampling_params = {
"temperature": 0.0,
"stop": ["[User]", "[System]", "[Assistant]"], # add stop tokens
}
sampling_parameters = grpcclient.InferInput("sampling_parameters", [1], "BYTES")
sampling_parameters.set_data_from_numpy(
np.array([json.dumps(sampling_params).encode("utf-8")], dtype=object)
)
inputs = [input, stream_setting, sampling_parameters]
output = grpcclient.InferRequestedOutput("text_output")
outputs = [output]
return {
"model_name": model_name,
"inputs": inputs,
"outputs": outputs,
"request_id": request_id,
}
messages = [
ChatMessage(
role="user", content="What is the first letter of English alphabets?"
)
]
request_id = "req-0"
async def requests_gen():
yield create_request(messages, request_id)
response_stream = client.stream_infer(requests_gen())
prompt = apply_chat_template(messages)
is_first_response = True
async for response in response_stream:
result, error = response
if error:
print("Error occurred!")
else:
output = result.as_numpy("text_output")
for i in output:
decoded = i.decode("utf-8")
if is_first_response:
if decoded.startswith(prompt):
decoded = decoded[len(prompt) :]
is_first_response = False
print(decoded, end="", flush=True)
print("\n") # end of stream
asyncio.run(try_request())
|
If the request works properly, you will see an output like the one shown below.
| The first letter of the English alphabet is 'A'. Would you like to know more about the English alphabet? I can help you with that!
|
If you need to change other sampling paramaters (such as temperature, top_p, top_k, max_tokens, early_stopping...) please refer to VLLM's python client.