Kompute
Sequence.hpp
1 // SPDX-License-Identifier: Apache-2.0
2 #pragma once
3 
4 #include "kompute/Core.hpp"
5 
6 #include "kompute/operations/OpAlgoDispatch.hpp"
7 #include "kompute/operations/OpBase.hpp"
8 
9 namespace kp {
10 
14 class Sequence : public std::enable_shared_from_this<Sequence>
15 {
16  public:
27  Sequence(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
28  std::shared_ptr<vk::Device> device,
29  std::shared_ptr<vk::Queue> computeQueue,
30  uint32_t queueIndex,
31  uint32_t totalTimestamps = 0);
37 
48  std::shared_ptr<Sequence> record(std::shared_ptr<OpBase> op);
49 
61  template<typename T, typename... TArgs>
62  std::shared_ptr<Sequence> record(
63  std::vector<std::shared_ptr<Tensor>> tensors,
64  TArgs&&... params)
65  {
66  std::shared_ptr<T> op{ new T(tensors, std::forward<TArgs>(params)...) };
67  return this->record(op);
68  }
81  template<typename T, typename... TArgs>
82  std::shared_ptr<Sequence> record(std::shared_ptr<Algorithm> algorithm,
83  TArgs&&... params)
84  {
85  std::shared_ptr<T> op{ new T(algorithm,
86  std::forward<TArgs>(params)...) };
87  return this->record(op);
88  }
89 
96  std::shared_ptr<Sequence> eval();
97 
105  std::shared_ptr<Sequence> eval(std::shared_ptr<OpBase> op);
106 
116  template<typename T, typename... TArgs>
117  std::shared_ptr<Sequence> eval(std::vector<std::shared_ptr<Tensor>> tensors,
118  TArgs&&... params)
119  {
120  std::shared_ptr<T> op{ new T(tensors, std::forward<TArgs>(params)...) };
121  return this->eval(op);
122  }
133  template<typename T, typename... TArgs>
134  std::shared_ptr<Sequence> eval(std::shared_ptr<Algorithm> algorithm,
135  TArgs&&... params)
136  {
137  std::shared_ptr<T> op{ new T(algorithm,
138  std::forward<TArgs>(params)...) };
139  return this->eval(op);
140  }
141 
150  std::shared_ptr<Sequence> evalAsync();
159  std::shared_ptr<Sequence> evalAsync(std::shared_ptr<OpBase> op);
169  template<typename T, typename... TArgs>
170  std::shared_ptr<Sequence> evalAsync(
171  std::vector<std::shared_ptr<Tensor>> tensors,
172  TArgs&&... params)
173  {
174  std::shared_ptr<T> op{ new T(tensors, std::forward<TArgs>(params)...) };
175  return this->evalAsync(op);
176  }
187  template<typename T, typename... TArgs>
188  std::shared_ptr<Sequence> evalAsync(std::shared_ptr<Algorithm> algorithm,
189  TArgs&&... params)
190  {
191  std::shared_ptr<T> op{ new T(algorithm,
192  std::forward<TArgs>(params)...) };
193  return this->evalAsync(op);
194  }
195 
203  std::shared_ptr<Sequence> evalAwait(uint64_t waitFor = UINT64_MAX);
204 
209  void clear();
210 
215  std::vector<std::uint64_t> getTimestamps();
216 
221  void begin();
222 
227  void end();
228 
234  bool isRecording() const;
235 
242  bool isInit() const;
243 
249  void rerecord();
250 
257  bool isRunning() const;
258 
263  void destroy();
264 
265  private:
266  // -------------- NEVER OWNED RESOURCES
267  std::shared_ptr<vk::PhysicalDevice> mPhysicalDevice = nullptr;
268  std::shared_ptr<vk::Device> mDevice = nullptr;
269  std::shared_ptr<vk::Queue> mComputeQueue = nullptr;
270  uint32_t mQueueIndex = -1;
271 
272  // -------------- OPTIONALLY OWNED RESOURCES
273  std::shared_ptr<vk::CommandPool> mCommandPool = nullptr;
274  bool mFreeCommandPool = false;
275  std::shared_ptr<vk::CommandBuffer> mCommandBuffer = nullptr;
276  bool mFreeCommandBuffer = false;
277 
278  // -------------- ALWAYS OWNED RESOURCES
279  vk::Fence mFence;
280  std::vector<std::shared_ptr<OpBase>> mOperations{};
281  std::shared_ptr<vk::QueryPool> timestampQueryPool = nullptr;
282 
283  // State
284  bool mRecording = false;
285  bool mIsRunning = false;
286 
287  // Create functions
288  void createCommandPool();
289  void createCommandBuffer();
290  void createTimestampQueryPool(uint32_t totalTimestamps);
291 };
292 
293 } // End namespace kp
Definition: Sequence.hpp:15
bool isRecording() const
std::shared_ptr< Sequence > eval(std::shared_ptr< Algorithm > algorithm, TArgs &&... params)
Definition: Sequence.hpp:134
std::shared_ptr< Sequence > record(std::shared_ptr< Algorithm > algorithm, TArgs &&... params)
Definition: Sequence.hpp:82
std::shared_ptr< Sequence > evalAwait(uint64_t waitFor=UINT64_MAX)
std::shared_ptr< Sequence > record(std::vector< std::shared_ptr< Tensor >> tensors, TArgs &&... params)
Definition: Sequence.hpp:62
std::vector< std::uint64_t > getTimestamps()
std::shared_ptr< Sequence > evalAsync(std::shared_ptr< Algorithm > algorithm, TArgs &&... params)
Definition: Sequence.hpp:188
std::shared_ptr< Sequence > eval()
bool isRunning() const
std::shared_ptr< Sequence > eval(std::vector< std::shared_ptr< Tensor >> tensors, TArgs &&... params)
Definition: Sequence.hpp:117
std::shared_ptr< Sequence > eval(std::shared_ptr< OpBase > op)
void rerecord()
bool isInit() const
std::shared_ptr< Sequence > evalAsync(std::vector< std::shared_ptr< Tensor >> tensors, TArgs &&... params)
Definition: Sequence.hpp:170
std::shared_ptr< Sequence > record(std::shared_ptr< OpBase > op)
std::shared_ptr< Sequence > evalAsync(std::shared_ptr< OpBase > op)
std::shared_ptr< Sequence > evalAsync()
Sequence(std::shared_ptr< vk::PhysicalDevice > physicalDevice, std::shared_ptr< vk::Device > device, std::shared_ptr< vk::Queue > computeQueue, uint32_t queueIndex, uint32_t totalTimestamps=0)
void destroy()