Kompute
OpTensorSyncDevice.hpp
1 // SPDX-License-Identifier: Apache-2.0
2 #pragma once
3 
4 #include "kompute/Core.hpp"
5 #include "kompute/Tensor.hpp"
6 #include "kompute/operations/OpBase.hpp"
7 
8 namespace kp {
9 
18 class OpTensorSyncDevice : public OpBase
19 {
20  public:
28  OpTensorSyncDevice(const std::vector<std::shared_ptr<Tensor>>& tensors);
29 
34  ~OpTensorSyncDevice() override;
35 
42  void record(const vk::CommandBuffer& commandBuffer) override;
43 
49  virtual void preEval(const vk::CommandBuffer& commandBuffer) override;
50 
56  virtual void postEval(const vk::CommandBuffer& commandBuffer) override;
57 
58  private:
59  // -------------- ALWAYS OWNED RESOURCES
60  std::vector<std::shared_ptr<Tensor>> mTensors;
61 };
62 
63 } // End namespace kp
Definition: OpBase.hpp:19
Definition: OpTensorSyncDevice.hpp:19
virtual void preEval(const vk::CommandBuffer &commandBuffer) override
OpTensorSyncDevice(const std::vector< std::shared_ptr< Tensor >> &tensors)
void record(const vk::CommandBuffer &commandBuffer) override
~OpTensorSyncDevice() override
virtual void postEval(const vk::CommandBuffer &commandBuffer) override