Skip to content

Text Generation

This tutorial explains how to deploy a Llama3-8b model using the RBLN SDK C/C++ Runtime API. The model is compiled using the RBLN SDK Python API, and the resulting *.rbln file is deployed using the RBLN SDK C/C++ Runtime API.

This approach combines the ease of model preparation in Python with the performance benefits of C/C++ for inference.

The tutorial is divided into two parts:

  1. How to compile the Llama3-8b model with Python API
  2. How to perform inference the compiled model with C/C++ Runtime API

Prerequisites

Before starting, ensure you have installed the following packages:

Step 1. Compiling the Model

The RBLN Python API handles both compilation and inference, while the RBLN SDK C/C++ Runtime API is optimized solely for inference.

Here, we use RBLN Python API for model compilation and RBLN SDK C/C++ Runtime API for inference.

Compile the Model

Import the RBLNLlamaForCausalLM class from optimum-rbln. Then, use the from_pretrained() method to download the Llama3-8b model from HuggingFace Hub and compile it.

The compiled model can be saved to disk using the model.save_pretrained() method. The compiled model files, prefill.rbln and decoder.rbln, will be located in the Meta-Llama-3-8B-Instruct directory.

# compile.py

import os
from optimum.rbln import RBLNLlamaForCausalLM

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model = RBLNLlamaForCausalLM.from_pretrained(
    model_id=model_id,
    export=True,
    rbln_batch_size=1,
    rbln_max_seq_len=8192,
    rbln_tensor_parallel_size=4,
)

# Save the compiled model
model.save_pretrained(os.path.basename(model_id))

Generate Input Data

Tokenize input using AutoTokenizer from the transformers library package.

Note

This tutorial demonstrates how to use RBLN SDK C/C++ Runtime API for Llama3-8b. This example focuses on C/C++ based inference, so the pre- and post-processing, i.e. tokenization, are handled by Python APIs.

Input binary file, c_input_ids.bin can be generated by the following Python script.

# pre_process.py

from transformers import AutoTokenizer

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
input_text = "Hey, are you conscious? Can you talk to me?"
batch_size = 1

# Prepare inputs
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
conversation = [[{"role": "user", "content": input_text}]] * batch_size
text = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(text, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.numpy()

# Save the generated input binary files
input_ids.tofile("c_input_ids.bin")

After the input data generation is successfully completed, you can check the generated input binary file c_input_ids.bin.

Step 2. Perform Inference with the RBLN SDK C/C++ Runtime API

Now, we can deploy the model using the RBLN SDK C/C++ Runtime API. This involves loading the model, running inference, and decoding the output.

Note

${YOUR_SAMPLE_PATH} refers to the directory containing the CMake file and inference code.

Prepare CMake Build Script

The following CMake script describes the dependencies on external packages and how to link them with our example application code.

# CMakeLists.txt

cmake_minimum_required(VERSION 3.26)

# Collect all source files
file(GLOB SOURCE_FILES "*.cc")

# Define executable
add_executable(llama_binding llama_main.cc ${SOURCE_FILES})

# Link RBLN runtime
find_package(rbln CONFIG REQUIRED)
target_link_libraries(llama_binding rbln::rbln_runtime)

# Add header files directory
target_include_directories(llama_binding PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)

Prepare Code for Inference

Here's the main implementations for inference, and the details for each implementation are as follows:

  • llama_main.cc: Inference application that handles model loading and execution workflow
  • llama_tensor_example.hpp: Custom tensor operators and data structure implementations
  • llama_class_example.hpp/cc: Inference wrapper using RBLN SDK C/C++ Runtime API, including:
    • Prefill and decode stage management of Llama3-8b
    • Input/output buffer handling
    • Model execution flow control
  • llama_tensor_op_example.hpp: A set of tensor manipulation operations
llama_main.cc
#include "llama_class_example.hpp"

int main() {
  LLamaClass llama_cls;
  // Init Model configuration
  llama_cls.InitConfig();

  // Create Model & Runtime
  llama_cls.Prepare();

  // Init LLamaClass
  llama_cls.Init();

  auto input_ids = Tensor<int64_t>(1, 23);
  assert(LoadBinary<int64_t>(llama_cls.GetIdsPath(), input_ids) == true);

  auto past_cached_length = Tensor<int32_t>();
  llama_cls.PrepareInput(input_ids, past_cached_length);

  // Process of Prefill phase
  llama_cls.DoPrefill();

  // Process of Decode phase
  llama_cls.DoDecode();

  // Generate c_text2text_generation_gen_id.bin
  llama_cls.GenerateBinary();

  // Reset LLamaClass for iteration
  llama_cls.Reset();

  // Deinit LLamaClass
  llama_cls.DeInit();

  return 0;
}
llama_class_example.hpp
#ifndef RBLN_LLAMA_H
#define RBLN_LLAMA_H

#include <assert.h>
#include <rbln/rbln.h>

#include <cstring>
#include <fstream>
#include <iostream>
#include <numeric>
#include <string>
#include <vector>

#include "llama_tensor_example.hpp"
#include "llama_tensor_op_example.hpp"

constexpr uint32_t kDecodeInputCount = 3;
constexpr uint32_t kPrefillInputCount = 4;

template <typename T>
bool LoadBinary(const std::string &filename, Tensor<T> &data) {
  std::ifstream file(filename, std::ios::binary);
  if (!file.is_open()) {
    std::cout << "Could not open file: " + filename << std::endl;
    return false;
  }

  file.seekg(0, std::ios::end);
  const size_t fileSize = file.tellg();
  file.seekg(0, std::ios::beg);

  if (fileSize % sizeof(T) != 0) {
    std::cout << "File size(" << fileSize << ") is not a multiple of data type size(" 
          << sizeof(T) << ")" << std::endl;
    return false;
  }

  file.read(const_cast<char *>(static_cast<const char *>(data.GetData())),
            fileSize);
  if (file.fail()) {
    std::cout << "Failed to read file: " << filename << std::endl;
    return false;
  }
  return true;
}

int WriteToFile(const std::string &filePath, const void *data,
                uint32_t data_len);

class LLamaClass {
public:
  LLamaClass() = default;
  ~LLamaClass() = default;

  // Init Model configuration
  void InitConfig() {
    prefill_id_ = "${YOUR_SAMPLE_PATH}/Meta-Llama-3-8B-Instruct/prefill.rbln";
    dec_id_ = "${YOUR_SAMPLE_PATH}/Meta-Llama-3-8B-Instruct/decoder.rbln";
    input_ids_path_ = "${YOUR_SAMPLE_PATH}/c_input_ids.bin";
    batch_size_ = 1;
    max_seq_len_ = 8192;
    prefill_chunk_size_ = 128;
  }

  // Init LLamaClass
  void Init();

  // Reset LLamaClass for iteration
  void Reset();

  // Deinit LLamaClass
  void DeInit();

  // Create Model & Runtime
  void Prepare();

  // Process of Prefill phase
  void DoPrefill();

  // Process of Decode phase
  void DoDecode();

  // Generate c_text2text_generation_gen_id.bin
  void GenerateBinary();

  template <typename T0, typename T1>
  void PrepareInput(Tensor<T0> &input_ids, Tensor<T1> &v0) {
    if (!v0.GetSize()) {
      auto input_tensors = input_ids;
      auto batch_size = input_tensors.GetRows();
      std::vector<Tensor<int64_t>> l_input_tensors;
      std::vector<Tensor<int32_t>> cache_positions;
      auto past_cached_length = Tensor<int32_t>(batch_size, 1);

      for (int i = 0; i < batch_size; i++) {
        auto input_tensor =
            tensor_ops::Reshape(input_tensors, input_tensors.GetCols());

        auto valid_len = input_tensor.GetCols();
        auto cache_position = Tensor<int32_t>();
        tensor_ops::Arange(cache_position, 0, valid_len);
        tensor_ops::Reshape(cache_position, 1, valid_len);

        past_cached_length[i] = valid_len;
        l_input_tensors.emplace_back(tensor_ops::UnSqueeze(input_tensor));
        cache_positions.emplace_back(tensor_ops::UnSqueeze(cache_position));
      }
      mdl_input_ = ModelInput{l_input_tensors[0], cache_positions[0],
                              past_cached_length};
    } else {
      auto input_tensor = tensor_ops::SelectLastColumn(input_ids);
      auto cache_positions = v0;
      auto past_cached_length = v0 + 1;
      mdl_input_ =
          ModelInput{input_tensor, cache_positions, past_cached_length};
    }
  }

  const std::string &GetIdsPath() { return input_ids_path_; }

  RBLNModel *prefill_mdl_;
  RBLNModel *dec_mdl_;
  RBLNRuntime *prefill_rt_;
  RBLNRuntime *dec_rt_;

private:
  void ForwardPrefill();
  void ForwardDecode();

  typedef struct {
    Tensor<int64_t> input_ids;
    Tensor<int32_t> cache_position;
    Tensor<int32_t> past_cached_length;
  } ModelInput;

  ModelInput mdl_input_;
  int max_seq_len_;
  int batch_size_;
  int prefill_chunk_size_;
  bool unfinished_sequences_;

  std::string prefill_id_;
  std::string dec_id_;
  std::string input_ids_path_;

  Tensor<float> output_logits_;
  Tensor<int64_t> input_ids_;
};

#endif
llama_class_example.cc
#include "llama_class_example.hpp"
#include <iostream>

int WriteToFile(const std::string &filePath, const void *data,
                uint32_t data_len) {
  std::ofstream fout;
  fout.open(filePath, std::ios::out | std::ios::binary);
  if (fout.is_open()) {
    fout.write((const char *)data, data_len);
    fout.close();
    return 1;
  }
  return 0;
}

// Prefill forward method
void LLamaClass::ForwardPrefill() {
  // Get input tensors and cache position
  auto input_tensors = mdl_input_.input_ids;
  auto cache_position = mdl_input_.cache_position;
  // Get query length (number of tokens in the input sequence)
  int query_length = input_tensors.GetCols();

  // Process input in chunks (divided into chunks of size prefill_chunk_size)
  for (auto step = 0; step < query_length; step += prefill_chunk_size_) {
    // If the last chunk is incomplete (remaining tokens less than chunk size)
    if ((step + prefill_chunk_size_) > query_length) {
      // Calculate and add necessary padding to the input tensor
      int padding_needed = step + prefill_chunk_size_ - query_length;
      input_tensors = tensor_ops::Pad(input_tensors, 0, padding_needed);

      // Extend cache positions (concatenate current cache positions with additional range)
      auto new_cache_position = tensor_ops::ConcatenateWithRange(
          cache_position, query_length, step + prefill_chunk_size_);

      // Slice input tensors and cache positions for the current chunk
      auto sliced_input_tensors = tensor_ops::VerticalSlicing(
          input_tensors, step, step + prefill_chunk_size_);
      auto sliced_cache_positions = tensor_ops::VerticalSlicing(
          new_cache_position, step, step + prefill_chunk_size_);

      // Create query index and empty block tables(with value 0) for KV-cache management
      Tensor<int> query_idx(query_length % prefill_chunk_size_ - 1);
      Tensor<int16_t> block_tables(0);

      // Check if prefill input count exceeds expected limit
      if (rbln_get_num_inputs(prefill_rt_) > kPrefillInputCount) {
        throw std::runtime_error(
            "You appear to be running on ATOM(RBLN-CA02). RSD is only "
            "available on ATOM+(RBLN-CA12). Check your NPU type with "
            "'rbln-stat' command.");
      }

      // Set inputs for the model runtime
      rbln_set_input(prefill_rt_, 0, sliced_input_tensors.GetData());
      rbln_set_input(prefill_rt_, 1, sliced_cache_positions.GetData());
      rbln_set_input(prefill_rt_, 2, query_idx.GetData());
      rbln_set_input(prefill_rt_, 3, block_tables.GetData());

      // Run the model
      rbln_run(prefill_rt_);

      // Get output logits and convert to tensor
      void *logit = static_cast<float *>(rbln_get_output(prefill_rt_, 0));
      auto layout = rbln_get_output_layout(prefill_rt_, 0);
      output_logits_ = Tensor<float>(logit, layout->shape[1], layout->shape[2]);
    }
  }

  // Predict the next token from logits using Argmax
  auto next_tokens = tensor_ops::GetArgmax<float, int64_t>(output_logits_);

  // Concatenate existing input IDs with the predicted next token
  input_ids_ = tensor_ops::Concatenate(mdl_input_.input_ids, next_tokens);
}

// Decoder forward method
void LLamaClass::ForwardDecode() {
  // Get input tensors for decoding from prefill step
  auto dec_input_tensors = mdl_input_.input_ids;
  // Get batch size from the number of rows in input tensors
  auto dec_batch_size = dec_input_tensors.GetRows();
  // Get cache positions for decoding from prefill step
  auto dec_cache_position = mdl_input_.cache_position;

  // For each item in the batch
  for (auto b_idx = 0; b_idx < dec_batch_size; b_idx++) {
    // Get the current decoding step
    auto decoding_step = dec_cache_position[b_idx];
  }

  // Initialize block tables for KV-cache management with shape of (batch_size, 1)
  Tensor<int16_t> block_tables(batch_size_, 1);

  // Check if decoder input count exceeds expected limit
  if (rbln_get_num_inputs(dec_rt_) > kDecodeInputCount) {
    throw std::runtime_error(
        "You appear to be running on ATOM(RBLN-CA02). RSD is only available on "
        "ATOM+(RBLN-CA12). Check your NPU type with 'rbln-stat' command.");
  }

  // Set inputs for decoder runtime
  rbln_set_input(dec_rt_, 0, dec_input_tensors.GetData());  
  rbln_set_input(dec_rt_, 1, dec_cache_position.GetData());
  rbln_set_input(dec_rt_, 2, block_tables.GetData());

  // Run the decoder
  rbln_run(dec_rt_);

  // Get output logits from the decoder
  float *dec_logit = static_cast<float *>(rbln_get_output(dec_rt_, 0));
  auto dec_layout = rbln_get_output_layout(dec_rt_, 0);
  // Convert output to tensor format
  output_logits_ =
      Tensor<float>(dec_logit, dec_layout->shape[1], dec_layout->shape[2]);
}

// Prefill forward wrapper
void LLamaClass::DoPrefill() {
  // Run the prefill phase to process the input sequence and generate the next token
  ForwardPrefill();
}

// Decoder forward wrapper
void LLamaClass::DoDecode() {
  while (unfinished_sequences_) {
    // Prepare input for the model with current token IDs and past cache info
    PrepareInput(input_ids_, mdl_input_.past_cached_length);
    // Run the decoder to get the next token logits
    ForwardDecode();
    // Get the next token using Argmax
    auto dec_next_tokens =
        tensor_ops::GetArgmax<float, int64_t>(output_logits_);
    // Append/Concatenate the new tokens to the existing sequence
    input_ids_ = tensor_ops::Concatenate(input_ids_, dec_next_tokens);

    auto stopping_criteria = [](const auto &array) -> bool {
      const int32_t eos_token_id = 128009;
      // Stop generation if EOS token is found at the last position
      if (array(0, array.GetCols() - 1) == eos_token_id)
        return false;
      return true;
    };
    unfinished_sequences_ = stopping_criteria(input_ids_);
  }
}

void LLamaClass::Init() {
  unfinished_sequences_ = true;
}

void LLamaClass::Reset() {
  output_logits_ = Tensor<float>();
  input_ids_ = Tensor<int64_t>();
}

void LLamaClass::DeInit() {
  // Destroy runtime
  rbln_destroy_runtime(prefill_rt_);
  rbln_destroy_runtime(dec_rt_);
  // Destroy model
  rbln_destroy_model(prefill_mdl_);
  rbln_destroy_model(dec_mdl_);
}

void LLamaClass::Prepare() {
  // Create prefill/decoder model
  prefill_mdl_ = rbln_create_model(prefill_id_.c_str());
  dec_mdl_ = rbln_create_model(dec_id_.c_str());

  // Create prefill/decoder runtime
  prefill_rt_ = rbln_create_runtime(prefill_mdl_, nullptr, 0, 0);
  dec_rt_ = rbln_create_runtime(dec_mdl_, nullptr, 0, 0);
}

void LLamaClass::GenerateBinary() {
  if(!WriteToFile("c_text2text_generation_gen_id.bin", input_ids_.GetData(),
              input_ids_.GetSize() * sizeof(int64_t))) {
                std::cout << "Fail to save c_text2text_generation_gen_id.bin" << std::endl;
              }
}
llama_tensor_example.hpp
#ifndef RBLN_TENSOR_H
#define RBLN_TENSOR_H

#include <memory>
#include <string>
#include <vector>

template <typename T> class Tensor {
public:
  Tensor() : depth_(1), rows_(0), cols_(0) { array_.reserve(0); }
  Tensor(T val) : depth_(1), rows_(0), cols_(1) { array_.resize(1, T{val}); }
  Tensor(size_t row, size_t col) : depth_(1), rows_(row), cols_(col) {
    array_.resize(GetCapacity(), T{});
  }

  Tensor(const void *data, size_t row, size_t col)
      : depth_(1), rows_(row), cols_(col) {
    const T *ptr = static_cast<const T *>(data);
    array_.assign(ptr, ptr + GetCapacity());
  }

  Tensor(size_t depth, size_t row, size_t col)
      : depth_(depth), rows_(row), cols_(col) {
    array_.resize(GetCapacity(), T{});
  }

  ~Tensor() = default;

  Tensor(const Tensor &other) {
    array_ = other.array_;
    depth_ = other.depth_;
    rows_ = other.rows_;
    cols_ = other.cols_;
  }

  Tensor(Tensor &&other) {
    array_ = std::move(other.array_);
    depth_ = other.depth_;
    rows_ = other.rows_;
    cols_ = other.cols_;
  }

  T &operator[](size_t i) { return array_[i]; }
  T operator[](size_t i) const { return array_[i]; }

  Tensor operator=(const Tensor &other) {
    if (this != &other) {
      array_ = other.array_;
      depth_ = other.depth_;
      rows_ = other.rows_;
      cols_ = other.cols_;
    }
    return *this;
  }

  T &operator()(size_t r_idx, size_t c_idx) {
    if (r_idx >= rows_ || c_idx >= cols_) {
      throw std::out_of_range("Index out of bounds");
    }
    return array_[cols_ * r_idx + c_idx];
  }

  T &operator()(size_t col) {
    if (col >= cols_) {
      throw std::out_of_range("Index out of bounds");
    }
    return array_[col];
  }

  T operator()(size_t r_idx, size_t c_idx) const {
    if (r_idx >= rows_ || c_idx >= cols_) {
      throw std::out_of_range("Index out of bounds");
    }
    return array_[cols_ * r_idx + c_idx];
  }

  Tensor operator+(T val) {
    Tensor ret(rows_, cols_);
    for (auto r = 0; r < rows_; r++) {
      for (auto c = 0; c < cols_; c++) {
        ret(r, c) = array_[r * cols_ + c] + val;
      }
    }
    return ret;
  }

  void *GetData() { return array_.data(); }
  size_t GetRows() const { return rows_; }
  size_t GetCols() const { return cols_; }
  size_t GetDepth() const { return depth_; }
  size_t GetSize() const { return array_.size(); }
  void Ones() { std::fill(array_.begin(), array_.end(), T{1}); }
  void Zeros() { std::fill(array_.begin(), array_.end(), T{0}); }

  size_t GetCapacity(size_t r, size_t c) const {
    return std::max(1UL, r) * std::max(1UL, c);
  }

  size_t GetCapacity() const {
    return GetCapacity(rows_, cols_) * std::max(1UL, depth_);
  }

  void Resize(size_t row, size_t col) {
    rows_ = row;
    cols_ = col;
    array_.resize(GetCapacity(row, col));
  }

private:
  std::vector<T> array_;
  size_t depth_;
  size_t rows_;
  size_t cols_;
};
#endif
llama_tensor_op_example.hpp
#ifndef RBLN_LLAMA_OPS_H
#define RBLN_LLAMA_OPS_H

// Tensor operations implementation
namespace tensor_ops {

template <typename T>
Tensor<T> Reshape(const Tensor<T> &tensor, int row, int col) {
  if (tensor.GetCapacity(row, col) != tensor.GetCapacity()) {
    throw std::runtime_error("Cannot reshape: total size must remain the same");
  }

  Tensor<T> ret(tensor);
  std::vector<T> temp(tensor.GetSize());
  for (size_t i = 0; i < row; ++i) {
    for (size_t j = 0; j < col; ++j) {
      size_t new_idx = i * col + j;
      size_t old_idx = (new_idx / col) * tensor.GetCols() + (new_idx % col);
      temp[new_idx] = tensor[old_idx];
    }
  }

  for (size_t i = 0; i < temp.size(); ++i) {
    ret[i] = temp[i];
  }
  ret.Resize(row, col);
  return ret;
}

template <typename T> Tensor<T> Reshape(const Tensor<T> &tensor, int col) {
  if (col != tensor.GetCapacity()) {
    throw std::runtime_error("Cannot reshape: total size must remain the same");
  }

  Tensor<T> ret(tensor);
  ret.Resize(0, col);
  return ret;
}

template <typename T> void Arange(Tensor<T> &tensor, int start, int stop) {
  tensor.Resize(0, stop - start);
  for (size_t i = 0; i < stop - start; ++i) {
    tensor[i] = static_cast<T>(start + i);
  }
}

template <typename T> Tensor<T> UnSqueeze(const Tensor<T> &tensor) {
  Tensor<T> ret(1, tensor.GetSize());
  for (size_t i = 0; i < tensor.GetSize(); ++i) {
    ret(0, i) = tensor[i];
  }
  return ret;
}

template <typename T> Tensor<T> SelectLastColumn(const Tensor<T> &tensor) {
  Tensor<T> result(tensor.GetRows(), 1);
  size_t last_col = tensor.GetCols() - 1;

  for (size_t i = 0; i < tensor.GetRows(); ++i) {
    result(i, 0) = tensor(i, last_col);
  }
  return result;
}

template <typename T>
Tensor<T> Pad(const Tensor<T> &tensor, size_t start_pos, size_t end_pos) {
  Tensor<T> padded(tensor.GetRows(), tensor.GetCols() + start_pos + end_pos);
  for (size_t i = 0; i < tensor.GetRows(); ++i) {
    for (size_t j = 0; j < tensor.GetCols(); ++j) {
      padded(i, start_pos + j) = tensor(i, j);
    }
  }
  return padded;
}

template <typename T>
Tensor<T> VerticalSlicing(Tensor<T> &tensor, size_t start_pos, size_t end_pos) {
  Tensor<T> ret(tensor);

  std::vector<T> temp(ret.GetCapacity(ret.GetRows(), (end_pos - start_pos)));
  for (size_t i = 0; i < ret.GetRows(); ++i) {
    for (size_t j = start_pos; j < end_pos; ++j) {
      temp[i * (end_pos - start_pos) + (j - start_pos)] = ret(i, j);
    }
  }
  for (size_t i = 0; i < temp.size(); ++i) {
    ret[i] = temp[i];
  }
  return ret;
}

template <typename T>
void SetCausalMask(Tensor<T> &tensor, const Tensor<T> &mask_tensor,
                   size_t start_pos, size_t end_pos) {
  if (end_pos > tensor.GetCols()) {
    throw std::out_of_range("Index range out of bounds");
  }
  for (size_t d = 0; d < tensor.GetDepth(); ++d) {
    for (size_t r = 0; r < tensor.GetRows(); ++r) {
      size_t base_idx = (d * tensor.GetRows() + r) * tensor.GetCols();
      for (size_t idx = start_pos; idx < end_pos; ++idx) {
        tensor[base_idx + idx] = mask_tensor(r, idx - start_pos);
      }
    }
  }
}

template <typename T>
Tensor<T> ConcatenateWithRange(const Tensor<T> &tensor, size_t start_pos,
                               size_t end_pos) {
  Tensor<T> range;
  tensor_ops::Arange(range, start_pos, end_pos);

  size_t total_cols = tensor.GetCols() + range.GetSize();
  Tensor<T> result(1, total_cols);
  for (size_t i = 0; i < tensor.GetCols(); ++i) {
    result(0, i) = tensor[i];
  }
  for (size_t i = 0; i < range.GetSize(); ++i) {
    result(0, tensor.GetCols() + i) = range[i];
  }
  return result;
}

template <typename T, typename T1>
Tensor<T1> GetArgmax(const Tensor<T> &tensor) {
  Tensor<T1> next_tokens(tensor.GetRows(), 1);
  for (size_t i = 0; i < tensor.GetRows(); ++i) {
    size_t max_idx = 0;
    T max_val = tensor(i, 0);

    for (size_t j = 1; j < tensor.GetCols(); ++j) {
      if (tensor(i, j) > max_val) {
        max_val = tensor(i, j);
        max_idx = j;
      }
    }
    next_tokens(i, 0) = static_cast<T1>(max_idx);
  }
  return next_tokens;
}

template <typename T>
Tensor<T> Concatenate(const Tensor<T> &tensor, const Tensor<T> &other) {
  Tensor<T> result(tensor.GetRows(), tensor.GetCols() + 1);
  for (size_t i = 0; i < tensor.GetRows(); ++i) {
    for (size_t j = 0; j < tensor.GetCols(); ++j) {
      result(i, j) = tensor(i, j);
    }
    result(i, tensor.GetCols()) = other(i, 0);
  }
  return result;
}

template <typename T>
void SetMaskAtPos(Tensor<T> &tensor, size_t pos, T value) {
  if (pos >= tensor.GetCols()) {
    throw std::out_of_range("Index out of bounds");
  }

  for (size_t i = 0; i < tensor.GetRows(); ++i) {
    tensor(i, pos) = value;
  }
}

template <typename T>
void SetMaskUpToPos(Tensor<T> &tensor, size_t batch_idx, size_t pos, T value) {
  if (pos > tensor.GetCols()) {
    throw std::out_of_range("Index out of bounds");
  }

  for (size_t r = 0; r < tensor.GetRows(); ++r) {
    for (size_t i = 0; i < pos; ++i) {
      tensor(r, i) = value;
    }
  }
}

template <typename T> Tensor<T> TriuMask(size_t row, size_t col) {
  Tensor<T> mask(row, col);
  mask.Ones();

  for (size_t i = 0; i < row; ++i) {
    for (size_t j = i + 1; j < col; ++j) {
      mask(i, j) = 0;
    }
  }
  return mask;
}

template <typename T>
Tensor<T> FilterByMask(const Tensor<T> &tensor, const Tensor<T> &mask,
                       size_t i) {
  size_t count = 0;
  for (size_t j = 0; j < mask.GetCols(); ++j) {
    if (mask(i, j) == 1) {
      count++;
    }
  }

  Tensor<T> result(1, count);
  size_t idx = 0;
  for (size_t j = 0; j < mask.GetCols(); ++j) {
    if (mask(i, j) == 1) {
      result[idx++] = tensor[j];
    }
  }
  return result;
}

} // namespace tensor_ops

#endif

Build with CMake

1
2
3
4
mkdir ${YOUR_SAMPLE_PATH}/build
cd ${YOUR_SAMPLE_PATH}/build
cmake ..
make

Run the Executable

After following all the steps above, the compiled executable binary, llama_binding, will be located in the build directory.

./llama_binding

Executing the binary will generate a c_text2text_generation_gen_id.bin file in local storage. This file contains the Token ID sequence data generated by the Llama3-8b decoder. You can use the Python code below to decode it into recognizable text.

Generate Text from Output Data

# post_process.py

import numpy as np
import torch
from transformers import AutoTokenizer

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
input_text = "Hey, are you conscious? Can you talk to me?"
batch_size = 1

# Prepare inputs
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
conversation = [[{"role": "user", "content": input_text}]] * batch_size
text = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(text, return_tensors="pt", padding=True)
input_ids = inputs.input_ids
input_len = inputs.input_ids.shape[-1]

# Data decoding
output_sequence = torch.tensor(np.fromfile("c_text2text_generation_gen_id.bin", dtype=np.int64), dtype=torch.int64)
generated_texts = tokenizer.decode(
    output_sequence[input_len:], skip_special_tokens=True, clean_up_tokenization_spaces=True
)

print("--- input text ---")
print(input_text)
print("--- Decoded C Result ---")
print(generated_texts)

The result will look like this:

1
2
3
4
5
6
7
8
--- input text ---
Hey, are you conscious? Can you talk to me?
--- Decoded C Result ---
Hello! I'm an AI, which means I'm a computer program designed to simulate conversation and answer questions to the best of my ability. I don't have consciousness in the way that humans do, but I'm designed to be very responsive and interactive.

I can understand and respond to language, and I can even learn and improve over time based on the conversations I have with users like you. So, in a sense, I'm "awake" and ready to chat with you!

What would you like to talk about? Do you have a specific question or topic in mind, or do you just want to chat about something random? I'm here to listen and help if I can!