ResNet50 서빙
이 페이지에서는 TorchServe에서 RBLN SDK를 활용하여 컴파일된 ResNet50
모델을 서빙하는 방법을 소개합니다. TorchServe 환경 구성 방법에 대해서는 TorchServe를 참고 바랍니다.
아래에 소개된 모델 컴파일 및 Torchserve 설정에 필요한 커맨드들을 확인하시려면 모델 주를 참고 바랍니다.
참고
이 튜토리얼은 사용자가 RBLN SDK 기반의 모델 컴파일 및 추론에 대해 잘 이해하고 있다는 가정하에 작성되었습니다. RBLN SDK 사용법에 익숙하지 않을 경우 파이토치/텐서플로우 튜토리얼 및 파이썬 API 페이지를 참고 바랍니다.
사전준비
시작하기에 앞서 TorchServe 환경과 컴파일된 ResNet50
모델이 필요합니다.
TorchServe를 이용한 모델 서빙
TorchServe에서 모델 서빙은 모델 아카이브(.mar
) 파일 단위로 이루어집니다. .mar
파일에는 모델 서빙에 필요한 모든 정보가 포함됩니다. 본 섹션에서는 .mar
파일 생성 및 생성된 .mar
파일을 이용하여 모델을 서빙하는 방법에 대해 설명합니다.
모델 핸들러 작성
아래는 TorchServe의 BaseHandler를 상속받아서 ResNet50
모델의 요청을 처리하는 간단한 Custom handler 예시입니다. 이 Handler는 모델 서빙을 위해 initialize()
, inference()
, postprocess()
, handle()
메서드를 정의합니다. initialize()
메서드는 모델이 model_store에서 로드될 때 호출되며, handle()
메서드는 TorchServe inference API의 predictions 요청에 의해 호출됩니다.
resnet50_handler.py |
---|
| # resnet50_handler.py
import os
import torch
from torchvision.models import ResNet50_Weights
import rebel # RBLN Runtime
import PIL.Image as Image
import io
from ts.torch_handler.base_handler import BaseHandler
class Resnet50Handler(BaseHandler):
def __init__(self):
self._context = None
self.initialized = False
self.model = None
self.weights = None
def initialize(self, context):
"""
Initialize model. This will be called during model loading time
:param context: Initial context contains model server system properties.
:return:
"""
self._context = context
# load the model, refer 'custom handler class' above for details
model_dir = context.system_properties.get("model_dir")
serialized_file = context.manifest["model"].get("serializedFile")
model_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_path):
raise RuntimeError(
f"[RBLN ERROR] File not found at the specified model_path({model_path})."
)
self.module = rebel.Runtime(model_path, tensor_type="pt")
self.weights = ResNet50_Weights.DEFAULT
self.initialized = True
def preprocess(self, data):
"""
Transform raw input into model input data.
:param batch: list of raw requests, should match batch size
:return: list of preprocessed model input data
"""
input_data = data[0].get("data")
if input_data is None:
input_data = data[0].get("body")
assert input_data is not None, print(
"[RBLN][ERROR] Data not found with client request."
)
if not isinstance(input_data, (bytes, bytearray)):
raise ValueError("[RBLN][ERROR] Preprocessed data is not binary data.")
try:
image = Image.open(io.BytesIO(input_data))
except Exception as e:
raise ValueError(f"[RBLN][ERROR]Invalid image data: {e}")
prep = self.weights.transforms()
batch = prep(image).unsqueeze(0)
preprocessed_data = batch.numpy()
return torch.from_numpy(preprocessed_data)
def inference(self, model_input):
"""
Internal inference methods
:param model_input: transformed model input data
:return: list of inference output in NDArray
"""
model_output = self.module.run(model_input)
return model_output
def postprocess(self, inference_output):
"""
Return inference result.
:param inference_output: list of inference output
:return: list of predict results
"""
score, class_id = torch.topk(inference_output, 1, dim=1)
category_name = self.weights.meta["categories"][class_id]
return category_name
def handle(self, data, context):
"""
Invoke by TorchServe for prediction request.
Do pre-processing of data, prediction using model and postprocessing of prediciton output
:param data: Input data for prediction
:param context: Initial context contains model server system properties.
:return: prediction output
"""
model_input = self.preprocess(data)
model_output = self.inference(model_input)
category_name = self.postprocess(model_output)
print("[RBLN][INFO] Top1 category: ", category_name)
return [{"result": category_name}]
|
모델 서빙 설정 작성
아래와 같이 TorchServe로 ResNet50
모델을 서빙하기 위한 설정인 config.properties
파일을 작성합니다. 이 파일은 모델의 서빙에 필요한 정보를 담을 수 있으며, 이 예제에서는 한 개의 Worker로 제한하기 위해 default_workers_per_model
을 1로 설정합니다.
config.properties |
---|
| default_workers_per_model:1
models={\
"resnet50":{\
"1.0":{\
"marName": "resnet50.mar",\
"responseTimeout": 120\
}\
}\
}
|
torch-model-archiver
를 이용한 모델 아카이빙
아래와 같이 모델 아카이브(.mar
) 파일이 저장될 경로인 model_store
디렉토리를 생성합니다. 이 디렉토리에 ResNet50
모델 아카이브 파일이 저장됩니다.
이제 모델 아카이브 파일을 만들기 위해 필요한 내용이 모두 준비되었습니다. torch-model-archiver
도구를 이용해 모델 아카이브 파일을 만들 수 있습니다. 생성된 모델 아카이브 파일이 위치한 model_store
디렉토리는 TorchServe를 실행할 때 파라미터로 전달됩니다.
| $ torch-model-archiver \
--model-name resnet50 \
--version 1.0 \
--serialized-file ./resnet50.rbln \
--handler ./resnet50_handler.py \
--export-path ./model_store/
|
사용된 옵션은 아래와 같습니다.
--model-name
: 서빙할 모델이름으로 resnet50
으로 설정합니다.
--version
: TorchServe로 서빙할 모델에 대한 버전입니다.
--serialized-file
: 가중치 파일을 지정하는데 사용합니다. 여기서는 resnet50.rbln
파일을 지정합니다.
--handler
: 요청 모델에 대한 Handler Script를 지정하는 옵션이며, 위에서 작성한 resnet50_handler.py
를 지정해줍니다.
--export-path
: 아카이빙 결과물 저장 경로를 설정하는 옵션으로, 위에서 생성한 model_store
디렉토리로 설정합니다.
실행을 하면 --export-path
에 설정한 model_store
경로에 resnet50.mar
파일이 생성됩니다.
| +--(YOUR_PATH)/
| +--model_store/
| | +--resnet50.mar
| +--resnet50.rbln
| +--resnet50_handler.py
| +--config.properties
|
torchserve
실행
torchserve
명령어로 아래와 같이 서빙을 시작할 수 있습니다. TorchServe 환경에서 ResNet50
모델 서빙에 대한 간단한 테스트에 대한 목적이므로, --disable-token-auth
옵션으로 토큰 인증을 비활성화 합니다.
| $ torchserve --start --ncs \
--ts-config ./config.properties \
--model-store ./model_store \
--models resnet50=resnet50.mar \
--disable-token-auth
|
--start
: 모델 서빙을 시작합니다.
--ncs
: --no-config-snapshots
옵션입니다.
--ts-config
: torchserve
로 전달되는 모델 서빙에 대한 설정 파일을 지정하는 옵션입니다. config.properties
로 지정합니다.
--model-store
: 모델을 로드하거나 기본적으로 로드할 모델들의 경로를 지정합니다.
--models
: 서빙할 모델을 지정합니다. all
을 지정하면, model_store
경로 내 모든 모델이 서빙할 모델로 지정됩니다.
--disable-token-auth
: Token authorization을 비활성화 합니다.
TorchServe가 정상적으로 실행되면 백그라운드에서 동작합니다. TorchServe의 동작을 중지하기 위한 명령어는 아래와 같습니다.
TorchServe는 기본 설정으로 Management API는 8081
번 포트를, Inference API는 8080
번 포트를 사용합니다.
아래의 Management API를 통해서 서빙되고 있는 모델 리스트를 확인할 수 있습니다.
| $ curl -X GET "http://localhost:8081/models"
|
정상적으로 동작할 경우 ResNet50
모델이 서빙되고 있음을 확인할 수 있습니다.
| {
"models": [
{
"modelName": "resnet50",
"modelUrl": "resnet50.mar"
}
]
}
|
TorchServe Inference API
기반 추론 요청
torchserve
로 서빙된 ResNet50 모델을 테스트하기 위해 TorchServe Inference API의 Prediction API를 이용하여 추론 요청을 합니다.
ResNet50
추론 요청에 이용할 샘플 이미지를 다운로드 합니다.
| $ wget https://rbln-public.s3.ap-northeast-2.amazonaws.com/images/tabby.jpg
|
curl을 이용하여 TorchServe Inference API
기반 추론 요청을 요청합니다.
| $ curl -X POST "http://127.0.0.1:8080/predictions/resnet50" \
-H "Content-Type: application/octet-stream" \
--data-binary @./tabby.jpg
|
정상적으로 동작할 경우 아래와 같이 응답이 확인됩니다.
고급 기능
배치 추론
(Batch Inference)
TorchServe는 여러 추론 요청을 모아 한 번에 처리하는 배치 추론
기능을 지원합니다.
배치 추론
설정
TorchServe에서 배치 추론
을 사용하기 위해 모델 설정에서 아래 2개 옵션을 지정해야 합니다.
batchSize
: 모델이 사용 가능한 최대 배치 사이즈
입니다.
maxBatchDeplay
: TorchServe에서 batchSize
만큼의 요청을 기다릴 수 있는 최대 지연 시간(ms)이며, 수신된 요청이 최대 배치 사이즈
만큼 수신되지 못한 채 설정된 지연시간을 초과하면, 현재까지 수신된 모든 요청을 Handler로 전달합니다.
config.properties
에 각 배치 설정을 batchSize
, maxBatchDelay
키 값으로 아래와 같이 설정합니다.
config_b4.properties |
---|
| default_workers_per_model=1
models={\
"resnet50":{\
"1.0":{\
"marName": "resnet50.mar",\
"batchSize": 4,\
"maxBatchDelay": 100,\
"responseTimeout": 120\
}\
}\
}
|
모델 컴파일
RBLN 컴파일러는 다양한 입력 형상에 대해 모델을 여러 차례 컴파일하는 버케팅(Bucketing)
모델을 지원합니다. 이 모델을 배치 추론
에 활용하여 추론 성능을 최적화할 수 있습니다.
아래는 배치 범위 1-4
에 대해 버케팅
모델을 생성하는 예시 코드입니다:
| size = 224 # 이미지의 넓이와 높이
batches = [1, 2, 3, 4] # 지원되는 배치의 크기
input_infos = []
# 개별 배치 크기에 대한 입력 정보를 생성
for i, batch in enumerate(batches):
input_info = [("x", [batch, 3, size, size], "float32")]
input_infos.append(input_info)
# 미리 정의된 입력 정보를 기반으로 모델 컴파일
compiled_model = rebel.compile_from_torch(model, input_info=input_infos)
# 컴파일된 모델 저장
compiled_model.save("resnet50.rbln")
|
컴파일 된 모델을 저장할 때 파일 이름은 모델 핸들러
에서 불러오기 위해서 torch-model-archiver
의 --serialized-file
파라미터와 일치해야 합니다.
모델 핸들러
모델 핸들러에서 특정 배치 크기만큼 런타임을 생성하고, 생성된 런타임을 통해 제공된 입력 데이터 기반의 추론 연산을 정의할 수 있습니다.
resnet50_batch_handler.py |
---|
| # resnet50_batch_handler.py
import io
import os
import numpy as np
import PIL.Image as Image
import rebel # RBLN Runtime
import torch
from torchvision.models import ResNet50_Weights
from ts.torch_handler.base_handler import BaseHandler
class Resnet50Handler(BaseHandler):
def __init__(self):
self._context = None
self.initialized = False
self.model = None
self.weights = None
self.prep = None
self.batch_size = None
self.max_batch_size = None
def initialize(self, context):
"""
Initialize model. This will be called during model loading time
:param context: Initial context contains model server system properties.
:return:
"""
self._context = context
model_dir = context.system_properties.get("model_dir")
serialized_file = context.manifest["model"].get("serializedFile")
self.max_batch_size = context.system_properties["batch_size"]
model_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_path):
raise RuntimeError(
f"[RBLN ERROR] File not found at the specified model_path({model_path})."
)
self.modules = []
compiled_model = rebel.RBLNCompiledModel(model_path)
for i in range(self.max_batch_size):
self.modules.append(compiled_model.create_runtime(input_info_index=i, tensor_type="pt"))
self.weights = ResNet50_Weights.DEFAULT
self.prep = self.weights.transforms()
self.initialized = True
def preprocess(self, data):
"""
Transform raw input into model input data.
:param batch: list of raw requests, should match batch size
:return: list of preprocessed model input data
"""
# Take the input data and make it inference ready
self.batch_size = num_requests = len(data)
assert self.batch_size <= self.max_batch_size, print(
f"[RBLN][ERROR] Inputed batched number({self.batch_size})"
f" is over the batchSize({self.max_batch_size}) in configuration."
)
images = []
for i in range(num_requests):
input_data = data[i].get("data")
if input_data is None:
input_data = data[i].get("body")
assert input_data is not None, print(
"[RBLN][ERROR] Data not found with client request."
)
if not isinstance(input_data, (bytes, bytearray)):
raise ValueError("[RBLN][ERROR] Preprocessed data is not binary data.")
try:
image = Image.open(io.BytesIO(input_data))
except Exception as e:
raise ValueError(f"[RBLN][ERROR]Invalid image data: {e}")
batch = self.prep(image).unsqueeze(0)
images.append(batch.numpy())
preprocessed_data = np.concatenate(images, axis=0).copy()
return torch.from_numpy(preprocessed_data)
def inference(self, model_input):
"""
Internal inference methods
:param model_input: transformed model input data
:return: list of inference output in NDArray
"""
model_output = self.modules[self.batch_size - 1].run(model_input)
return model_output
def postprocess(self, inference_output):
"""
Return inference result.
:param inference_output: list of inference output
:return: list of predict results
"""
category_names = []
chunky_batched_result = np.array_split(
inference_output, self.batch_size, axis=0
)
for result in chunky_batched_result:
score, class_id = torch.topk(result, 1, dim=1)
category_names.append(self.weights.meta["categories"][class_id])
return category_names
def handle(self, data, context):
"""
Invoke by TorchServe for prediction request.
Do pre-processing of data, prediction using model and postprocessing of prediciton output
:param data: Input data for prediction
:param context: Initial context contains model server system properties.
:return: prediction output
"""
model_input = self.preprocess(data)
model_output = self.inference(model_input)
category_names = self.postprocess(model_output)
results = []
for idx, category_name in enumerate(category_names):
print("[RBLN][INFO][", idx, "] Top1 category: ", category_name)
results.append(f"result[{idx}] : {category_name}")
return results
|
모델 서빙
위에서 생성한 설정
, 모델
, 모델 핸들러
를 이용하여 "torch-model-archiver를 이용한 모델 아카이빙
", "torchserve 실행
" 방법과 동일하게 진행하여 모델 서빙을 시작합니다.
아래의 Management API
를 통해 설정이 정상적으로 적용되었는지 확인할 수 있습니다.
| $ curl -X GET "http://localhost:8081/models/resnet50"
|
응답에서 batchSize
와 maxBatchDelay
가 지정된 값으로 설정되었는지 확인할 수 있습니다.
| [
{
"modelName": "resnet50",
"modelVersion": "1.0",
"modelUrl": "resnet50.mar",
"runtime": "python",
"minWorkers": 1,
"maxWorkers": 1,
"batchSize": 4,
"maxBatchDelay": 100,
:
:
"workers": [
{
:
:
}
],
:
:
}
]
|