Qwen3 (커스텀 커널)¶
개요¶
이 튜토리얼은 vLLM RBLN에서 Triton 커스텀 커널을 활성화하고 Qwen3-0.6B 예제를 실행하는 방법을 설명합니다.
모든 경로는 vllm-rbln 저장소 루트를 기준으로 작성됩니다.
모델 호출 지점은 torch.ops.rbln_triton_ops.*를 사용하고, 실행 경로를 처음부터 끝까지 검증합니다.
환경 설정 및 설치 확인¶
시작하기 전에 시스템 환경이 올바르게 구성되어 있으며, 필요한 모든 필수 패키지가 설치되어 있는지 확인하십시오. 다음 항목이 포함됩니다:
- 시스템 요구 사항:
- 필수 패키지:
- RBLN Compiler
- vllm-rbln – 1단계에서 설치
- 설치 명령어:
Note
rebel-compiler를 사용하려면 RBLN 포털 계정이 필요합니다.
실행¶
단계 구분
- Action(실행): 예제를 실행하기 위해 필요한 단계입니다.
- Reference(참고): 배경/연동 설명입니다. 필요할 때 참고하세요.
1) 실행: 저장소 루트에서 시작¶
vllm-rbln을 클론한 뒤 저장소 루트로 이동합니다.
아래 단계는 모두 저장소 루트 vllm-rbln/에서 작업한다고 가정합니다.
2) 참고: 연동 지점 파악¶
아래 표에서 등록, 호출, 실행을 위한 파일들을 찾을 수 있습니다.
| 경로 | 이유 | 하는 일 |
|---|---|---|
vllm_rbln/__init__.py |
초기화 시 import hook(register_ops()). |
register_ops()가 초기화 시 vllm_rbln.triton_kernels.*를 import하므로 연산자가 등록됩니다. 새 연산자 모듈을 추가할 때는 여기서 import되도록 하세요. |
vllm_rbln/triton_kernels/ |
연산자 정의, 스키마, 등록 패턴. | 참고: 연산자 인터페이스 확인용(3단계). 이 가이드에서는 이 파일들을 수정하지 않고 진행해도 됩니다. |
vllm_rbln/v1/attention/backends/flash_attention.py |
모델 호출 지점(연산자 호출 연결). | 참고: RBLN_USE_CUSTOM_KERNEL=1일 때 torch.ops.rbln_triton_ops.*를 선택하는지 확인(4단계). |
examples/experimental/offline_inference_basic.py |
실행 진입점. | 환경 변수를 켠 상태로 컴파일 및 추론 실행(5단계). |
3) 참고: 연산자 인터페이스 및 등록¶
이 단계는 참고용으로, vLLM RBLN에서 쓰는 연산자 등록을 이해하는 데 사용합니다.
용어(및 코드와의 대응)
| 용어 | 의미 | 코드 위치 |
|---|---|---|
| 커스텀 커널 | torch.ops.rbln_triton_ops.* 아래 PyTorch 연산자로 모델 코드에 노출되는 Triton 커널 구현. |
vllm_rbln/triton_kernels/ 아래 @triton.jit 커널 함수 |
| 등록(Registration) | PyTorch 연산자를 torch.ops.rbln_triton_ops.<op_name>으로 사용 가능하게 함. |
@triton_op("rbln_triton_ops::<op_name>", ...) 및 @register_fake("rbln_triton_ops::<op_name>") |
| 호출(Invocation) | 런타임에 모델 코드에서 연산자를 호출함. | torch.ops.rbln_triton_ops.<op_name>(...) 호출 지점 |
| 인터페이스 항목 | 중요한 이유 |
|---|---|
연산자 이름 (<op_name>) |
모델이 Operator 이름으로 연산자를 찾습니다. 이름을 바꾸면 찾을 수 없습니다. |
| 연산자 스키마(시그니처) | 모델이 고정 인자 순서로 텐서를 넘깁니다. 불일치 시 런타임/컴파일 시 실패합니다. |
| 텐서 dtype / shape / layout | 커널은 특정 dtype/layout을 가정하는 경우가 있습니다. 불일치 시 컴파일 실패나 잘못된 결과가 나올 수 있습니다. |
참고
- 스키마 호환성: 일부 연산자는 스키마를 유지하기 위해 placeholder 텐서 인자(예:
dummy0)를 둡니다.
예: @triton.jit 커널 구현
# 발췌: `vllm_rbln/triton_kernels/attention.py`
import torch
from rebel import triton
from rebel.triton import language as tl
@triton.jit
def flash_attention_naive_prefill(
query,
key,
value,
kv_cache,
mask,
output,
qk_scale,
seq_idx,
block_table,
block_size,
H: tl.constexpr,
G: tl.constexpr,
D: tl.constexpr,
L: tl.constexpr,
NB: tl.constexpr,
P: tl.constexpr,
C: tl.constexpr,
B: tl.constexpr,
DIM_BLOCK_TABLE: tl.constexpr,
):
NP: tl.constexpr = C // P
for batch_id in tl.static_range(0, NB, 1):
Q_block_ptr = tl.make_block_ptr(
base=query,
shape=(NB, H, G, L, D),
strides=(H * G * L * D, G * L * D, L * D, D, 1),
offsets=(batch_id, 0, 0, 0, 0),
block_shape=(1, H, G, L, D),
order=(4, 3, 2, 1, 0),
)
# ... more block pointers + compute ...
예: @triton_op 연산자 등록
# 발췌: `vllm_rbln/triton_kernels/attention.py`
@triton_op("rbln_triton_ops::flash_attention_naive_prefill", mutates_args=())
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
mask: torch.Tensor,
qk_scale: torch.Tensor,
seq_idx: torch.Tensor,
block_table: torch.Tensor,
dummy0: torch.Tensor,
) -> torch.Tensor:
...
@triton_op("rbln_triton_ops::flash_attention_naive_decode", mutates_args=())
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
mask: torch.Tensor,
qk_scale: torch.Tensor,
seq_idx: torch.Tensor,
block_table: torch.Tensor,
dummy0: torch.Tensor,
) -> torch.Tensor:
...
예: @register_fake 연산자 스텁
# 발췌: `vllm_rbln/triton_kernels/attention.py`
@register_fake("rbln_triton_ops::flash_attention_naive_prefill")
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
mask: torch.Tensor,
qk_scale: torch.Tensor,
seq_idx: torch.Tensor,
block_table: torch.Tensor,
dummy0: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(query)
@register_fake("rbln_triton_ops::flash_attention_naive_decode")
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
mask: torch.Tensor,
qk_scale: torch.Tensor,
seq_idx: torch.Tensor,
block_table: torch.Tensor,
dummy0: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(query)
4) 참고: 모델 호출 지점에서 torch.ops.rbln_triton_ops 선택 확인¶
모델 호출 지점이 torch.ops.rbln_triton_ops.*를 선택하는지 확인하여, 런타임에 커스텀 커널 경로가 사용되도록 합니다.
대상: vllm_rbln/v1/attention/backends/flash_attention.py
RBLN_USE_CUSTOM_KERNEL=1이면 어텐션에 쓰이는 커스텀 커널이 활성화됩니다:
발췌: flash_attention_naive_* 커널 모드 선택
5) 실행: 커스텀 커널을 켠 상태로 추론 실행¶
위치: vllm-rbln/
추론 실행:
명령에서 사용하는 환경 변수:
| 환경 변수 | 설명 |
|---|---|
RBLN_USE_CUSTOM_KERNEL=1 |
커스텀 커널 실행 경로(torch.ops.rbln_triton_ops.*)를 활성화합니다. |
VLLM_RBLN_COMPILE_MODEL=1 |
디바이스 실행을 위해 모델을 컴파일합니다(첫 실행 시 필요). |
VLLM_RBLN_COMPILE_STRICT_MODE=1 |
지원하지 않는 연산/커널을 먼저 드러냅니다(개발 시 권장). |
VLLM_RBLN_USE_VLLM_MODEL=1 |
vLLM RBLN에서 기대하는 vLLM 모델 연동 경로를 사용합니다. |
VLLM_DISABLE_COMPILE_CACHE=1 |
컴파일 캐시를 사용하지 않고 깨끗하게 다시 빌드합니다. |
VLLM_USE_V1=1 |
vLLM V1 실행 경로를 사용합니다. |
Tip
커널 코드를 수정했고 다시 빌드되도록 하려면, 해당 실행 시 VLLM_DISABLE_COMPILE_CACHE=1을 설정하세요.
문제 해결¶
문제 해결 체크리스트¶
설치/컴파일/실행 중 문제가 있으면 여기서부터 확인하시면 됩니다:
- Operator 찾기: 초기화 시 import가 실행되는지, 연산자 이름이 등록된 이름과 일치하는지 확인이 필요합니다.
- 변경 후 재빌드: 전체 모델 재빌드를 하려면
VLLM_DISABLE_COMPILE_CACHE=1을 설정해야 합니다. - Strict 모드: 개발 중에는
VLLM_RBLN_COMPILE_STRICT_MODE=1을 설정하면 문제를 빨리 발견할 수 있습니다.
자주 나오는 문제와 해결 방법은 문제 해결을 참고하세요.