Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0 2 : 3 : #include "kompute/Tensor.hpp" 4 : 5 : #include "kompute/operations/OpTensorSyncLocal.hpp" 6 : 7 : namespace kp { 8 : 9 39 : OpTensorSyncLocal::OpTensorSyncLocal( 10 39 : const std::vector<std::shared_ptr<Tensor>>& tensors) 11 : { 12 78 : KP_LOG_DEBUG("Kompute OpTensorSyncLocal constructor with params"); 13 : 14 39 : if (tensors.size() < 1) { 15 0 : throw std::runtime_error( 16 0 : "Kompute OpTensorSyncLocal called with less than 1 tensor"); 17 : } 18 : 19 39 : this->mTensors = tensors; 20 39 : } 21 : 22 78 : OpTensorSyncLocal::~OpTensorSyncLocal() 23 : { 24 78 : KP_LOG_DEBUG("Kompute OpTensorSyncLocal destructor started"); 25 78 : } 26 : 27 : void 28 40 : OpTensorSyncLocal::record(const vk::CommandBuffer& commandBuffer) 29 : { 30 80 : KP_LOG_DEBUG("Kompute OpTensorSyncLocal record called"); 31 : 32 120 : for (size_t i = 0; i < this->mTensors.size(); i++) { 33 80 : if (this->mTensors[i]->tensorType() == Tensor::TensorTypes::eDevice) { 34 : 35 78 : this->mTensors[i]->recordPrimaryBufferMemoryBarrier( 36 : commandBuffer, 37 : vk::AccessFlagBits::eShaderWrite, 38 : vk::AccessFlagBits::eTransferRead, 39 : vk::PipelineStageFlagBits::eComputeShader, 40 : vk::PipelineStageFlagBits::eTransfer); 41 : 42 78 : this->mTensors[i]->recordCopyFromDeviceToStaging(commandBuffer); 43 : 44 78 : this->mTensors[i]->recordPrimaryBufferMemoryBarrier( 45 : commandBuffer, 46 : vk::AccessFlagBits::eTransferWrite, 47 : vk::AccessFlagBits::eHostRead, 48 : vk::PipelineStageFlagBits::eTransfer, 49 : vk::PipelineStageFlagBits::eHost); 50 : } 51 : } 52 40 : } 53 : 54 : void 55 237 : OpTensorSyncLocal::preEval(const vk::CommandBuffer& /*commandBuffer*/) 56 : { 57 474 : KP_LOG_DEBUG("Kompute OpTensorSyncLocal preEval called"); 58 237 : } 59 : 60 : void 61 237 : OpTensorSyncLocal::postEval(const vk::CommandBuffer& /*commandBuffer*/) 62 : { 63 474 : KP_LOG_DEBUG("Kompute OpTensorSyncLocal postEval called"); 64 : 65 474 : KP_LOG_DEBUG("Kompute OpTensorSyncLocal mapping data into tensor local"); 66 237 : } 67 : 68 : }