Skip to content

Llama3.1-8B with Flash Attention

Overview

This tutorial guides how to run the model with Flash Attention on vLLM. For this guide, we will use meta-llama/Llama-3.1-8B-Instruct model.

Flash Attention improves memory efficiency and throughput, enabling better performance for models handling long contexts. Flash Attention mode is activated by adding the rbln_kvcache_partition_len parameter during compilation.

Setup & Installation

Before you begin, ensure that your system environment is properly configured and that all required packages are installed. This includes:

Note

Please note that rebel-compiler requires an RBLN Portal account.

Note

Please note that the meta-llama/Llama-3.1-8B-Instruct model on HuggingFace has restricted access. Once access is granted, you can log in using the huggingface-cli command as shown below:

$ huggingface-cli login

    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) .
Token: *****

Execution

Model Compilation

To begin, import the RBLNLlamaForCausalLM class from optimum-rbln. This class's from_pretrained() method downloads the Llama 3.1 model from the HuggingFace Hub and compiles it using the RBLN Compiler. When exporting the model, specify the following parameters:

  • export: Must be True to compile the model.

  • rbln_batch_size: Defines the batch size for compilation.

  • rbln_max_seq_len: Defines the maximum sequence length.

  • rbln_tensor_parallel_size: Defines the number of NPUs to be used for inference.

  • rbln_kvcache_partition_len: Defines the length of KV cache partitions for flash attention. rbln_max_seq_len must be multiple of rbln_kvcache_partition_len and larger than rbln_kvcache_partition_len.

After compilation, save the model artifacts to disk using the save_pretrained() method. This will create a directory (e.g., rbln-Llama-3-1-8B-Instruct) containing the compiled model.

Note

Select batch size based on model size and NPU specs. Moreover, vllm-rbln supports Dynamic Batching to ensure optimal throughput and resource utilization. See Dynamic Batching for details.

from optimum.rbln import RBLNLlamaForCausalLM

model_id = "meta-llama/Llama-3.1-8B-Instruct"

# Compile and export
model = RBLNLlamaForCausalLM.from_pretrained(
    model_id=model_id,
    export=True,
    rbln_batch_size=1,
    rbln_max_seq_len=131_072,
    rbln_tensor_parallel_size=8,
    rbln_kvcache_partition_len=16_384,
)

# Save compiled results to disk
model.save_pretrained("rbln-Llama-3-1-8B-Instruct")

Inference using vLLM

You can use the compiled model with vLLM. The example below shows how to set up the vLLM engine using a compiled model and run inference.

Note

Note that for Flash Attention, block_size should match with rbln_kvcache_partition_len.

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

# Please make sure the engine configurations match the parameters used when compiling.
model_id = "rbln-Llama-3-1-8B-Instruct"
max_seq_len = 131_072
batch_size = 1
block_size = 16_384  # Should match to `rbln_kvcache_partition_len` for flash attention.

llm = LLM(
    model=model_id,
    device="rbln",
    max_num_seqs=batch_size,
    max_num_batched_tokens=max_seq_len,
    max_model_len=max_seq_len,
    block_size=block_size,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

sampling_params = SamplingParams(
    temperature=0.0,
    skip_special_tokens=True,
    stop_token_ids=[tokenizer.eos_token_id],
    max_tokens=100
)

conversation = [
    {
        "role": "user",
        "content": "Who are you?"
    }
]

chat = tokenizer.apply_chat_template(
    conversation, 
    add_generation_prompt=True,
    tokenize=False
)

outputs = llm.generate(chat, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(generated_text)

Example Output:

I'm an artificial intelligence model known as Llama. Llama stands for "Large Language Model Meta AI."

References