YOLOv8 서빙
이 페이지에서는 TorchServe에서 RBLN SDK를 활용하여, 컴파일된 YOLOv8
모델을 서빙할 수 있는 방법을 소개합니다. TorchServe 환경 구성 방법에 대해서는 TorchServe를 참고 바랍니다.
아래에 소개된 모델 컴파일 및 TorchServe 설정에 필요한 커맨드들을 확인하려면 모델 주를 참고 바랍니다.
참고
이 튜토리얼은 사용자가 RBLN SDK 기반의 모델 컴파일 및 추론에 대해 잘 이해하고 있다는 가정하에 작성되었습니다. RBLN SDK 사용법에 익숙하지 않을 경우 파이토치/텐서플로우 튜토리얼 및 파이썬 API 페이지를 참고 바랍니다.
사전준비
시작하기에 앞서 TorchServe 환경과 컴파일된 YOLOv8
모델, 그리고 테스트에 사용할 coco 라벨파일이 필요합니다.
TorchServe를 이용한 모델 서빙
TorchServe에서 모델 서빙은 모델 아카이브(.mar
) 파일 단위로 이루어집니다. .mar
파일에는 모델 서빙에 필요한 모든 정보가 포함됩니다. 본 섹션에서는 .mar
파일 생성 및 생성된 .mar
파일을 이용하여 모델을 서빙하는 방법에 대해 설명합니다.
모델 핸들러 작성
아래는 TorchServe의 BaseHandler를 상속받아 YOLOv8
모델의 요청을 처리하는 간단한 Custom handler 예시입니다. 이 Handler는 모델 서빙을 위해 initialize()
, inference()
, postprocess()
, handle()
메서드를 정의합니다. initialize()
메서드는 모델이 model_store
에서 로드될 때 호출되며, handle()
메서드는 TorchServe inference API의 predictions 요청에 의해 호출됩니다.
yolov8l_handler.py |
---|
| # yolov8l_handler.py
"""
ModelHandler defines a custom model handler.
"""
import os
import torch
import rebel # RBLN Runtime
import PIL.Image as Image
import numpy as np
import yaml
import io
from ultralytics.data.augment import LetterBox
from ultralytics.yolo.utils.ops import non_max_suppression as nms, scale_boxes
from ts.torch_handler.base_handler import BaseHandler
class YOLOv8_Handler(BaseHandler):
"""
A custom model handler implementation.
"""
def __init__(self):
self._context = None
self.initialized = False
self.explain = False
self.target = 0
self.input_image = None
def initialize(self, context):
"""
Initialize model. This will be called during model loading time
:param context: Initial context contains model server system properties.
:return:
"""
self._context = context
# load the model, refer 'custom handler class' above for details
model_dir = context.system_properties.get("model_dir")
serialized_file = context.manifest["model"].get("serializedFile")
model_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_path):
raise RuntimeError(
f"[RBLN ERROR] File not found at the specified model_path({model_path})."
)
self.module = rebel.Runtime(model_path, tensor_type="pt")
self.initialized = True
def preprocess(self, data):
"""
Transform raw input into model input data.
:param batch: list of raw requests, should match batch size
:return: list of preprocessed model input data
"""
# Take the input data and make it inference ready
preprocessed_data = data[0].get("data")
if preprocessed_data is None:
preprocessed_data = data[0].get("body")
image = Image.open(io.BytesIO(preprocessed_data)).convert("RGB")
image = np.array(image)
preprocessed_data = LetterBox(new_shape=(640, 640))(image=image)
preprocessed_data = preprocessed_data.transpose((2, 0, 1))[::-1]
preprocessed_data = np.ascontiguousarray(
preprocessed_data, dtype=np.float32
)
preprocessed_data = preprocessed_data[None]
preprocessed_data /= 255
self.input_image = preprocessed_data
return torch.from_numpy(preprocessed_data)
def inference(self, model_input):
"""
Internal inference methods
:param model_input: transformed model input data
:return: list of inference output in NDArray
"""
# Do some inference call to engine here and return output
model_output = self.module.run(model_input)
return model_output
def postprocess(self, inference_output):
"""
Return inference result.
:param inference_output: list of inference output
:return: list of predict results
"""
# Take output from network and post-process to desired format
pred = nms(
inference_output, 0.25, 0.45, None, False, max_det=1000
)[0]
pred[:, :4] = scale_boxes(
self.input_image.shape[2:], pred[:, :4], self.input_image.shape
)
yaml_path = "./coco128.yaml"
postprocess_output = []
with open(yaml_path) as f:
data = yaml.safe_load(f)
names = list(data["names"].values())
for *xyxy, conf, cls in reversed(pred):
xyxy_str = f"{xyxy[0]}, {xyxy[1]}, {xyxy[2]}, {xyxy[3]}"
postprocess_output.append(
f"xyxy : {xyxy_str}, conf : {conf}, cls : {names[int(cls)]}"
)
return postprocess_output
def handle(self, data, context):
"""
Invoke by TorchServe for prediction request.
Do pre-processing of data, prediction using model and postprocessing of prediciton output
:param data: Input data for prediction
:param context: Initial context contains model server system properties.
:return: prediction output
"""
model_input = self.preprocess(data)
model_output = self.inference(model_input)
return [{"result": self.postprocess(model_output[0])}]
|
모델 서빙 설정 작성
아래와 같이 TorchServe로 YOLOv8
모델을 서빙하기 위한 설정인 config.properties
파일을 작성합니다. 이 파일은 모델의 서빙에 필요한 정보를 담을 수 있으며, 이 예제에서는 한 개의 Worker로 제한하기 위해 default_workers_per_model
을 1로 설정합니다. 본 예제에서 사용할 Input Image의 Size를 고려하여 max_request_size
를 여유있게 100MB 정도로 설정합니다.
config.properties |
---|
| max_request_size=104857600
max_response_size=104857600
default_workers_per_model:1
models={\
"yolov8l": {\
"1.0": {\
"marName": "yolov8l.mar",\
"responseTimeout": 120\
}\
}\
}
|
torch-model-archiver
를 이용한 모델 아카이빙
아래와 같이 모델 아카이브(.mar
) 파일이 저장될 경로인 model_store
디렉토리를 생성합니다. 이 디렉토리에 YOLOv8
모델 아카이브 파일이 저장됩니다.
이제 모델 아카이브 파일을 만들기 위해 필요한 내용이 모두 준비되었습니다. torch-model-archiver
도구를 이용해 모델 아카이브 파일을 만들 수 있습니다. 생성된 모델 아카이브 파일이 위치한 model_store
디렉토리는 TorchServe를 실행할 때 파라미터로 전달됩니다.
| $ torch-model-archiver --model-name yolov8l \
--version 1.0 \
--serialized-file ./yolov8l.rbln \
--handler ./yolov8l_handler.py \
--extra-files ./coco128.yaml \
--export-path ./model_store
|
사용된 옵션은 아래와 같습니다.
--model-name
: 서빙할 모델이름으로 yolov8l
으로 설정합니다.
--version
: TorchServe로 서빙할 모델에 대한 버전입니다.
--serialized-file
: 가중치 파일을 지정하는데 사용합니다. 여기서는 yolov8l.rbln
파일을 지정합니다.
--handler
: 요청 모델에 대한 Handler Script를 지정하는 옵션이며, 위에서 작성한 ./yolov8l_handler.py
를 지정해줍니다.
--extra-files
: 추가로 Archive에 포함되어야 하는 파일이며, 예제에서 사용할 라벨 파일인 coco128.yaml
파일을 지정해줍니다.
--export-path
: 아카이빙 결과물의 저장 경로를 설정하는 옵션으로, 위에서 생성한 model_store
디렉토리로 설정합니다.
실행을 하면 --export-path
에 설정한 model_store 경로에 yolov8l.mar
파일이 생성됩니다. 이제 TorchServe을 사용해 YOLOv8을 서빙할 준비가 완료되었습니다.
| +-- (YOUR_PATH)/
| +-- model_store/
| | +--yolov8l.mar
| +-- yolov8l.rbln
| +-- yolov8l_handler.py
| +-- coco128.yaml
| +-- config.properties
|
torchserve
실행
torchserve
명령어로 아래와 같이 서빙을 시작할 수 있습니다. TorchServe 환경에서 YOLOv8
모델 서빙을 간단히 테스트하는 것이 목적이므로, --disable-token-auth
옵션을 사용하여 토큰 인증을 비활성화합니다.
| $ torchserve --start --ncs \
--ts-config ./config.properties \
--model-store ./model_store \
--models yolov8l=yolov8l.mar \
--disable-token-auth
|
--start
: model-server를 시작합니다.
--ncs
: --no-config-snapshots 옵션입니다.
--ts-config
: torchserve
로 전달되는 모델 서빙에 대한 설정 파일을 지정하는 옵션입니다. config.properties
로 지정합니다.
--model-store
: 모델을 로드하거나 기본적으로 로드할 모델들의 경로를 지정합니다.
--models
: 서빙할 모델을 지정합니다. all
을 지정하면, model_store
경로 내 모든 모델이 서빙할 모델로 지정됩니다.
--disable-token-auth
: Token authorization을 비활성화 합니다.
TorchServe이 정상적으로 실행되면 백그라운드에서 동작합니다. TorchServe의 동작을 중지하기 위한 명령어는 아래와 같습니다.
TorchServe는 기본 설정으로 Management API는 8081
번 포트를, Inference API는 8080
번 포트를 사용합니다.
아래의 Management API를 통해서 서빙되고 있는 모델 리스트를 확인할 수 있습니다.
| $ curl -X GET "http://localhost:8081/models"
|
정상적으로 동작할 경우 YOLOv8
모델이 서빙되고 있는 것이 확인할 수 있습니다.
| {
"models": [
{
"modelName": "yolov8l",
"modelUrl": "yolov8l.mar"
}
]
}
|
TorchServe Inference API
기반 추론 요청
torchserve
로 서빙된 YOLOv8
서빙을 테스트하기 위해서 TorchServe Inference API의 Prediction API를 이용하여 추론 요청을 합니다.
YOLOv8
추론 요청에 이용할 샘플 이미지를 다운로드합니다.
| $ wget https://rbln-public.s3.ap-northeast-2.amazonaws.com/images/people4.jpg
|
curl을 이용하여 TorchServe Inference API
기반 추론 요청을 요청합니다.
| $ curl -X POST "http://127.0.0.1:8080/predictions/yolov8l" \
-H "Content-Type: application/octet-stream" \
--data-binary @./people4.jpg
|
정상적으로 동작할 경우 아래와 같이 응답을 확인할 수 있습니다.
| {
"result": [
"xyxy : 1.5238770246505737, 0.10898438096046448, 1.8791016340255737, 1.0, conf : 0.91015625, cls : person",
"xyxy : 0.6436523795127869, 0.2138671875, 0.968994140625, 1.0, conf : 0.916015625, cls : person",
"xyxy : 0.90380859375, 0.29179689288139343, 1.296240210533142, 1.0, conf : 0.9296875, cls : person",
"xyxy : 1.9107422828674316, 0.17695312201976776, 2.558789014816284, 1.0, conf : 0.943359375, cls : person"
]
}
|
고급 기능
배치 추론
(Batch Inference)
TorchServe는 여러 추론 요청을 모아 한 번에 처리하는 배치 추론
기능을 지원합니다.
배치 추론
설정
TorchServe에서 배치 추론
을 사용하기 위해 모델 설정에서 아래 2개 옵션을 지정해야 합니다.
batchSize
: 모델이 사용 가능한 최대 배치 사이즈
입니다.
maxBatchDeplay
: TorchServe에서 batchSize
만큼의 요청을 기다릴 수 있는 최대 지연 시간(ms)이며, 수신된 요청이 최대 배치 사이즈
만큼 수신되지 못한 채 설정된 지연시간을 초과하면, 현재까지 수신된 요청을 Handler로 전달합니다.
아래는 config.properties 파일에 설정하는 방법 예시입니다. config.properties에서 설정할 때에 각 배치 설정은 batchSize
, maxBatchDelay
키 값으로 아래와 같이 설정합니다.
config_b4.properties |
---|
| max_request_size=104857600
max_response_size=104857600
default_workers_per_model=1
models={\
"yolov8l": {\
"1.0": {\
"marName": "yolov8l.mar",\
"batchSize": 4,\
"maxBatchDelay": 100,\
"responseTimeout": 120\
}\
}\
}
|
모델 컴파일
RBLN 컴파일러는 다양한 입력 형상에 대해 모델을 여러 차례 컴파일하는 버케팅(Bucketing)
모델을 지원합니다. 이 모델을 배치 추론
에 활용하여 추론 성능을 최적화할 수 있습니다.
아래는 배치 범위 1-4
에 대해 버케팅
모델을 생성하는 예시 코드입니다:
| size = 640 # 이미지의 넓이와 높이
batches = [1, 2, 3, 4] # 지원되는 배치의 크기
input_infos = []
# 개별 배치 크기에 대한 입력 정보를 생성
for i, batch in enumerate(batches):
input_info = [("input_np", [batch, 3, size, size], "float32")]
input_infos.append(input_info)
# 미리 정의된 입력 정보를 기반으로 모델 컴파일
compiled_model = rebel.compile_from_torch(model, input_info=input_infos)
# 컴파일된 모델 저장
compiled_model.save("resnet50.rbln")
|
컴파일 된 모델을 저장할 때 파일 이름은 모델 핸들러
에서 불러오기 위해서 torch-model-archiver
의 --serialized-file
파라미터와 일치해야 합니다.
모델 핸들러
모델 핸들러에서 특정 배치 크기만큼 런타임을 생성하고, 생성된 런타임을 통해 제공된 입력 데이터 기반의 추론 연산을 정의할 수 있습니다.
yolov8l_batch_handler.py |
---|
| # yolov8l_batch_handler.py
"""
ModelHandler defines a custom model handler.
"""
import io
import os
import numpy as np
import PIL.Image as Image
import rebel # RBLN Runtime
import torch
import yaml
from ts.torch_handler.base_handler import BaseHandler
from ultralytics.data.augment import LetterBox
from ultralytics.yolo.utils.ops import non_max_suppression as nms
from ultralytics.yolo.utils.ops import scale_boxes
class YOLOv8_Handler(BaseHandler):
"""
A custom model handler implementation.
"""
def __init__(self):
self._context = None
self.initialized = False
self.explain = False
self.target = 0
self.input_images = []
self.batch_size = None
self.max_batch_size = None
def initialize(self, context):
"""
Initialize model. This will be called during model loading time
:param context: Initial context contains model server system properties.
:return:
"""
self._context = context
# load the model, refer 'custom handler class' above for details
model_dir = context.system_properties.get("model_dir")
serialized_file = context.manifest["model"].get("serializedFile")
self.max_batch_size = context.system_properties["batch_size"]
model_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_path):
raise RuntimeError(
f"[RBLN ERROR] File not found at the specified model_path({model_path})."
)
self.modules = []
compiled_model = rebel.RBLNCompiledModel(model_path)
for i in range(self.max_batch_size):
self.modules.append(compiled_model.create_runtime(input_info_index=i, tensor_type="pt"))
self.initialized = True
def preprocess(self, data):
"""
Transform raw input into model input data.
:param batch: list of raw requests, should match batch size
:return: list of preprocessed model input data
"""
# Take the input data and make it inference ready
self.batch_size = num_requests = len(data)
assert self.batch_size <= self.max_batch_size, print(
f"[RBLN][ERROR] Inputed batched number({self.batch_size})"
f" is over the batchSize({self.max_batch_size}) in configuration."
)
self.input_images.clear()
for i in range(num_requests):
preprocessed_data = data[0].get("data")
if preprocessed_data is None:
preprocessed_data = data[0].get("body")
image = Image.open(io.BytesIO(preprocessed_data)).convert("RGB")
image = np.array(image)
preprocessed_data = LetterBox(new_shape=(640, 640))(image=image)
preprocessed_data = preprocessed_data.transpose((2, 0, 1))[::-1]
preprocessed_data = np.ascontiguousarray(
preprocessed_data, dtype=np.float32
)
preprocessed_data = preprocessed_data[None]
preprocessed_data /= 255
self.input_images.append(preprocessed_data)
preprocessed_datas = np.concatenate(self.input_images, axis=0).copy()
return torch.from_numpy(preprocessed_datas)
def inference(self, model_input):
"""
Internal inference methods
:param model_input: transformed model input data
:return: list of inference output in NDArray
"""
# Do some inference call to engine here and return output
model_output = self.modules[self.batch_size - 1].run(model_input)
return model_output
def postprocess(self, inference_output):
"""
Return inference result.
:param inference_output: list of inference output
:return: list of predict results
"""
# Take output from network and post-process to desired format
chunky_batched_result = np.array_split(
inference_output, self.batch_size, axis=0
)
postprocess_outputs = []
for idx, result in enumerate(chunky_batched_result):
nms_result = nms(
result, 0.25, 0.45, None, False, max_det=1000
)
pred = nms_result[0]
pred[:, :4] = scale_boxes(
self.input_images[idx].shape[2:],
pred[:, :4],
self.input_images[idx].shape,
)
yaml_path = "./coco128.yaml"
postprocess_output = []
with open(yaml_path) as f:
data = yaml.safe_load(f)
names = list(data["names"].values())
for *xyxy, conf, cls in reversed(pred):
xyxy_str = f"{xyxy[0]}, {xyxy[1]}, {xyxy[2]}, {xyxy[3]}"
postprocess_output.append(
f"xyxy : {xyxy_str}, conf : {conf}, cls : {names[int(cls)]}"
)
postprocess_outputs.append(
[{f"result[{len(postprocess_outputs)}]": postprocess_output}]
)
return postprocess_outputs
def handle(self, data, context):
"""
Invoke by TorchServe for prediction request.
Do pre-processing of data, prediction using model and postprocessing of prediciton output
:param data: Input data for prediction
:param context: Initial context contains model server system properties.
:return: prediction output
"""
model_input = self.preprocess(data)
model_output = self.inference(model_input)
return self.postprocess(model_output[0])
|
모델 서빙
위에서 생성한 설정
, 모델
, 모델 핸들러
를 이용한 "torch-model-archiver를 이용한 모델 아카이빙
", "torchserve 실행
"해서 모델 서빙을 시작합니다.
아래의 Management API
를 통해 설정이 정상적으로 적용되었는지 확인할 수 있습니다.
| $ curl -X GET "http://localhost:8081/models/yolov8l"
|
응답에서 batchSize
와 maxBatchDelay
가 지정된 값으로 설정되었는지 확인할 수 있습니다.
| [
{
"modelName": "yolov8l",
"modelVersion": "1.0",
"modelUrl": "yolov8l.mar",
"runtime": "python",
"minWorkers": 1,
"maxWorkers": 1,
"batchSize": 4,
"maxBatchDelay": 100,
:
:
"workers": [
{
:
:
}
],
:
:
}
]
|