LCOV - code coverage report
Current view: top level - src/include/kompute/operations - OpAlgoDispatch.hpp (source / functions) Hit Total Coverage
Test: lcov.info Lines: 13 13 100.0 %
Date: 2024-01-20 13:42:20 Functions: 4 5 80.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: Apache-2.0
       2             : #pragma once
       3             : 
       4             : #include "kompute/Algorithm.hpp"
       5             : #include "kompute/Core.hpp"
       6             : #include "kompute/Tensor.hpp"
       7             : #include "kompute/operations/OpBase.hpp"
       8             : 
       9             : namespace kp {
      10             : 
      11             : /**
      12             :  * Operation that provides a general abstraction that simplifies the use of
      13             :  * algorithm and parameter components which can be used with shaders.
      14             :  * By default it enables the user to provide a dynamic number of tensors
      15             :  * which are then passed as inputs.
      16             :  */
      17             : class OpAlgoDispatch : public OpBase
      18             : {
      19             :   public:
      20             :     /**
      21             :      * Constructor that stores the algorithm to use as well as the relevant
      22             :      * push constants to override when recording.
      23             :      *
      24             :      * @param algorithm The algorithm object to use for dispatch
      25             :      * @param pushConstants The push constants to use for override
      26             :      */
      27             :     template<typename T = float>
      28          46 :     OpAlgoDispatch(const std::shared_ptr<kp::Algorithm>& algorithm,
      29             :                    const std::vector<T>& pushConstants = {})
      30          46 :     {
      31          92 :         KP_LOG_DEBUG("Kompute OpAlgoDispatch constructor");
      32             : 
      33          46 :         this->mAlgorithm = algorithm;
      34             : 
      35          46 :         if (pushConstants.size()) {
      36          11 :             uint32_t memorySize = sizeof(decltype(pushConstants.back()));
      37          11 :             uint32_t size = pushConstants.size();
      38          11 :             uint32_t totalSize = size * memorySize;
      39          11 :             this->mPushConstantsData = malloc(totalSize);
      40          11 :             memcpy(this->mPushConstantsData, pushConstants.data(), totalSize);
      41          11 :             this->mPushConstantsDataTypeMemorySize = memorySize;
      42          11 :             this->mPushConstantsSize = size;
      43             :         }
      44          46 :     }
      45             : 
      46             :     /**
      47             :      * Default destructor, which is in charge of destroying the algorithm
      48             :      * components but does not destroy the underlying tensors
      49             :      */
      50             :     virtual ~OpAlgoDispatch() override;
      51             : 
      52             :     /**
      53             :      * This records the commands that are to be sent to the GPU. This includes
      54             :      * the barriers that ensure the memory has been copied before going in and
      55             :      * out of the shader, as well as the dispatch operation that sends the
      56             :      * shader processing to the gpu. This function also records the GPU memory
      57             :      * copy of the output data for the staging buffer so it can be read by the
      58             :      * host.
      59             :      *
      60             :      * @param commandBuffer The command buffer to record the command into.
      61             :      */
      62             :     virtual void record(const vk::CommandBuffer& commandBuffer) override;
      63             : 
      64             :     /**
      65             :      * Does not perform any preEval commands.
      66             :      *
      67             :      * @param commandBuffer The command buffer to record the command into.
      68             :      */
      69             :     virtual void preEval(const vk::CommandBuffer& commandBuffer) override;
      70             : 
      71             :     /**
      72             :      * Does not perform any postEval commands.
      73             :      *
      74             :      * @param commandBuffer The command buffer to record the command into.
      75             :      */
      76             :     virtual void postEval(const vk::CommandBuffer& commandBuffer) override;
      77             : 
      78             :   private:
      79             :     // -------------- ALWAYS OWNED RESOURCES
      80             :     std::shared_ptr<Algorithm> mAlgorithm;
      81             :     void* mPushConstantsData = nullptr;
      82             :     uint32_t mPushConstantsDataTypeMemorySize = 0;
      83             :     uint32_t mPushConstantsSize = 0;
      84             : };
      85             : 
      86             : } // End namespace kp

Generated by: LCOV version 1.14