YOLOv8
In this tutorial, we will guide you through the steps required to integrate RBLN SDK with TorchServe using a precompiled YOLOv8
model. For instructions on setting up the TorchServe environment, refer to TorchServe.
You can check out the actual commands required to compile the model and initialize Nvidia Triton Python Backend on our model zoo.
Note
This tutorial assumes that you are familiar with compiling and running inference using the RBLN SDK. If you are not familiar with RBLN SDK, refer to PyTorch/TensorFlow tutorials and the API Documentation.
Prerequisites
Before we start, please make sure you have prepared the following prerequisites in your system:
Quick Start with TorchServe
In TorchServe, models are served as Model Archive (.mar
) units, which contain all necessary information for serving the model. The following guide explains how to create a .mar
file and use it for model serving.
Write the Model Request Handler
Below is a simple handler that inherits from TorchServe BaseHandler for YOLOv8
inference requests. This handler defines initialize()
, inference()
, postprocess()
, and handle()
for model serving. The initialize()
method is called when the model is loaded from the model_store
directory, and the handle()
method is invoked for TorchServe Inference API's predictions request.
yolov8l_handler.py |
---|
| # yolov8l_handler.py
"""
ModelHandler defines a custom model handler.
"""
import os
import torch
import rebel # RBLN Runtime
import PIL.Image as Image
import numpy as np
import yaml
import io
from ultralytics.data.augment import LetterBox
from ultralytics.yolo.utils.ops import non_max_suppression as nms, scale_boxes
from ts.torch_handler.base_handler import BaseHandler
class YOLOv8_Handler(BaseHandler):
"""
A custom model handler implementation.
"""
def __init__(self):
self._context = None
self.initialized = False
self.explain = False
self.target = 0
self.input_image = None
def initialize(self, context):
"""
Initialize model. This will be called during model loading time
:param context: Initial context contains model server system properties.
:return:
"""
self._context = context
# load the model, refer 'custom handler class' above for details
model_dir = context.system_properties.get("model_dir")
serialized_file = context.manifest["model"].get("serializedFile")
model_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_path):
raise RuntimeError(
f"[RBLN ERROR] File not found at the specified model_path({model_path})."
)
self.module = rebel.Runtime(model_path, tensor_type="pt")
self.initialized = True
def preprocess(self, data):
"""
Transform raw input into model input data.
:param batch: list of raw requests, should match batch size
:return: list of preprocessed model input data
"""
# Take the input data and make it inference ready
preprocessed_data = data[0].get("data")
if preprocessed_data is None:
preprocessed_data = data[0].get("body")
image = Image.open(io.BytesIO(preprocessed_data)).convert("RGB")
image = np.array(image)
preprocessed_data = LetterBox(new_shape=(640, 640))(image=image)
preprocessed_data = preprocessed_data.transpose((2, 0, 1))[::-1]
preprocessed_data = np.ascontiguousarray(
preprocessed_data, dtype=np.float32
)
preprocessed_data = preprocessed_data[None]
preprocessed_data /= 255
self.input_image = preprocessed_data
return torch.from_numpy(preprocessed_data)
def inference(self, model_input):
"""
Internal inference methods
:param model_input: transformed model input data
:return: list of inference output in NDArray
"""
# Do some inference call to engine here and return output
model_output = self.module.run(model_input)
return model_output
def postprocess(self, inference_output):
"""
Return inference result.
:param inference_output: list of inference output
:return: list of predict results
"""
# Take output from network and post-process to desired format
pred = nms(
inference_output, 0.25, 0.45, None, False, max_det=1000
)[0]
pred[:, :4] = scale_boxes(
self.input_image.shape[2:], pred[:, :4], self.input_image.shape
)
yaml_path = "./coco128.yaml"
postprocess_output = []
with open(yaml_path) as f:
data = yaml.safe_load(f)
names = list(data["names"].values())
for *xyxy, conf, cls in reversed(pred):
xyxy_str = f"{xyxy[0]}, {xyxy[1]}, {xyxy[2]}, {xyxy[3]}"
postprocess_output.append(
f"xyxy : {xyxy_str}, conf : {conf}, cls : {names[int(cls)]}"
)
return postprocess_output
def handle(self, data, context):
"""
Invoke by TorchServe for prediction request.
Do pre-processing of data, prediction using model and postprocessing of prediciton output
:param data: Input data for prediction
:param context: Initial context contains model server system properties.
:return: prediction output
"""
model_input = self.preprocess(data)
model_output = self.inference(model_input)
return [{"result": self.postprocess(model_output[0])}]
|
Write the Model Configuration
Create a config.properties
file as follows. This file contains the necessary information for serving the model. In this tutorial, to limit the number of workers to a single instance, set default_workers_per_model
to 1. max_request_size
is configured to 100 MB for input image size in this example.
config.properties |
---|
| max_request_size=104857600
max_response_size=104857600
default_workers_per_model:1
models={\
"yolov8l": {\
"1.0": {\
"marName": "yolov8l.mar",\
"responseTimeout": 120\
}\
}\
}
|
Model Archiving with torch-model-archiver
The model_store
directory stores .mar
files, including the YOLOv8
model archive used in this tutorial, for serving.
Now that the setup for model archiving is complete, run the torch-model-archiver
command to create the model archive file. The model_store
folder, where the generated yolov8l.mar
archive file is located, will be passed as a parameter when TorchServe starts.
| $ torch-model-archiver --model-name yolov8l \
--version 1.0 \
--serialized-file ./yolov8l.rbln \
--handler ./yolov8l_handler.py \
--extra-files ./coco128.yaml \
--export-path ./model_store
|
The options passed to torch-model-archiver
are as follows.
--model-name
: Specifies the name of the model to be served, set as yolov8l
.
--version
: Defines the version of the model to be served with TorchServe.
--serialized-file
: Specifies the weight file. Set to yolov8l.rbln
.
--handler
: Specifies the handler script for the model, set as yolov8l_handler.py
.
--extra-file
: Specify the files that need to be included in the archive, set as coco128.yaml
.
--export-path
: Specifies the output directory for the archived file. The previously created model_store
folder is set as the destination.
After executing the command, the yolov8l.mar
file is generated in the model_store
directory specified by --export-path
.
| +-- (YOUR_PATH)/
| +-- model_store/
| | +-- yolov8l.mar
| +-- yolov8l.rbln
| +-- yolov8l_handler.py
| +-- coco128.yaml
| +-- config.properties
|
Run torchserve
TorchServe can be started using the following command. For a simple test where token authentication is not required, you can use the --disable-token-auth
option.
| $ torchserve --start --ncs --ts-config ./config.properties --model-store ./model_store --models yolov8l.mar --disable-token-auth
|
--start
: Starts the TorchServe service.
--ncs
: Disable snapshot feature.
--ts-config
: TorchServe configuration.
--model-store
: Specifies the directory containing model archives (.mar
) files.
--models
: Specify the model to serve. If all
is specified, all models in the model_store
directory are designated as serving models.
--disable-token-auth
: Disables token authentication.
When TorchServe is successfully started, it operates in the background. The command to stop TorchServe is shown below:
TorchServe provides the Management API on port 8081
and the Inference API on port 8080
by default.
You can check the list of models currently being served using the following Management API.
| $ curl -X GET "http://localhost:8081/models"
|
If the operation is successful, you can verify that the YOLOv8
model is being served.
| {
"models": [
{
"modelName": "yolov8l",
"modelUrl": "yolov8l.mar"
}
]
}
|
Inference Request with TorchServe Inference API
Now we can send an inference request using the Prediction API from the TorchServe Inference API to test the YOLOv8
model served with TorchServe.
Download a sample image for the YOLOv8
inference request.
| $ wget https://rbln-public.s3.ap-northeast-2.amazonaws.com/images/people4.jpg
|
Make an inference request using the TorchServe inference API
with curl.
| $ curl -X POST "http://127.0.0.1:8080/predictions/yolov8l" -H "Content-Type: application/octet-stream" --data-binary @./people4.jpg
|
If the inference request is successful, the following response is returned.
| {
"result": [
"xyxy : 1.5238770246505737, 0.10898438096046448, 1.8791016340255737, 1.0, conf : 0.91015625, cls : person",
"xyxy : 0.6436523795127869, 0.2138671875, 0.968994140625, 1.0, conf : 0.916015625, cls : person",
"xyxy : 0.90380859375, 0.29179689288139343, 1.296240210533142, 1.0, conf : 0.9296875, cls : person",
"xyxy : 1.9107422828674316, 0.17695312201976776, 2.558789014816284, 1.0, conf : 0.943359375, cls : person"
]
}
|
Advanced Features
Batch Inference
in TorchServe
TorchServe supports Batch Inference
, a method of grouping multiple inference requests together and processing them all at once.
Batch Inference
Configuration
To use Batch Inference
in TorchServe, the model configuration must include the following two required settings:
batchSize
: The maximum batch size
that the model can handle.
maxBatchDelay
: The maximum wait time
(in milliseconds) that TorchServe will hold requests to reach the defined batchSize
. If the number of received requests does not reach the maximum batch size
within the specified delay, all currently received requests will be sent to the handler for processing.
In the config.properties
file, specify the batch settings using batchSize
and maxBatchDelay
as shown below.
config_b4.properties |
---|
| max_request_size=104857600
max_response_size=104857600
default_workers_per_model=1
models={\
"yolov8l": {\
"1.0": {\
"marName": "yolov8l.mar",\
"batchSize": 4,\
"maxBatchDelay": 100,\
"responseTimeout": 120\
}\
}\
}
|
Model Compilation
Bucketing
is the process of compiling a model multiple times with different target input shapes to create optimized bucketed models
. The RBLN Compiler supports bucketing
by compiling models for various input shapes, enhancing Batch Inference
and improving memory efficiency.
Below is an example code snippet demonstrating how to define a bucketed model that supports batch sizes ranging from 1 to 4
:
| size = 640 # Width and height of image
batches = [1, 2, 3, 4] # Supported batch sizes
input_infos = []
# Create input information for each batch size
for i, batch in enumerate(batches):
input_info = [("input_np", [batch, 3, size, size], "float32")]
input_infos.append(input_info)
# Compile the model with the pre-defined input information
compiled_model = rebel.compile_from_torch(model, input_info=input_infos)
# Compiled model save
compiled_model.save("yolov8l.rbln")
|
When saving the compiled model, the file name must match the --serialized-file
parameter specified torch-model-archiver
to be correctly loaded by the Model Handler
.
Model Handler
The model handler creates a runtime for a specific batch size and uses it to perform inference operations based on the provided input data.
yolov8l_batch_handler.py |
---|
| # yolov8l_batch_handler.py
"""
ModelHandler defines a custom model handler.
"""
import io
import os
import numpy as np
import PIL.Image as Image
import rebel # RBLN Runtime
import torch
import yaml
from ts.torch_handler.base_handler import BaseHandler
from ultralytics.data.augment import LetterBox
from ultralytics.yolo.utils.ops import non_max_suppression as nms
from ultralytics.yolo.utils.ops import scale_boxes
class YOLOv8_Handler(BaseHandler):
"""
A custom model handler implementation.
"""
def __init__(self):
self._context = None
self.initialized = False
self.explain = False
self.target = 0
self.input_images = []
self.batch_size = None
self.max_batch_size = None
def initialize(self, context):
"""
Initialize model. This will be called during model loading time
:param context: Initial context contains model server system properties.
:return:
"""
self._context = context
# load the model, refer 'custom handler class' above for details
model_dir = context.system_properties.get("model_dir")
serialized_file = context.manifest["model"].get("serializedFile")
self.max_batch_size = context.system_properties["batch_size"]
model_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_path):
raise RuntimeError(
f"[RBLN ERROR] File not found at the specified model_path({model_path})."
)
self.modules = []
compiled_model = rebel.RBLNCompiledModel(model_path)
for i in range(self.max_batch_size):
self.modules.append(compiled_model.create_runtime(input_info_index=i, tensor_type="pt"))
self.initialized = True
def preprocess(self, data):
"""
Transform raw input into model input data.
:param batch: list of raw requests, should match batch size
:return: list of preprocessed model input data
"""
# Take the input data and make it inference ready
self.batch_size = num_requests = len(data)
assert self.batch_size <= self.max_batch_size, print(
f"[RBLN][ERROR] Inputed batched number({self.batch_size})"
f" is over the batchSize({self.max_batch_size}) in configuration."
)
self.input_images.clear()
for i in range(num_requests):
preprocessed_data = data[0].get("data")
if preprocessed_data is None:
preprocessed_data = data[0].get("body")
image = Image.open(io.BytesIO(preprocessed_data)).convert("RGB")
image = np.array(image)
preprocessed_data = LetterBox(new_shape=(640, 640))(image=image)
preprocessed_data = preprocessed_data.transpose((2, 0, 1))[::-1]
preprocessed_data = np.ascontiguousarray(
preprocessed_data, dtype=np.float32
)
preprocessed_data = preprocessed_data[None]
preprocessed_data /= 255
self.input_images.append(preprocessed_data)
preprocessed_datas = np.concatenate(self.input_images, axis=0).copy()
return torch.from_numpy(preprocessed_datas)
def inference(self, model_input):
"""
Internal inference methods
:param model_input: transformed model input data
:return: list of inference output in NDArray
"""
# Do some inference call to engine here and return output
model_output = self.modules[self.batch_size - 1].run(model_input)
return model_output
def postprocess(self, inference_output):
"""
Return inference result.
:param inference_output: list of inference output
:return: list of predict results
"""
# Take output from network and post-process to desired format
chunky_batched_result = np.array_split(
inference_output, self.batch_size, axis=0
)
postprocess_outputs = []
for idx, result in enumerate(chunky_batched_result):
nms_result = nms(
result, 0.25, 0.45, None, False, max_det=1000
)
pred = nms_result[0]
pred[:, :4] = scale_boxes(
self.input_images[idx].shape[2:],
pred[:, :4],
self.input_images[idx].shape,
)
yaml_path = "./coco128.yaml"
postprocess_output = []
with open(yaml_path) as f:
data = yaml.safe_load(f)
names = list(data["names"].values())
for *xyxy, conf, cls in reversed(pred):
xyxy_str = f"{xyxy[0]}, {xyxy[1]}, {xyxy[2]}, {xyxy[3]}"
postprocess_output.append(
f"xyxy : {xyxy_str}, conf : {conf}, cls : {names[int(cls)]}"
)
postprocess_outputs.append(
[{f"result[{len(postprocess_outputs)}]": postprocess_output}]
)
return postprocess_outputs
def handle(self, data, context):
"""
Invoke by TorchServe for prediction request.
Do pre-processing of data, prediction using model and postprocessing of prediciton output
:param data: Input data for prediction
:param context: Initial context contains model server system properties.
:return: prediction output
"""
model_input = self.preprocess(data)
model_output = self.inference(model_input)
return self.postprocess(model_output[0])
|
Model Serving
Using the previously created Configuration
, Model
, and Model handler
, start model serving by following the steps in “Model Archiving with torch-model-archiver
” and “'Run torchserve'”.
You can verify whether the configuration has been applied correctly by using the running Management API
command:
| $ curl -X GET "http://localhost:8081/models/yolov8l"
|
Check whether batchSize
and maxBatchDelay
are set to the specified values in the response.
| [
{
"modelName": "yolov8l",
"modelVersion": "1.0",
"modelUrl": "yolov8l.mar",
"runtime": "python",
"minWorkers": 1,
"maxWorkers": 1,
"batchSize": 4,
"maxBatchDelay": 100,
:
:
"workers": [
{
:
:
}
],
:
:
}
]
|