Kompute
OpAlgoDispatch.hpp
1 // SPDX-License-Identifier: Apache-2.0
2 #pragma once
3 
4 #include "kompute/Algorithm.hpp"
5 #include "kompute/Core.hpp"
6 #include "kompute/Tensor.hpp"
7 #include "kompute/operations/OpBase.hpp"
8 
9 namespace kp {
10 
17 class OpAlgoDispatch : public OpBase
18 {
19  public:
27  template<typename T = float>
28  OpAlgoDispatch(const std::shared_ptr<kp::Algorithm>& algorithm,
29  const std::vector<T>& pushConstants = {})
30  {
31  KP_LOG_DEBUG("Kompute OpAlgoDispatch constructor");
32 
33  this->mAlgorithm = algorithm;
34 
35  if (pushConstants.size()) {
36  uint32_t memorySize = sizeof(decltype(pushConstants.back()));
37  uint32_t size = pushConstants.size();
38  uint32_t totalSize = size * memorySize;
39  this->mPushConstantsData = malloc(totalSize);
40  memcpy(this->mPushConstantsData, pushConstants.data(), totalSize);
41  this->mPushConstantsDataTypeMemorySize = memorySize;
42  this->mPushConstantsSize = size;
43  }
44  }
45 
50  virtual ~OpAlgoDispatch() override;
51 
62  virtual void record(const vk::CommandBuffer& commandBuffer) override;
63 
69  virtual void preEval(const vk::CommandBuffer& commandBuffer) override;
70 
76  virtual void postEval(const vk::CommandBuffer& commandBuffer) override;
77 
78  private:
79  // -------------- ALWAYS OWNED RESOURCES
80  std::shared_ptr<Algorithm> mAlgorithm;
81  void* mPushConstantsData = nullptr;
82  uint32_t mPushConstantsDataTypeMemorySize = 0;
83  uint32_t mPushConstantsSize = 0;
84 };
85 
86 } // End namespace kp
Definition: OpAlgoDispatch.hpp:18
virtual void preEval(const vk::CommandBuffer &commandBuffer) override
virtual ~OpAlgoDispatch() override
virtual void record(const vk::CommandBuffer &commandBuffer) override
OpAlgoDispatch(const std::shared_ptr< kp::Algorithm > &algorithm, const std::vector< T > &pushConstants={})
Definition: OpAlgoDispatch.hpp:28
virtual void postEval(const vk::CommandBuffer &commandBuffer) override
Definition: OpBase.hpp:19