vLLM 네이티브 API
vllm-rbln
를 이용하여, vLLM API 를 대형 언어 모델(LLMs)에 손쉽게 활용 할 수 있습니다. 이 튜토리얼에서는 vLLM API를 사용하여 Llama3-8B와 Llama3.1-8B 모델을 각각 Eager Attention과 Flash Attention으로 추론을 수행하는 방법을 배웁니다.
사전 준비
rebel-compiler
, optimum-rbln
, vllm-rbln
패키지의 최신 버전이 설치되어 있어야 합니다. 각 패키지를 설치하기 위해 리벨리온 사설 PyPI 서버의 접근 권한이 필요합니다. 관련 내용은 설치 가이드를 참고하시기 바랍니다. 각 패키지의 최신 버전은 릴리즈 노트에서 확인 할 수 있습니다.
| $ pip3 install --extra-index https://pypi.rbln.ai/simple/ "rebel-compiler>=0.7.3" "optimum-rbln>=0.7.3.post2" "vllm-rbln>=0.7.3"
|
기본 모델 예제: Llama3-8B
1 단계. Llama3-8B 컴파일
튜토리얼 예시에서는 Llama3-8B 모델을 사용하여 진행합니다. 먼저, optimum-rbln을 사용하여 Llama3-8B 모델을 컴파일합니다.
| from optimum.rbln import RBLNLlamaForCausalLM
import os
# HuggingFace PyTorch Llama3 모델을 RBLN 컴파일된 모델로 내보내기
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
compiled_model = RBLNLlamaForCausalLM.from_pretrained(
model_id=model_id,
export=True, # True인 경우 모델 컴파일 진행
rbln_max_seq_len=8192, # Maximum sequence length
rbln_tensor_parallel_size=4, # Rebellions Scalable Design (RSD)를 위한 ATOM+ 개수
rbln_batch_size=4, # Continuous batching을 위함, batch_size > 1 권장
)
# 컴파일 결과를 저장하기
compiled_model.save_pretrained(os.path.basename(model_id))
|
Note
서빙에 사용할 적절한 배치 크기를 선택해야 합니다. 여기에서는 4로 설정합니다.
2 단계. 추론을 위한 vLLM API 사용
vLLM의 API를 사용해 컴파일된 모델을 실행할 수 있습니다. 다음은 앞서 컴파일한 모델을 vLLM 엔진을 통해 초기화를 진행한 후 추론을 수행하는 코드입니다.
vllm_api_example_llama3_8B.py |
---|
| import asyncio
from transformers import AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
# Please make sure the engine configurations match the parameters used when compiling.
model_id = "Meta-Llama-3-8B-Instruct"
max_seq_len = 8192
batch_size = 4
engine_args = AsyncEngineArgs(
model=model_id,
device="rbln",
max_num_seqs=batch_size,
max_num_batched_tokens=max_seq_len,
max_model_len=max_seq_len,
block_size=max_seq_len,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
tokenizer = AutoTokenizer.from_pretrained(model_id)
def stop_tokens():
eot_id = next((k for k, t in tokenizer.added_tokens_decoder.items() if t.content == "<|eot_id|>"), None)
if eot_id is not None:
return [tokenizer.eos_token_id, eot_id]
else:
return [tokenizer.eos_token_id]
sampling_params = SamplingParams(
temperature=0.0,
skip_special_tokens=True,
stop_token_ids=stop_tokens(),
)
# Runs a single inference for an example
async def run_single(chat, request_id):
results_generator = engine.generate(chat, sampling_params, request_id=request_id)
final_result = None
async for result in results_generator:
# You can use the intermediate `result` here, if needed.
final_result = result
return final_result
conversation = [{"role": "user", "content": "What is the first letter of English alphabets?"}]
chat = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
result = asyncio.run(run_single(chat, "123"))
print(result)
async def run_multi(chats):
tasks = [asyncio.create_task(run_single(chat, i)) for (i, chat) in enumerate(chats)]
return [await task for task in tasks]
# Runs multiple inferences in parallel
conversations = [
[{"role": "user", "content": "What is the first letter of English alphabets?"}],
[{"role": "user", "content": "What is the last letter of English alphabets?"}],
]
chats = [
tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
for conversation in conversations
]
results = asyncio.run(run_multi(chats))
for result in results:
assert len(result.outputs) > 0, "Invalid output."
print(result.outputs[0].text)
|
vLLM API를 이용해 여러 인코더-디코더 모델이나 멀티모달 모델을 실행할 수 있습니다. 모델 주에서 가능한 모델을 확인하실 수 있습니다.
vLLM API에 대한 더 자세한 내용은 vLLM 문서를 참고하시기 바랍니다.
응용 예제: Flash Attention 을 이용한 Llama3.1-8B
Flash Attention은 메모리 사용량을 줄이고 처리량을 향상시켜 Llama3.1-8B
등의 모델에서 긴 컨텍스트를 효율적으로 처리할 수 있습니다. optimum-rbln
으로 컴파일할 때 rbln_kvcache_partition_len
매개변수를 추가하면 Flash Attention을 활성화할 수 있습니다.
1 단계. Llama3.1-8B 컴파일
| from optimum.rbln import RBLNLlamaForCausalLM
import os
model_id = "meta-llama/Llama-3.1-8B-Instruct"
# HuggingFace PyTorch Llama3.1 모델을 RBLN 컴파일된 모델로 내보내기
model = RBLNLlamaForCausalLM.from_pretrained(
model_id=model_id,
export=True, # True인 경우 모델 컴파일 진행
rbln_batch_size=1, # Batch size
rbln_max_seq_len=131_072, # Maximum sequence length
rbln_tensor_parallel_size=8, # Tensor parallelism
rbln_kvcache_partition_len=16_384, # Flash Attention 을 사용하기 위한 KV cache 파티션 크기
)
# 컴파일 결과를 저장하기
model.save_pretrained(os.path.basename(model_id))
|
Note
배치 크기는 요구사항에 적합하게 선택하면 됩니다. 여기서는 1로 설정합니다.
2 단계. 추론을 위한 vLLM API 사용
컴파일 후에는 vLLM API로 모델을 사용할 수 있습니다:
Note
Flash Attention을 사용하기 위해서는, block_size
가 rbln_kvcache_partition_len
과 일치해야 합니다.
vllm_api_example_llama3_1_8B.py |
---|
| import asyncio
from transformers import AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
# Please make sure the engine configurations match the parameters used when compiling.
model_id = "Llama-3.1-8B-Instruct"
max_seq_len = 131_072
batch_size = 1
block_size = 16_384 # Should match to `rbln_kvcache_partition_len` for flash attention.
engine_args = AsyncEngineArgs(
model=model_id,
device="rbln",
max_num_seqs=batch_size,
max_num_batched_tokens=max_seq_len,
max_model_len=max_seq_len,
block_size=block_size,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
tokenizer = AutoTokenizer.from_pretrained(model_id)
def stop_tokens():
eot_id = next((k for k, t in tokenizer.added_tokens_decoder.items() if t.content == "<|eot_id|>"), None)
if eot_id is not None:
return [tokenizer.eos_token_id, eot_id]
else:
return [tokenizer.eos_token_id]
sampling_params = SamplingParams(
temperature=0.0,
skip_special_tokens=True,
stop_token_ids=stop_tokens(),
)
# Runs a single inference for an example
async def run_single(chat, request_id):
results_generator = engine.generate(chat, sampling_params, request_id=request_id)
final_result = None
async for result in results_generator:
# You can use the intermediate `result` here, if needed.
final_result = result
return final_result
conversation = [{"role": "user", "content": "What is the first letter of English alphabets?"}]
chat = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
result = asyncio.run(run_single(chat, "123"))
assert len(result.outputs) > 0, "Invalid output."
print(result.outputs[0].text)
async def run_multi(chats):
tasks = [asyncio.create_task(run_single(chat, i)) for (i, chat) in enumerate(chats)]
return [await task for task in tasks]
conversations = [
[{"role": "user", "content": "What is the first letter of English alphabets?"}],
[{"role": "user", "content": "What is the last letter of English alphabets?"}],
]
chats = [
tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
for conversation in conversations
]
results = asyncio.run(run_multi(chats))
for result in results:
assert len(result.outputs) > 0, "Invalid output."
print(result.outputs[0].text)
|
vLLM API에 대한 더 자세한 내용은 vLLM 문서를 참고하시기 바랍니다.