Skip to content

vLLM Native API

Using vllm-rbln, you can easily utilize the vLLM API for large language models (LLMs). In this tutorial, you will learn how to perform inference with the Llama3-8B and Llama3.1-8B models using vLLM API, utilizing Eager Attention and Flash Attention, respectively.

How to install

First, make sure you have the latest versions of the required packages including rebel-compiler, optimum-rbln, and vllm-rbln. You need access rights to Rebellions' private PyPI server. Please refer to the Installation Guide for more information. You can find the latest version of the packages in the Release Note.

$ pip3 install --extra-index https://pypi.rbln.ai/simple/ "rebel-compiler>=0.7.3" "optimum-rbln>=0.7.3.post2" "vllm-rbln>=0.7.3"

Standard Model Example: Llama3-8B

Step1: Compile Llama3-8B

You need to compile the Llama3-8B model using optimum-rbln.

from optimum.rbln import RBLNLlamaForCausalLM
import os

# Export huggingFace pytorch llama3 model to RBLN compiled model
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
compiled_model = RBLNLlamaForCausalLM.from_pretrained(
    model_id=model_id,
    export=True,                  # To compile the model, this argument must be True
    rbln_max_seq_len=8192,        # Maximum sequence length
    rbln_tensor_parallel_size=4,  # number of ATOM+ for Rebellions Scalable Design (RSD)
    rbln_batch_size=4,            # batch_size > 1 is recommended for continuous batching
)

compiled_model.save_pretrained(os.path.basename(model_id))

Note

You can choose an appropriate batch size for your serving needs. Here, it is set to 4.

Step2: Use vLLM API for Inference

You can use the compiled model with vLLM APIs. The following code shows how to initialize the vLLM engine with the compiled model and run the inference with the engine.

vllm_api_example_llama3_8B.py
import asyncio
from transformers import AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams

# Please make sure the engine configurations match the parameters used when compiling.
model_id = "Meta-Llama-3-8B-Instruct"
max_seq_len = 8192
batch_size = 4

engine_args = AsyncEngineArgs(
  model=model_id,
  device="rbln",
  max_num_seqs=batch_size,
  max_num_batched_tokens=max_seq_len,
  max_model_len=max_seq_len,
  block_size=max_seq_len,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)

tokenizer = AutoTokenizer.from_pretrained(model_id)

def stop_tokens():
  eot_id = next((k for k, t in tokenizer.added_tokens_decoder.items() if t.content == "<|eot_id|>"), None)
  if eot_id is not None:
    return [tokenizer.eos_token_id, eot_id]
  else:
    return [tokenizer.eos_token_id]

sampling_params = SamplingParams(
  temperature=0.0,
  skip_special_tokens=True,
  stop_token_ids=stop_tokens(),
)


# Runs a single inference for an example
async def run_single(chat, request_id):
  results_generator = engine.generate(chat, sampling_params, request_id=request_id)
  final_result = None
  async for result in results_generator:
    # You can use the intermediate `result` here, if needed.
    final_result = result
  return final_result


conversation = [{"role": "user", "content": "What is the first letter of English alphabets?"}]
chat = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
result = asyncio.run(run_single(chat, "123"))
print(result)


async def run_multi(chats):
  tasks = [asyncio.create_task(run_single(chat, i)) for (i, chat) in enumerate(chats)]
  return [await task for task in tasks]

# Runs multiple inferences in parallel
conversations = [
  [{"role": "user", "content": "What is the first letter of English alphabets?"}],
  [{"role": "user", "content": "What is the last letter of English alphabets?"}],
]
chats = [
  tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  for conversation in conversations
]
results = asyncio.run(run_multi(chats))
for result in results:
   assert len(result.outputs) > 0, "Invalid output."
   print(result.outputs[0].text)

You can find more vLLM API usage examples for encoder-decoder models and multi-modal models in RBLN Model Zoo.

Please refer to the vLLM Docs for more information on the vLLM API.

Advanced Example: Llama3.1-8B with Flash Attention

Flash Attention enables efficient handling of long contexts in models like Llama3.1-8B by reducing memory usage and improving throughput. When working with optimum-rbln, Flash Attention can be enabled by adding rbln_kvcache_partition_len parameter when compiling.

Step1: Compile Llama3.1-8B

from optimum.rbln import RBLNLlamaForCausalLM
import os

model_id = "meta-llama/Llama-3.1-8B-Instruct"

# Compile and export
model = RBLNLlamaForCausalLM.from_pretrained(
    model_id=model_id,
    export=True,                        # To compile the model, this argument must be True
    rbln_batch_size=1,                  # Batch size
    rbln_max_seq_len=131_072,           # Maximum sequence length
    rbln_tensor_parallel_size=8,        # Tensor parallelism
    rbln_kvcache_partition_len=16_384,  # Length of KV cache partitions for Flash Attention
)

# Save compiled results to disk
model.save_pretrained(os.path.basename(model_id))

Note

You can choose an appropriate batch size for your serving needs. Here, it is set to 1.

Step2: Use vLLM API for Inference

After compiling, you can use the model with vLLM APIs:

Note

Note that for Flash Attention, block_size should match with rbln_kvcache_partition_len.

vllm_api_example_llama3_1_8B.py
import asyncio
from transformers import AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams


# Please make sure the engine configurations match the parameters used when compiling.
model_id = "Llama-3.1-8B-Instruct"
max_seq_len = 131_072
batch_size = 1
block_size = 16_384  # Should match to `rbln_kvcache_partition_len` for flash attention.


engine_args = AsyncEngineArgs(
  model=model_id,
  device="rbln",
  max_num_seqs=batch_size,
  max_num_batched_tokens=max_seq_len,
  max_model_len=max_seq_len,
  block_size=block_size,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)

tokenizer = AutoTokenizer.from_pretrained(model_id)

def stop_tokens():
  eot_id = next((k for k, t in tokenizer.added_tokens_decoder.items() if t.content == "<|eot_id|>"), None)
  if eot_id is not None:
    return [tokenizer.eos_token_id, eot_id]
  else:
    return [tokenizer.eos_token_id]

sampling_params = SamplingParams(
  temperature=0.0,
  skip_special_tokens=True,
  stop_token_ids=stop_tokens(),
)


# Runs a single inference for an example
async def run_single(chat, request_id):
  results_generator = engine.generate(chat, sampling_params, request_id=request_id)
  final_result = None
  async for result in results_generator:
    # You can use the intermediate `result` here, if needed.
    final_result = result
  return final_result


conversation = [{"role": "user", "content": "What is the first letter of English alphabets?"}]
chat = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
result = asyncio.run(run_single(chat, "123"))
assert len(result.outputs) > 0, "Invalid output."
print(result.outputs[0].text)


async def run_multi(chats):
  tasks = [asyncio.create_task(run_single(chat, i)) for (i, chat) in enumerate(chats)]
  return [await task for task in tasks]

conversations = [
  [{"role": "user", "content": "What is the first letter of English alphabets?"}],
  [{"role": "user", "content": "What is the last letter of English alphabets?"}],
]
chats = [
  tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  for conversation in conversations
]
results = asyncio.run(run_multi(chats))
for result in results:
  assert len(result.outputs) > 0, "Invalid output."
  print(result.outputs[0].text)

Please refer to the vLLM Docs for more information on the vLLM API.