Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0 2 : #pragma once 3 : 4 : #include "kompute/Algorithm.hpp" 5 : #include "kompute/Core.hpp" 6 : #include "kompute/Tensor.hpp" 7 : #include "kompute/operations/OpBase.hpp" 8 : 9 : namespace kp { 10 : 11 : /** 12 : * Operation that provides a general abstraction that simplifies the use of 13 : * algorithm and parameter components which can be used with shaders. 14 : * By default it enables the user to provide a dynamic number of tensors 15 : * which are then passed as inputs. 16 : */ 17 : class OpAlgoDispatch : public OpBase 18 : { 19 : public: 20 : /** 21 : * Constructor that stores the algorithm to use as well as the relevant 22 : * push constants to override when recording. 23 : * 24 : * @param algorithm The algorithm object to use for dispatch 25 : * @param pushConstants The push constants to use for override 26 : */ 27 : template<typename T = float> 28 46 : OpAlgoDispatch(const std::shared_ptr<kp::Algorithm>& algorithm, 29 : const std::vector<T>& pushConstants = {}) 30 46 : { 31 92 : KP_LOG_DEBUG("Kompute OpAlgoDispatch constructor"); 32 : 33 46 : this->mAlgorithm = algorithm; 34 : 35 46 : if (pushConstants.size()) { 36 11 : uint32_t memorySize = sizeof(decltype(pushConstants.back())); 37 11 : uint32_t size = pushConstants.size(); 38 11 : uint32_t totalSize = size * memorySize; 39 11 : this->mPushConstantsData = malloc(totalSize); 40 11 : memcpy(this->mPushConstantsData, pushConstants.data(), totalSize); 41 11 : this->mPushConstantsDataTypeMemorySize = memorySize; 42 11 : this->mPushConstantsSize = size; 43 : } 44 46 : } 45 : 46 : /** 47 : * Default destructor, which is in charge of destroying the algorithm 48 : * components but does not destroy the underlying tensors 49 : */ 50 : virtual ~OpAlgoDispatch() override; 51 : 52 : /** 53 : * This records the commands that are to be sent to the GPU. This includes 54 : * the barriers that ensure the memory has been copied before going in and 55 : * out of the shader, as well as the dispatch operation that sends the 56 : * shader processing to the gpu. This function also records the GPU memory 57 : * copy of the output data for the staging buffer so it can be read by the 58 : * host. 59 : * 60 : * @param commandBuffer The command buffer to record the command into. 61 : */ 62 : virtual void record(const vk::CommandBuffer& commandBuffer) override; 63 : 64 : /** 65 : * Does not perform any preEval commands. 66 : * 67 : * @param commandBuffer The command buffer to record the command into. 68 : */ 69 : virtual void preEval(const vk::CommandBuffer& commandBuffer) override; 70 : 71 : /** 72 : * Does not perform any postEval commands. 73 : * 74 : * @param commandBuffer The command buffer to record the command into. 75 : */ 76 : virtual void postEval(const vk::CommandBuffer& commandBuffer) override; 77 : 78 : private: 79 : // -------------- ALWAYS OWNED RESOURCES 80 : std::shared_ptr<Algorithm> mAlgorithm; 81 : void* mPushConstantsData = nullptr; 82 : uint32_t mPushConstantsDataTypeMemorySize = 0; 83 : uint32_t mPushConstantsSize = 0; 84 : }; 85 : 86 : } // End namespace kp