Skip to content

Inference with Dynamic Batch Sizes

In real-world serving scenarios, you often need to handle varying numbers of requests efficiently. For example, sometimes you might have 1 request, sometimes 3, 5, or 7 requests arriving simultaneously. Instead of always using the maximum batch size (e.g., 8) which wastes computation for smaller request counts, you can compile multiple decoders with different batch sizes and let the system automatically choose the decoder with the batch size closest to the actual number of requests.

The rbln_decoder_batch_sizes parameter allows you to specify multiple batch sizes during compilation. This enables the model to automatically select the most appropriate decoder based on the actual number of requests, improving both throughput and resource utilization. For example, when 3 requests arrive, the decoder with batch size 4 would be selected, and when 7 requests arrive, the decoder with batch size 8 would be used.

Similar Optimization Techniques

This approach is similar to other vLLM optimization techniques:

  • CUDA Graph: cudagraph_capture_sizes - pre-captures CUDA graphs for different batch sizes
  • Inductor Compilation: compile_sizes - pre-compiles kernels for specific input sizes

All techniques share the principle of pre-optimizing for expected input sizes to improve dynamic serving performance.

Compile Model with Multiple Decoder Batch Sizes

from optimum.rbln import RBLNLlamaForCausalLM

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

# Compile with multiple decoder batch sizes
model = RBLNLlamaForCausalLM.from_pretrained(
    model_id=model_id,
    export=True,                        # To compile the model, this argument must be True
    rbln_batch_size=8,                  # Maximum batch size for prefill
    rbln_max_seq_len=8192,              # Maximum sequence length
    rbln_tensor_parallel_size=4,        # Tensor parallelism
    rbln_decoder_batch_sizes=[8, 4, 1], # Compile decoders for batch sizes 8, 4, and 1
)

# Save compiled results to disk
model.save_pretrained("rbln-dynamic-Llama-3-1-8B-Instruct")

Note

The rbln_decoder_batch_sizes list will be automatically sorted in descending order. All values must be less than or equal to rbln_batch_size. If the maximum batch size is not included in the list, it will be automatically added.

Use vLLM API for Efficient Dynamic Batch Inference

There are three test cases in this example. When processing small_batch_conversations or medium_batch_conversations, the batch_size is expected to be 4 to ensure low latency. When processing large_batch_conversations, the batch_size increases to 8 to achieve higher throughput and better resource utilization.

small_batch_conversations = [
    [{"role": "user", "content": "What is the first letter of English alphabets?"}]
]

medium_batch_conversations = [
    [{"role": "user", "content": "Explain quantum computing in simple terms."}],
    [{"role": "user", "content": "What are the benefits of renewable energy?"}],
    [{"role": "user", "content": "Describe the process of photosynthesis."}],
    [{"role": "user", "content": "How does machine learning work?"}],
]

large_batch_conversations = [
    [{"role": "user", "content": "What is the theory of relativity?"}],
    [{"role": "user", "content": "Explain blockchain technology."}],
    [{"role": "user", "content": "Describe climate change effects."}],
    [{"role": "user", "content": "How do neural networks learn?"}],
    [{"role": "user", "content": "What is genetic engineering?"}],
    [{"role": "user", "content": "Explain the water cycle."}],
    [{"role": "user", "content": "How does the internet work?"}],
    [{"role": "user", "content": "What is sustainable development?"}],
]
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

# Please make sure the engine configurations match the parameters used when compiling.
# This example assumes the model was compiled with rbln_decoder_batch_sizes=[8, 4, 1]
model_id = "rbln-dynamic-Llama-3-1-8B-Instruct"
max_seq_len = 8192
batch_size = 8  # Maximum batch size

llm = LLM(
    model=model_id,
    device="rbln",
    max_num_seqs=batch_size,
    max_num_batched_tokens=max_seq_len,
    max_model_len=max_seq_len,
    block_size=max_seq_len,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

sampling_params = SamplingParams(
  temperature=0.0,
  skip_special_tokens=True,
  stop_token_ids=[tokenizer.eos_token_id],
)

conversations = [
    small_batch_conversations,
    medium_batch_conversations,
    large_batch_conversations,
]

for conversation in conversations:
    chats = [
        tokenizer.apply_chat_template(
            conv,
            add_generation_prompt=True,
            tokenize=False,
        ) for conv in conversation
    ]

    outputs = llm.generate(chats, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Benefits of Dynamic Batch Compilation

  1. Better Throughput: The system automatically selects the optimal decoder for each request batch size, improving overall throughput.

  2. Flexible Serving: Handle varying workloads efficiently without being constrained by a single batch size.