Skip to content

Llama3.1-8B with Flash Attention

Overview

This tutorial guides how to run the model with Flash Attention on vLLM. For this guide, we will use meta-llama/Llama-3.1-8B-Instruct model.

Flash Attention improves memory efficiency and throughput, enabling better performance for models handling long contexts. In optimum-rbln, Flash Attention mode is activated by adding the rbln_kvcache_partition_len parameter during compilation.

Note

Rebellions Scalable Design (RSD) is available on ATOM™+ (RBLN-CA12 and RBLN-CA22) and ATOM™-Max (RBLN-CA25). You can check your RBLN NPU type using the rbln-stat command.

Note

Llama 3.1 is licensed under the LLAMA Community License, Copyright (c) Meta Platforms, Inc. All Rights Reserved.

Setup & Installation

Before you begin, ensure that your system environment is properly configured and that all required packages are installed. This includes:

Note

Please note that rebel-compiler requires an RBLN Portal account.

Note

Please note that the meta-llama/Meta-Llama-3.1-8B-Instruct model on HuggingFace has restricted access. Once access is granted, you can log in using the huggingface-cli command as shown below:

$ huggingface-cli login

    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) .
Token: *****

Compile Llama3.1-8B

To begin, import the RBLNLlamaForCausalLM class from optimum-rbln. This class's from_pretrained() method downloads the Llama 3.1 model from the HuggingFace Hub and compiles it using the RBLN Compiler. When exporting the model, specify the following parameters:

  • export: Must be True to compile the model.
  • rbln_batch_size: Defines the batch size for compilation.
  • rbln_max_seq_len: Defines the maximum sequence length.
  • rbln_tensor_parallel_size: Defines the number of NPUs to be used for inference.
  • rbln_kvcache_partition_len: Defines the length of KV cache partitions for flash attention. rbln_max_seq_len must be multiple of rbln_kvcache_partition_len and larger than rbln_kvcache_partition_len.

After compilation, save the model artifacts to disk using the save_pretrained() method. This will create a directory (e.g., rbln-Llama-3-1-8B-Instruct) containing the compiled model.

from optimum.rbln import RBLNLlamaForCausalLM

model_id = "meta-llama/Llama-3.1-8B-Instruct"

# Compile and export
model = RBLNLlamaForCausalLM.from_pretrained(
    model_id=model_id,
    export=True,
    rbln_batch_size=1,
    rbln_max_seq_len=131_072,
    rbln_tensor_parallel_size=8,
    rbln_kvcache_partition_len=16_384,
)

# Save compiled results to disk
model.save_pretrained("rbln-Llama-3-1-8B-Instruct")

Note

You can select an appropriate batch size based on the model size and the specifications of the NPUs. Since vllm-rbln supports continuous batching, it’s important to configure the batch size carefully to ensure optimal throughput and resource utilization. For information on enabling dynamic batching, see Inference with Dynamic Batch Sizes.

Use vLLM API for Inference

After compiling, you can use the model with vLLM APIs:

Note

Note that for Flash Attention, block_size should match with rbln_kvcache_partition_len.

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

# Please make sure the engine configurations match the parameters used when compiling.
model_id = "rbln-Llama-3-1-8B-Instruct"
max_seq_len = 131_072
batch_size = 1
block_size = 16_384  # Should match to `rbln_kvcache_partition_len` for flash attention.

llm = LLM(
    model=model_id,
    device="rbln",
    max_num_seqs=batch_size,
    max_num_batched_tokens=max_seq_len,
    max_model_len=max_seq_len,
    block_size=block_size,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

sampling_params = SamplingParams(
  temperature=0.0,
  skip_special_tokens=True,
  stop_token_ids=[tokenizer.eos_token_id],
)

conversation = [
    {
        "role": "user",
        "content": "What is the first letter of English alphabets?"
    }
]

chat = tokenizer.apply_chat_template(
  conversation, 
  add_generation_prompt=True,
  tokenize=False
)

outputs = llm.generate(chat, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

You can find more vLLM API usage examples for encoder-decoder models and multi-modal models in RBLN Model Zoo.

Please refer to the vLLM Docs for more information on the vLLM API.