Skip to content

PyTorch BERT-base

Overview

In this tutorial, we demonstrate how to compile and run inference with the Hugging Face BERT-base model for masked language modeling using both the RBLN Python API and PyTorch's torch.compile() API.

Setup & Installation

Before you begin, ensure that your system environment is properly configured and that all required packages are installed. This includes:

Note

RBLN SDK is distributed as a .whl package. Please note that rebel-compiler and rbln_driver require an RBLN Portal account.

Using RBLN Python API

Model Compilation

Load the pre-trained BERT-base model from Hugging Face, set it to evaluation mode, compile it using the RBLN compiler, and save the compiled model to local storage.

import torch  
from transformers import BertForMaskedLM  
import rebel  # RBLN Compiler  

# Instantiate the BERT-base model  
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=False)  
bert_model.eval()  

# Compile the model  
MAX_SEQ_LEN = 128  
input_info = [  
    ('input_ids', [1, MAX_SEQ_LEN], 'int64'),  
    ('attention_mask', [1, MAX_SEQ_LEN], 'int64'),  
    ('token_type_ids', [1, MAX_SEQ_LEN], 'int64'),  
]  
compiled_model = rebel.compile_from_torch(  
    bert_model,  
    input_info,  
)  
compiled_model.save('bert_base.rbln')  

Model Inference and Inference

Prepare the input using BertTokenizer, load the compiled model with RBLN Runtime, run inference, and display the predicted word for the masked token.

import torch  
from transformers import BertTokenizer, pipeline  
import rebel  # RBLN Runtime  

MAX_SEQ_LEN = 128  
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')  
text = 'the color of rose is [MASK].'  
inputs = tokenizer(text, return_tensors='pt', padding='max_length', max_length=MAX_SEQ_LEN)  

# Load the compiled model  
module = rebel.Runtime('bert_base.rbln', tensor_type='pt')  

# Run inference  
output = module(**inputs)  

# Display results  
unmasker = pipeline('fill-mask', model='bert-base-uncased', framework='pt')  
print(unmasker.postprocess({'input_ids': inputs.input_ids, 'logits': output}))  

The results will look like this:

[{'score': 0.23419028520584106, 'token': 2317, 'token_str': 'white', 'sequence': 'the color of rose is white.'}, {'score': 0.1072201207280159, 'token': 2417, 'token_str': 'red', 'sequence': 'the color of rose is red.'}, {'score': 0.07844392210245132, 'token': 2304, 'token_str': 'black', 'sequence': 'the color of rose is black.'}, {'score': 0.07031667977571487, 'token': 3756, 'token_str': 'yellow', 'sequence': 'the color of rose is yellow.'}, {'score': 0.051444780081510544, 'token': 2630, 'token_str': 'blue', 'sequence': 'the color of rose is blue.'}]

Using torch.compile() API

Model and Input Preparation

Load the BERT-base model and prepare the input using BertTokenizer.

import torch  
from transformers import BertForMaskedLM, BertTokenizer  

if torch.__version__ >= '2.5.0':  
    torch._dynamo.config.inline_inbuilt_nn_modules = False  

# Instantiate the model  
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')  
bert_model.eval()  

MAX_SEQ_LEN = 128  
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')  
text = 'The color of a rose is [MASK].'  
inputs = tokenizer(text, return_tensors='pt', padding='max_length', max_length=MAX_SEQ_LEN)  

Compilation and Inference

Compile the model using torch.compile() with the RBLN backend, trigger compilation on the first forward pass, and run inference on the input.

Note

To use the RBLN backend with torch.compile(), you must specify backend='rbln'. This requires import rebel before compilation.

import rebel  # Need to use torch dynamo's "rbln" backend.

compiled_model = torch.compile(  
    bert_model,  
    backend='rbln',  
    options={'cache_dir': './.cache'},  
    dynamic=False  
)  

# Trigger compilation with the first forward pass  
compiled_model(**inputs)  

# Run inference  
logits = compiled_model(**inputs).logits  

# Get the mask token index  
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]  

# Get the predicted token ID  
predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)  

# Print the predicted word  
print(f'Predicted word: {tokenizer.decode(predicted_token_id)}')  

The results will look like this:

Predicted word: white

Summary and References

This tutorial demonstrated how to compile and inference the Hugging Face BERT-base model using both the RBLN Python API and PyTorch's torch.compile() API. The compiled model can be efficiently inferenced on an RBLN NPU for masked language modeling.

References: