Kompute
Algorithm.hpp
1 // SPDX-License-Identifier: Apache-2.0
2 #pragma once
3 
4 #include "kompute/Core.hpp"
5 
6 #include "fmt/format.h"
7 #include "kompute/Tensor.hpp"
8 #include "logger/Logger.hpp"
9 
10 namespace kp {
11 
16 class Algorithm
17 {
18  public:
37  template<typename S = float, typename P = float>
38  Algorithm(std::shared_ptr<vk::Device> device,
39  const std::vector<std::shared_ptr<Tensor>>& tensors = {},
40  const std::vector<uint32_t>& spirv = {},
41  const Workgroup& workgroup = {},
42  const std::vector<S>& specializationConstants = {},
43  const std::vector<P>& pushConstants = {})
44  {
45  KP_LOG_DEBUG("Kompute Algorithm Constructor with device");
46 
47  this->mDevice = device;
48 
49  if (tensors.size() && spirv.size()) {
50  KP_LOG_INFO(
51  "Kompute Algorithm initialising with tensor size: {} and "
52  "spirv size: {}",
53  tensors.size(),
54  spirv.size());
55  this->rebuild(tensors,
56  spirv,
57  workgroup,
58  specializationConstants,
59  pushConstants);
60  } else {
61  KP_LOG_INFO(
62  "Kompute Algorithm constructor with empty tensors and or "
63  "spirv so not rebuilding vulkan components");
64  }
65  }
66 
83  template<typename S = float, typename P = float>
84  void rebuild(const std::vector<std::shared_ptr<Tensor>>& tensors,
85  const std::vector<uint32_t>& spirv,
86  const Workgroup& workgroup = {},
87  const std::vector<S>& specializationConstants = {},
88  const std::vector<P>& pushConstants = {})
89  {
90  KP_LOG_DEBUG("Kompute Algorithm rebuild started");
91 
92  this->mTensors = tensors;
93  this->mSpirv = spirv;
94 
95  if (specializationConstants.size()) {
96  if (this->mSpecializationConstantsData) {
97  free(this->mSpecializationConstantsData);
98  }
99  uint32_t memorySize =
100  sizeof(decltype(specializationConstants.back()));
101  uint32_t size = specializationConstants.size();
102  uint32_t totalSize = size * memorySize;
103  this->mSpecializationConstantsData = malloc(totalSize);
104  memcpy(this->mSpecializationConstantsData,
105  specializationConstants.data(),
106  totalSize);
107  this->mSpecializationConstantsDataTypeMemorySize = memorySize;
108  this->mSpecializationConstantsSize = size;
109  }
110 
111  if (pushConstants.size()) {
112  if (this->mPushConstantsData) {
113  free(this->mPushConstantsData);
114  }
115  uint32_t memorySize = sizeof(decltype(pushConstants.back()));
116  uint32_t size = pushConstants.size();
117  uint32_t totalSize = size * memorySize;
118  this->mPushConstantsData = malloc(totalSize);
119  memcpy(this->mPushConstantsData, pushConstants.data(), totalSize);
120  this->mPushConstantsDataTypeMemorySize = memorySize;
121  this->mPushConstantsSize = size;
122  }
123 
124  this->setWorkgroup(
125  workgroup, this->mTensors.size() ? this->mTensors[0]->size() : 1);
126 
127  // Descriptor pool is created first so if available then destroy all
128  // before rebuild
129  if (this->isInit()) {
130  this->destroy();
131  }
132 
133  this->createParameters();
134  this->createShaderModule();
135  this->createPipeline();
136  }
137 
143 
150  void recordDispatch(const vk::CommandBuffer& commandBuffer);
151 
158  void recordBindCore(const vk::CommandBuffer& commandBuffer);
159 
168  void recordBindPush(const vk::CommandBuffer& commandBuffer);
169 
176  bool isInit();
177 
186  void setWorkgroup(const Workgroup& workgroup, uint32_t minSize = 1);
195  template<typename T>
196  void setPushConstants(const std::vector<T>& pushConstants)
197  {
198  uint32_t memorySize = sizeof(decltype(pushConstants.back()));
199  uint32_t size = pushConstants.size();
200 
201  this->setPushConstants(pushConstants.data(), size, memorySize);
202  }
203 
213  void setPushConstants(void* data, uint32_t size, uint32_t memorySize)
214  {
215 
216  uint32_t totalSize = memorySize * size;
217  uint32_t previousTotalSize =
218  this->mPushConstantsDataTypeMemorySize * this->mPushConstantsSize;
219 
220  if (totalSize != previousTotalSize) {
221  throw std::runtime_error(fmt::format(
222  "Kompute Algorithm push "
223  "constant total memory size provided is {} but expected {} bytes",
224  totalSize,
225  previousTotalSize));
226  }
227  if (this->mPushConstantsData) {
228  free(this->mPushConstantsData);
229  }
230 
231  this->mPushConstantsData = malloc(totalSize);
232  memcpy(this->mPushConstantsData, data, totalSize);
233  this->mPushConstantsDataTypeMemorySize = memorySize;
234  this->mPushConstantsSize = size;
235  }
236 
244  const Workgroup& getWorkgroup();
251  template<typename T>
252  const std::vector<T> getSpecializationConstants()
253  {
254  return { (T*)this->mSpecializationConstantsData,
255  ((T*)this->mSpecializationConstantsData) +
256  this->mSpecializationConstantsSize };
257  }
263  template<typename T>
264  const std::vector<T> getPushConstants()
265  {
266  return { (T*)this->mPushConstantsData,
267  ((T*)this->mPushConstantsData) + this->mPushConstantsSize };
268  }
274  const std::vector<std::shared_ptr<Tensor>>& getTensors();
275 
276  void destroy();
277 
278  private:
279  // -------------- NEVER OWNED RESOURCES
280  std::shared_ptr<vk::Device> mDevice;
281  std::vector<std::shared_ptr<Tensor>> mTensors;
282 
283  // -------------- OPTIONALLY OWNED RESOURCES
284  std::shared_ptr<vk::DescriptorSetLayout> mDescriptorSetLayout;
285  bool mFreeDescriptorSetLayout = false;
286  std::shared_ptr<vk::DescriptorPool> mDescriptorPool;
287  bool mFreeDescriptorPool = false;
288  std::shared_ptr<vk::DescriptorSet> mDescriptorSet;
289  bool mFreeDescriptorSet = false;
290  std::shared_ptr<vk::ShaderModule> mShaderModule;
291  bool mFreeShaderModule = false;
292  std::shared_ptr<vk::PipelineLayout> mPipelineLayout;
293  bool mFreePipelineLayout = false;
294  std::shared_ptr<vk::PipelineCache> mPipelineCache;
295  bool mFreePipelineCache = false;
296  std::shared_ptr<vk::Pipeline> mPipeline;
297  bool mFreePipeline = false;
298 
299  // -------------- ALWAYS OWNED RESOURCES
300  std::vector<uint32_t> mSpirv;
301  void* mSpecializationConstantsData = nullptr;
302  uint32_t mSpecializationConstantsDataTypeMemorySize = 0;
303  uint32_t mSpecializationConstantsSize = 0;
304  void* mPushConstantsData = nullptr;
305  uint32_t mPushConstantsDataTypeMemorySize = 0;
306  uint32_t mPushConstantsSize = 0;
307  Workgroup mWorkgroup;
308 
309  // Create util functions
310  void createShaderModule();
311  void createPipeline();
312 
313  // Parameters
314  void createParameters();
315 };
316 
317 } // End namespace kp
Definition: Algorithm.hpp:17
void recordBindCore(const vk::CommandBuffer &commandBuffer)
void rebuild(const std::vector< std::shared_ptr< Tensor >> &tensors, const std::vector< uint32_t > &spirv, const Workgroup &workgroup={}, const std::vector< S > &specializationConstants={}, const std::vector< P > &pushConstants={})
Definition: Algorithm.hpp:84
void setWorkgroup(const Workgroup &workgroup, uint32_t minSize=1)
void setPushConstants(const std::vector< T > &pushConstants)
Definition: Algorithm.hpp:196
void recordBindPush(const vk::CommandBuffer &commandBuffer)
Algorithm(std::shared_ptr< vk::Device > device, const std::vector< std::shared_ptr< Tensor >> &tensors={}, const std::vector< uint32_t > &spirv={}, const Workgroup &workgroup={}, const std::vector< S > &specializationConstants={}, const std::vector< P > &pushConstants={})
Definition: Algorithm.hpp:38
const Workgroup & getWorkgroup()
const std::vector< std::shared_ptr< Tensor > > & getTensors()
void recordDispatch(const vk::CommandBuffer &commandBuffer)
void setPushConstants(void *data, uint32_t size, uint32_t memorySize)
Definition: Algorithm.hpp:213
const std::vector< T > getPushConstants()
Definition: Algorithm.hpp:264
const std::vector< T > getSpecializationConstants()
Definition: Algorithm.hpp:252