콘텐츠로 이동

vLLM 네이티브 API

vllm-rbln를 이용하여, vLLM API 를 대형 언어 모델(LLMs)에 손쉽게 활용 할 수 있습니다.

사전 준비

rebel-compiler, optimum-rbln, vllm-rbln 패키지의 최신 버전이 설치되어 있어야 합니다. 각 패키지를 설치하기 위해 리벨리온 사설 PyPI 서버의 접근 권한이 필요합니다. 관련 내용은 설치 가이드를 참고 바랍니다. 각 패키지의 최신 버전은 릴리즈 노트에서 확인 할 수 있습니다.

$ pip3 install -i https://pypi.rbln.ai/simple/ "rebel-compiler>=0.7.1" "optimum-rbln>=0.2.0" "vllm-rbln>=0.2.0"

Llama2-7B 컴파일

OpenAI 호환 API 서버의 Tutorial 예제로 Llama2-7B를 사용합니다. 먼저, optimum-rbln을 사용하여 Llama2-7B 모델을 컴파일합니다.

from optimum.rbln import RBLNLlamaForCausalLM

# HuggingFace PyTorch Llama2 모델을 RBLN 컴파일된 모델로 내보내기
model_id = "meta-llama/Llama-2-7b-chat-hf"
compiled_model = RBLNLlamaForCausalLM.from_pretrained(
    model_id=model_id,
    export=True,
    rbln_max_seq_len=4096,
    rbln_tensor_parallel_size=4,  # Rebellions Scalable Design (RSD)를 위한 ATOM+ 개수
    rbln_batch_size=4,            # Continuous batching을 위해 batch_size > 1 권장
)

compiled_model.save_pretrained("rbln-Llama-2-7b-chat-hf")

서빙에 사용할 적절한 배치 크기를 선택해야 합니다. 여기에서는 4로 설정합니다.

vLLM API 사용

vLLM의 API를 사용해 컴파일된 모델을 실행할 수 있습니다. 다음 코드는 앞서 컴파일된 모델을 이용해 vLLM 엔진을 초기화하고 초기화된 엔진으로 인퍼런스를 수행하는 방법을 보여줍니다.

vllm_api_example.py
import asyncio
from transformers import AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams


# Please make sure the engine configurations match the parameters used when compiling.
model_id = "rbln-Llama-2-7b-chat-hf"
max_seq_len = 4096
batch_size = 4

engine_args = AsyncEngineArgs(
  model=model_id,
  device="rbln",
  max_num_seqs=batch_size,
  max_num_batched_tokens=max_seq_len,
  max_model_len=max_seq_len,
  block_size=max_seq_len,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)

tokenizer = AutoTokenizer.from_pretrained(model_id)

def stop_tokens():
  eot_id = next((k for k, t in tokenizer.added_tokens_decoder.items() if t.content == "<|eot_id|>"), None)
  if eot_id is not None:
    return [tokenizer.eos_token_id, eot_id]
  else:
    return [tokenizer.eos_token_id]

sampling_params = SamplingParams(
  temperature=0.0,
  skip_special_tokens=True,
  stop_token_ids=stop_tokens(),
)


# Runs a single inference for an example
async def run_single(chat, request_id):
  results_generator = engine.generate(chat, sampling_params, request_id=request_id)
  final_result = None
  async for result in results_generator:
    # You can use the intermediate `result` here, if needed.
    final_result = result
  return final_result


conversation = [{"role": "user", "content": "What is the first letter of English alphabets?"}]
chat = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
result = asyncio.run(run_single(chat, "123"))
print(result)


async def run_multi(chats):
  tasks = [asyncio.create_task(run_single(chat, i)) for (i, chat) in enumerate(chats)]
  return [await task for task in tasks]

# Runs multiple inferences in parallel
conversations = [
  [{"role": "user", "content": "What is the first letter of English alphabets?"}],
  [{"role": "user", "content": "What is the last letter of English alphabets?"}],
]
chats = [
  tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  for conversation in conversations
]
results = asyncio.run(run_multi(chats))
print(results)

vLLM API 를 이용해 인코더-디코더 모델이나 멀티모달 모델을 사용하는 예제 코드를 확인하시려면 모델 주를 참고 바랍니다.

vLLM API에 대한 더 자세한 내용은 vLLM 문서를 참고 바랍니다.