Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0 2 : 3 : #include "kompute/operations/OpTensorSyncDevice.hpp" 4 : 5 : namespace kp { 6 : 7 36 : OpTensorSyncDevice::OpTensorSyncDevice( 8 36 : const std::vector<std::shared_ptr<Tensor>>& tensors) 9 : { 10 72 : KP_LOG_DEBUG("Kompute OpTensorSyncDevice constructor with params"); 11 : 12 36 : if (tensors.size() < 1) { 13 0 : throw std::runtime_error( 14 0 : "Kompute OpTensorSyncDevice called with less than 1 tensor"); 15 : } 16 : 17 36 : this->mTensors = tensors; 18 36 : } 19 : 20 72 : OpTensorSyncDevice::~OpTensorSyncDevice() 21 : { 22 72 : KP_LOG_DEBUG("Kompute OpTensorSyncDevice destructor started"); 23 : 24 36 : this->mTensors.clear(); 25 72 : } 26 : 27 : void 28 37 : OpTensorSyncDevice::record(const vk::CommandBuffer& commandBuffer) 29 : { 30 74 : KP_LOG_DEBUG("Kompute OpTensorSyncDevice record called"); 31 : 32 120 : for (size_t i = 0; i < this->mTensors.size(); i++) { 33 83 : if (this->mTensors[i]->tensorType() == Tensor::TensorTypes::eDevice) { 34 77 : this->mTensors[i]->recordCopyFromStagingToDevice(commandBuffer); 35 : } 36 : } 37 37 : } 38 : 39 : void 40 245 : OpTensorSyncDevice::preEval(const vk::CommandBuffer& /*commandBuffer*/) 41 : { 42 490 : KP_LOG_DEBUG("Kompute OpTensorSyncDevice preEval called"); 43 245 : } 44 : 45 : void 46 245 : OpTensorSyncDevice::postEval(const vk::CommandBuffer& /*commandBuffer*/) 47 : { 48 490 : KP_LOG_DEBUG("Kompute OpTensorSyncDevice postEval called"); 49 245 : } 50 : 51 : }