Extending Kompute with Custom C++ Operations

Kompute provides an extenisble architecture which allows for the core components to be extended by building custom operations.

Building operations is intuitive however it requires knowing some nuances around the order in which each of the class functions across the operation are called as a sequence is executed.

These nuances are important for more advanced users of Kompute, as this will provide further intuition in what are the specific functions and components that the native functions (like OpTensorCreate, OpAlgoBase, etc) contain which define their specific behaviour.

Flow of Function Calls

The top level operation which all operations inherit from is the kp::OpBase class. Some of the “Core Native Operations” like kp::OpTensorCopy, kp::OpTensorCreate, etc all inherit from the base operation class.

The kp::OpAlgoBase is another base operation that is specifically built to enable users to create their own operations that contain custom shader logic (i.e. requiring Compute Pipelines, DescriptorSets, etc). The next section contains an example which shows how to extend the OpAlgoBase class.

Below you

Function

Description

OpBase(…, tensors, freeTensors)

Constructor for class where you can load/define resources such as shaders, etc.

~OpBase()

Destructor that frees GPU resources (if owned) which should be used to manage any memory allocations created through the operation.

init()

Init function gets called in the Sequence / Manager inside the record step. This function allows for relevant objects to be initialised within the operation.

record()

Record function that gets called in the Sequence / Manager inside the record step after init(). In this function you can directly record to the vk::CommandBuffer.

preEval()

When the Sequence is Evaluated this preEval is called across all operations before dispatching the batch of recorded commands to the GPU. This is useful for example if you need to copy data from local to host memory.

postEval()

After the sequence is Evaluated this postEval is called across all operations. When running asynchronously the postEval is called when you call evalAwait(), which is why it’s important to always run evalAwait() to ensure the process doesn’t go into inconsistent state.

Simple Operation Extending OpAlgoBase

You can find an example in the Advanced Examples documentation section that shows how to create your own custom function.

You can also see an implementation in the codebase through the OpMult class:

// SPDX-License-Identifier: Apache-2.0
#pragma once

#include <fstream>

#include "kompute/Core.hpp"

#include "ShaderOpMult.hpp"

#include "kompute/Algorithm.hpp"
#include "kompute/Tensor.hpp"

#include "kompute/operations/OpAlgoDispatch.hpp"

namespace kp {

/**
 * Operation that performs multiplication on two tensors and outpus on third
 * tensor.
 */
class OpMult : public OpAlgoDispatch
{
  public:
    /**
     * Default constructor with parameters that provides the bare minimum
     * requirements for the operations to be able to create and manage their
     * sub-components.
     *
     * @param tensors Tensors that are to be used in this operation
     * @param algorithm An algorithm that will be overridden with the OpMult
     * shader data and the tensors provided which are expected to be 3
     */
    OpMult(std::vector<std::shared_ptr<Tensor>> tensors,
           std::shared_ptr<Algorithm> algorithm)
      : OpAlgoDispatch(algorithm)
    {
        KP_LOG_DEBUG("Kompute OpMult constructor with params");

        if (tensors.size() != 3) {
            throw std::runtime_error(
              "Kompute OpMult expected 3 tensors but got " +
              std::to_string(tensors.size()));
        }

        const std::vector<uint32_t> spirv = std::vector<uint32_t>(
          SHADEROPMULT_COMP_SPV.begin(), SHADEROPMULT_COMP_SPV.end());

        algorithm->rebuild<>(tensors, spirv);
    }

    /**
     * Default destructor, which is in charge of destroying the algorithm
     * components but does not destroy the underlying tensors
     */
    ~OpMult() override { KP_LOG_DEBUG("Kompute OpMult destructor started"); }
};

} // End namespace kp

Then the implementation outlines all the implementations that perform the actions above: