Skip to content

Llama2-7B with Continuous Batching

To serve Large Language Models (LLMs) with maximum utilization, a popular serving optimization technique known as continuous batching is required.

This tutorial will guide you through implementing continuous batching with vllm-rbln to improve LLM serving costs.

You can check out the actual commands required to compile the model and initialize triton vllm_backend on our model zoo.

Prerequisites

Note

Since the vllm-rbln package does not depend on the vllm package, duplicate installations may cause operational issues. If you installed the vllm package after vllm-rbln, please reinstall the vllm-rbln package to ensure proper functionality.

Compile Llama2-7B

You need to compile the Llama2-7B model using optimum-rbln.

from optimum.rbln import RBLNLlamaForCausalLM

# Export huggingFace pytorch llama2 model to RBLN compiled model
model_id = "meta-llama/Llama-2-7b-chat-hf"
compiled_model = RBLNLlamaForCausalLM.from_pretrained(
    model_id=model_id,
    export=True,
    rbln_max_seq_len=4096,
    rbln_tensor_parallel_size=4,  # number of ATOM+ for Rebellions Scalable Design (RSD)
    rbln_batch_size=4,            # batch_size > 1 is recommended for continuous batching
)

compiled_model.save_pretrained("rbln-Llama-2-7b-chat-hf")

Choose an appropriate batch size for your serving needs. Here, it is set to 4.

Triton Inference Server with vLLM enabled

vLLM provides the backend for the Triton Inference Server.

If you are using Backend.AI, refer to Step 1. If you are using an on-premise server, skip Step 1 and proceed directly to Step 2.

Step 1. Setting Up the Backend.AI Environment

  1. Start a session via Backend.AI.
  2. Select Triton Server (ngc-triton) as your environment. You can see the version of 24.12 / vllm / x86_64 / python-py3.

Step 2. Prepare Nvidia Triton vllm_backend and Modify Model Configurations for Llama2-7B

A. Clone the Nvidia Triton Inference Server vllm_backend repository:

$ git clone https://github.com/triton-inference-server/vllm_backend.git -b r24.12

Note

Nvidia Triton Inference Server's vLLM backend has its own model.py. Separate user-defined model.py is not required.

B. Place the precompiled rbln-Llama-2-7b-chat-hf directory into the cloned vllm_backend/samples/model_repository/vllm_model/1 directory:

$ cp -R /PATH/TO/YOUR/rbln-Llama-2-7b-chat-hf /PATH/TO/YOUR/CLONED/vllm_backend/samples/model_repository/vllm_model/1

Your directory should look like the following at this point:

+-- vllm_backend/                    # Main directory for Triton vLLM backend
|   +-- samples/                     # Application example directory
|   |   +-- model_repository/        # Triton model repository (model_repositories)
|   |   |   +-- vllm_model/          # Individual model for Triton serving
|   |   |   |   +-- config.pbtxt     # Triton model configuration file
|   |   |   |   +-- 1/               # Version directory
|   |   |   |   |   +-- model.json   # Model configuration for vLLM serving
|   |   |   |   |   +-- rbln-Llama-2-7b-chat-hf/  # rbln compiled model files
|   |   |   |   |   |   +-- compiled_model.rbln
|   |   |   |   |   |   +-- config.json
|   |   |   |   |   |   +-- (other model files)
|   |   +-- (other example files)
|   +-- (other backend files)

Note

  1. The vLLM backend for Nvidia Triton Server doesn't need a model.py file unlike other vision model backends. All model processing logic is pre-included in the Docker container at backends/vllm/model.py, so you only need model.json for configuration.
  2. You can either use the default config.pbtxt from the repository or create a new one using the template below. Note that input and output formats must match exactly as shown, since they're required by the vLLM backend (see Step 4: gRPC Client Inference Request).
    name: "vllm_model"
    backend: "vllm"

    input [
        {
            name: "text_input"
            data_type: TYPE_STRING
            dims: [ 1 ]
        },
        {
            name: "stream"
            data_type: TYPE_BOOL
            dims: [ 1 ]
        }
    ]

    output [
        {
            name: "text_output"
            data_type: TYPE_STRING
            dims: [ 1 ]
        }
    ]

    instance_group [
        {
            count: 1
            kind: KIND_MODEL
        }
    ]

C. Modify model.json

Modify vllm_backend/samples/model_repository/vllm_model/1/model.json.

1
2
3
4
5
6
7
8
{
    "model": "/ABSOLUTE/PATH/TO/rbln-Llama-2-7b-chat-hf",
    "device": "rbln",
    "max_num_seqs": 4,
    "max_num_batched_tokens": 4096,
    "max_model_len": 4096,
    "block_size": 4096
}
  • model: Compile model's absolute path.
  • device: Device type for vLLM execution. Please set this to rbln.
  • max_num_seqs: Maximum number of sequences per iteration. This MUST match the compiled batch_size
  • When targeting RBLN device, the max_model_len, block_size, and max_num_batched_tokens fields should be set to the same value as the max sequence length.

Step 3. Run the Inference Server

We are now ready to run the inference server. If you are using Backend.AI, please refer to the A. Backend.AI section. If you are not a Backend.AI user, proceed to the B. On-premise server section.

A. Backend.AI

Before proceeding, install the required dependencies:

$ pip3 install -i https://pypi.rbln.ai/simple/ "rebel-compiler>=0.7.1" "optimum-rbln>=0.2.0" "vllm-rbln>=0.2.0"
Start the Triton Server:
$ tritonserver --model-repository PATH/TO/YOUR/vllm_backend/samples/model_repository

You will see the following messages that indicate successful initiation of the server:

1
2
3
Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002

B. On-premise server

If you are not using Backend.AI, follow these steps to start the inference server in the Docker container. (Backend.AI users can skip to Step 5.)

To access the RBLN NPU devices, the inference server container must be run in privileged mode. Add a mount option for the cloned vllm_backend repository as below:

1
2
3
sudo docker run --privileged --shm-size=1g --ulimit memlock=-1 \
   -v /PATH/TO/YOUR/vllm_backend:/opt/tritonserver/vllm_backend \
   -p 8000:8000 -p 8001:8001 -p 8002:8002 --ulimit stack=67108864 -ti rebellions/tritonserver:24.12-vllm-python-py3

Install the required dependencies inside the container:

$ pip3 install -i https://pypi.rbln.ai/simple/ "rebel-compiler>=0.7.1" "optimum-rbln>=0.2.0" "vllm-rbln>=0.2.0"
Start the Triton Server inside the container:
$ tritonserver --model-repository /opt/tritonserver/vllm_backend/samples/model_repository

You will see the following messages indicating successful initiation of the server:

1
2
3
Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002

Step 4. Requesting Inference via gRPC API

vLLM backend has its own model.py, while we defined model.py in tutorial Resnet50. The input parameter was called INPUT__0 and the output was called OUTPUT__0. But the input parameter of vLLM has the name text_input and the output is called text_output. Our client should be modified accordingly. Please refer to vLLM model.py for more details.

The following shows the client code for vLLM backend. This client also requires tritonclient and grpcio packages.

simple_vllm_client.py
import asyncio
import numpy as np
import tritonclient.grpc.aio as grpcclient

async def try_request():
  url = "<host and port number of the triton inference server>"  # e.g. "localhost:8001"
  client = grpcclient.InferenceServerClient(url=url, verbose=False)

  model_name = "vllm_model"

  def create_request(prompt, request_id):
    input = grpcclient.InferInput("text_input", [1], "BYTES")
    prompt_data = np.array([prompt.encode("utf-8")])
    input.set_data_from_numpy(prompt_data)

    stream_setting = grpcclient.InferInput("stream", [1], "BOOL")
    stream_setting.set_data_from_numpy(np.array([True]))

    inputs = [input, stream_setting]

    output = grpcclient.InferRequestedOutput("text_output")
    outputs = [output]

    return {
      "model_name": model_name,
      "inputs": inputs,
      "outputs": outputs,
      "request_id": request_id
    }

  prompt = "What is the first letter of English alphabets?"

  async def requests_gen():
    yield create_request(prompt, "req-0")

  response_stream = client.stream_infer(requests_gen())

  async for response in response_stream:
    result, error = response
    if error:
      print("Error occurred!")
    else:
      output = result.as_numpy("text_output")
      for i in output:
          decoded = i.decode("utf-8")
          if decoded.startswith(prompt):
              decoded = decoded[len(prompt):]
          print(decoded, end="", flush=True)

asyncio.run(try_request())

If you need to change other sampling paramaters (such as temperature, top_p, top_k, max_tokens, early_stopping...) please refer to VLLM's python client.