콘텐츠로 이동

Flash Attention을 이용한 Llama3.1-8B

개요

이 튜토리얼은 여러 개의 RBLN NPU를 사용하여 vLLM에서 모델을 실행하는 방법을 설명합니다. 이 가이드에서는 meta-llama/Llama-3.1-8B-Instruct 모델을 사용합니다.

Flash Attention은 메모리 효율성과 처리량(throughput)을 향상시켜, 긴 문맥을 처리하는 모델의 성능을 개선합니다. 컴파일 시 rbln_kvcache_partition_len 파라미터를 추가함으로써 Flash Attention 모드를 활성화할 수 있습니다.

환경 설정 및 설치 확인

시작하기 전에 시스템 환경이 올바르게 구성되어 있으며, 필요한 모든 필수 패키지가 설치되어 있는지 확인하십시오. 다음 항목이 포함됩니다:

Note

rebel-compiler를 사용하려면 RBLN 포털 계정이 필요하니 참고하십시오.

Note

HuggingFace의 meta-llama/Llama-3.1-8B-Instruct 모델은 접근이 제한되어 있습니다. 접근 권한을 부여받은 후, 아래와 같이 huggingface-cli 명령어를 사용하여 로그인할 수 있습니다:

$ huggingface-cli login

    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) .
Token: *****

실행

모델 컴파일

먼저, optimum-rbln에서 RBLNLlamaForCausalLM 클래스를 임포트합니다. 이 클래스의 from_pretrained() 메서드는 HuggingFace Hub에서 Llama 3.1 모델을 다운로드하고 RBLN Compiler를 사용해 컴파일합니다. 모델을 내보낼 때는 다음과 같은 파라미터를 지정해야 합니다:

  • export: 모델을 컴파일하려면 True로 설정해야 합니다.

  • rbln_batch_size: 컴파일을 위한 배치 크기를 정의합니다.

  • rbln_max_seq_len: 최대 시퀀스 길이를 정의합니다.

  • rbln_tensor_parallel_size: 추론에 사용할 NPU의 수를 정의합니다.

  • rbln_kvcache_partition_len: Flash Attention을 위한 KV 캐시 파티션의 길이를 정의합니다.rbln_max_seq_lenrbln_kvcache_partition_len의 배수여야 하며, 그보다 큰 값이어야 합니다.

컴파일 후에는 save_pretrained() 메서드를 사용하여 모델 아티팩트를 디스크에 저장합니다. 이 과정은 컴파일된 모델을 포함하는 디렉터리(예: rbln-Llama-3-1-8B-Instruct)를 생성합니다.

Note

모델 크기와 NPU 사양에 따라 적절한 배치 사이즈를 선택하세요. 또한, vllm-rbln은 최적의 처리량과 자원 활용을 보장하기 위해 동적 배치(Dynamic Batching)을 지원합니다. 자세한 내용은 Dynamic Batching를 참고하세요.

from optimum.rbln import RBLNLlamaForCausalLM

model_id = "meta-llama/Llama-3.1-8B-Instruct"

# Compile and export
model = RBLNLlamaForCausalLM.from_pretrained(
    model_id=model_id,
    export=True,
    rbln_batch_size=1,
    rbln_max_seq_len=131_072,
    rbln_tensor_parallel_size=8,
    rbln_kvcache_partition_len=16_384,
)

# Save compiled results to disk
model.save_pretrained("rbln-Llama-3-1-8B-Instruct")

vLLM을 활용한 추론

컴파일된 모델은 vLLM과 함께 사용할 수 있습니다. 아래 예시는 컴파일된 모델을 사용하여 vLLM 엔진을 설정하고 추론을 수행하는 방법을 보여줍니다.

Note

Flash Attention을 사용하기 위해서는, block_sizerbln_kvcache_partition_len과 일치해야 합니다.

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

# Please make sure the engine configurations match the parameters used when compiling.
model_id = "rbln-Llama-3-1-8B-Instruct"
max_seq_len = 131_072
batch_size = 1
block_size = 16_384  # Should match to `rbln_kvcache_partition_len` for flash attention.

llm = LLM(
    model=model_id,
    device="rbln",
    max_num_seqs=batch_size,
    max_num_batched_tokens=max_seq_len,
    max_model_len=max_seq_len,
    block_size=block_size,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

sampling_params = SamplingParams(
    temperature=0.0,
    skip_special_tokens=True,
    stop_token_ids=[tokenizer.eos_token_id],
    max_tokens=100
)

conversation = [
    {
        "role": "user",
        "content": "Who are you?"
    }
]

chat = tokenizer.apply_chat_template(
    conversation, 
    add_generation_prompt=True,
    tokenize=False
)

outputs = llm.generate(chat, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(generated_text)

예시 출력:

I'm an artificial intelligence model known as Llama. Llama stands for "Large Language Model Meta AI."

참고