동적 배치 크기를 이용한 추론
개요
실제 서빙 환경에서는 다양한 수의 요청을 효율적으로 처리해야 하는 경우가 많습니다. 예를 들어, 때로는 1개의 요청, 때로는 3개, 5개, 또는 7개의 요청이 동시에 들어올 수 있습니다. 이때 항상 최대 배치 크기(예: 8)를 사용하면 그보다 작은 요청에 대해서는 연산 자원이 낭비됩니다. 이를 해결하기 위해 서로 다른 배치 크기를 가진 여러 디코더를 컴파일하고, 시스템이 실제 요청 수에 가장 가까운 배치 크기의 디코더를 자동으로 선택하도록 할 수 있습니다.
rbln_decoder_batch_sizes
파라미터를 사용하면 컴파일 시 여러 배치 크기를 지정할 수 있습니다. 이를 통해 실제 요청 수에 따라 가장 적절한 디코더를 자동으로 선택하여 처리량과 자원 활용도를 개선할 수 있습니다. 예를 들어, 3개의 요청이 들어오면 배치 크기 4의 디코더가, 7개의 요청이 들어오면 배치 크기 8의 디코더가 선택됩니다.
유사한 최적화 기법
이 방식은 vLLM의 다른 최적화 기법과 유사합니다:
- CUDA Graph:
cudagraph_capture_sizes
- 다양한 배치 크기로 CUDA 그래프를 사전 캡처
- Inductor 컴파일:
compile_sizes
- 특정 입력 크기로 커널을 사전 컴파일
모두 예상되는 입력 크기들을 미리 최적화하여 동적 서빙 성능을 향상시키는 공통 원리를 사용합니다.
환경 설정 및 설치 확인
시작하기 전에 시스템 환경이 올바르게 구성되어 있으며, 필요한 모든 필수 패키지가 설치되어 있는지 확인하십시오. 다음 항목이 포함됩니다:
- 시스템 요구 사항:
- 필수 패키지:
- 설치 명령어:
| pip install optimum-rbln>=0.8.2 vllm-rbln>=0.8.2
pip install --extra-index-url https://pypi.rbln.ai/simple/ rebel-compiler>=0.8.2
|
Note
rebel-compiler
를 사용하려면 RBLN 포털 계정이 필요하니 참고하십시오.
실행
여러 디코더 배치 크기로 모델 컴파일
rbln_decoder_batch_sizes
를 통해서 여러 디코더 배치 크기로 모델을 컴파일할 수 있습니다.
Note
rbln_decoder_batch_sizes
리스트는 자동으로 내림차순으로 정렬됩니다. 모든 값은 rbln_batch_size
보다 작거나 같아야 합니다. 최대 배치 크기가 리스트에 포함되지 않은 경우 자동으로 추가됩니다.
| from optimum.rbln import RBLNLlamaForCausalLM
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# Compile with multiple decoder batch sizes
model = RBLNLlamaForCausalLM.from_pretrained(
model_id=model_id,
export=True, # To compile the model, this argument must be True
rbln_batch_size=8, # Maximum batch size for prefill
rbln_max_seq_len=8192, # Maximum sequence length
rbln_tensor_parallel_size=4, # Tensor parallelism
rbln_decoder_batch_sizes=[8, 4, 1], # Compile decoders for batch sizes 8, 4, and 1
)
# Save compiled results to disk
model.save_pretrained("rbln-dynamic-Llama-3-8B-Instruct")
|
vLLM을 활용한 추론
컴파일된 모델은 vLLM
과 함께 사용할 수 있습니다. 아래 예시는 컴파일된 모델을 사용하여 vLLM
엔진을 설정하고 추론을 수행하는 방법을 보여줍니다.
Note
from_pretrained
의 매개변수는 일반적으로 rbln_batch_size
, rbln_max_seq_len
과 같이 rbln
접두사가 필요합니다.
하지만 rbln_config
내부의 매개변수는 이러한 접두사가 필요하지 않습니다. 동일한 매개변수를 rbln_config에서 설정할 때는 절대로 rbln
접두사를 붙이지 마세요.
| small_batch_conversations = [
[{"role": "user", "content": "What is the first letter of English alphabets?"}]
]
medium_batch_conversations = [
[{"role": "user", "content": "Explain quantum computing in simple terms."}],
[{"role": "user", "content": "What are the benefits of renewable energy?"}],
[{"role": "user", "content": "Describe the process of photosynthesis."}],
[{"role": "user", "content": "How does machine learning work?"}],
]
large_batch_conversations = [
[{"role": "user", "content": "What is the theory of relativity?"}],
[{"role": "user", "content": "Explain blockchain technology."}],
[{"role": "user", "content": "Describe climate change effects."}],
[{"role": "user", "content": "How do neural networks learn?"}],
[{"role": "user", "content": "What is genetic engineering?"}],
[{"role": "user", "content": "Explain the water cycle."}],
[{"role": "user", "content": "How does the internet work?"}],
[{"role": "user", "content": "What is sustainable development?"}],
]
|
| from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
# Please make sure the engine configurations match the parameters used when compiling.
# This example assumes the model was compiled with rbln_decoder_batch_sizes=[8, 4, 1]
model_id = "rbln-dynamic-Llama-3-8B-Instruct"
max_seq_len = 8192
batch_size = 8 # Maximum batch size
llm = LLM(
model=model_id,
device="rbln",
max_num_seqs=batch_size,
max_num_batched_tokens=max_seq_len,
max_model_len=max_seq_len,
block_size=max_seq_len,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
sampling_params = SamplingParams(
temperature=0.0,
skip_special_tokens=True,
stop_token_ids=[tokenizer.eos_token_id],
)
conversations = [
small_batch_conversations,
medium_batch_conversations,
large_batch_conversations,
]
for conversation in conversations:
chats = [
tokenizer.apply_chat_template(
conv,
add_generation_prompt=True,
tokenize=False,
) for conv in conversation
]
outputs = llm.generate(chats, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
print("=====================================")
|
예시 출력:
| The first letter of the English alphabet is "A".'
=====================================
Quantum computing! It's a fascinating topic that can be a bit tricky to
Renewable energy has numerous benefits, including:
1. **Sustainability**:
Photosynthesis is the process by which plants, algae, and some bacteria convert light
Machine learning is a subfield of artificial intelligence that involves training algorithms to learn from
=====================================
The theory of relativity, developed by Albert Einstein, is a fundamental concept in
Blockchain technology is a decentralized, distributed ledger system that enables secure, transparent, and
Climate change is having a profound impact on our planet, and its effects are widespread
Neural networks learn through a process called supervised learning, unsupervised learning,
Genetic engineering, also known as genetic modification (GM), is the process of
The water cycle, also known as the hydrologic cycle, is the continuous process
What a great question! The internet is a complex system, but I'll try
Sustainable development is a concept that was first introduced in the 1987 report
|
참고