Qwen2.5-VL-7B (Multimodal)
Overview
This tutorial explains how to run the multi-modal model on vLLM using multiple RBLN NPUs. For this guide, we will use the Qwen/Qwen2.5-VL-7B-Instruct
model that allows images and videos as inputs.
Setup & Installation
Before you begin, ensure that your system environment is properly configured and that all required packages are installed. This includes:
- System Requirements:
- Packages Requirements:
- Installation Command:
| pip install optimum-rbln>=0.8.2 vllm-rbln>=0.8.2
pip install --extra-index-url https://pypi.rbln.ai/simple/ rebel-compiler>=0.8.2
|
Execution
Model Compilation
You can modify the parameters of the main module as well as the submodules through rbln_config
. For the original source code, refer to the RBLN Model Zoo.
If you need the API reference, see RblnModelConfig.
-
visual submodule:
-
max_seq_lens
: Defines the max sequence length for Vision Transformer (ViT), representing the number of patches in an image.
-
device
: Defines the device allocation for each submodule during runtime.
- As Qwen2.5-VL consists of multiple submodules, loading them all onto a single device may exceed its memory capacity, especially as the batch size increases. By distributing submodules across devices, memory usage can be optimized for efficient runtime performance.
-
main module:
-
export
: Must be True
to compile the model.
-
tensor_parallel_size
: Defines the number of NPUs to be used for inference.
-
kvcache_partition_len
: Defines the length of KV cache partitions for flash attention.
-
max_seq_len
: Defines max position embedding for the language model, must be a multiple of kvcache_partition_len
.
-
device
: Defines the device allocation for other modules except specifically device-allocated submodules.
-
batch_size
: Defines the batch size for compilation.
-
decoder_batch_sizes
: Defines the batch sizes for dynamic batching
.
Note
Select batch size based on model size and NPU specs. Moreover, vllm-rbln
supports Dynamic Batching to ensure optimal throughput and resource utilization. See Dynamic Batching for details.
| from optimum.rbln import RBLNQwen2_5_VLForConditionalGeneration
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
model = RBLNQwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
export=True,
rbln_config={
"visual": {
"max_seq_lens": 6400,
"device": 0,
},
"tensor_parallel_size": 8,
"kvcache_partition_len": 16_384,
"max_seq_len": 114_688,
"device": [0, 1, 2, 3, 4, 5, 6, 7],
"batch_size": 2,
"decoder_batch_sizes": [2, 1],
},
)
model.save_pretrained("rbln-Qwen2-5-7B-Instruct")
|
Inference using vLLM
You can use the compiled model with vLLM
. The example below shows how to set up the vLLM engine
using a compiled model and run inference.
Note
Parameters passed to from_pretrained
typically require the rbln
prefix (e.g., rbln_batch_size
, rbln_max_seq_len
).
In contrast, parameters within rbln_config
should not include the prefix. Avoid using the rbln
prefix when specifying the same parameters in rbln_config
.
| from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams
# If the video is too long
# set `VLLM_ENGINE_ITERATION_TIMEOUT_S` to a higher timeout value.
VIDEO_URLS = [
"https://duguang-labelling.oss-cn-shanghai.aliyuncs.com/qiansun/video_ocr/videos/50221078283.mp4",
]
model_id = "rbln-Qwen2-5-7B-Instruct"
batch_size = 2
max_seq_len = 114688
kvcache_partition_len = 16384
def generate_prompts_video(model_id):
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
video_nums = len(VIDEO_URLS)
messages = [[
{
"role":
"user",
"content": [
{
"type": "video",
"video": VIDEO_URLS[i],
},
{
"type": "text",
"text": "Describe this video."
},
],
},
] for i in range(video_nums)]
texts = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
arr_video_inputs = []
arr_video_kwargs = []
for i in range(video_nums):
image_inputs, video_inputs, video_kwargs = process_vision_info(
messages[i], return_video_kwargs=True)
arr_video_inputs.append(video_inputs)
arr_video_kwargs.append(video_kwargs)
return [{
"prompt": text,
"multi_modal_data": {
"video": video_inputs,
},
"mm_processor_kwargs": {
"min_pixels": 1024 * 14 * 14,
"max_pixels": 5120 * 14 * 14,
**video_kwargs,
},
} for text, video_inputs, video_kwargs in zip(
texts, arr_video_inputs, arr_video_kwargs)]
llm = LLM(
model=model_id,
device="rbln",
max_num_seqs=batch_size,
max_num_batched_tokens=max_seq_len,
max_model_len=max_seq_len,
block_size=kvcache_partition_len
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
sampling_params = SamplingParams(
temperature=0,
ignore_eos=False,
skip_special_tokens=True,
stop_token_ids=[tokenizer.eos_token_id],
max_tokens=200
)
inputs = generate_prompts_video(model_id)
outputs = llm.generate(inputs, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
|
Example Output:
| The video showcases a clear plastic food container being used to hold various fruits, demonstrating its features and benefits.The container is placed on a wooden table with a decorative background that includes a bouquet of artificial flowers and a plate with a mango and some berries.
The video begins with a close-up of the container filled with peaches, accompanied by text highlighting its versatility for different types of fruits such as longan, sliced watermelon, strawberries, and cherries. The container is then shown being opened and closed, emphasizing its easy-to-use design and secure locking mechanism.
Next, the container is filled with cherries, and the text explains that it is made of PET material, ensuring durability and quality. The container is then subjected to a durability test by placing two bricks on top of it, demonstrating its strength and ability to withstand pressure.
The video concludes with a final shot of the container filled with peaches, showcasing its transparency and the clear view of the contents inside. The text reiterates the container
|
References