Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0 2 : 3 : #include "kompute/operations/OpTensorCopy.hpp" 4 : #include "kompute/Tensor.hpp" 5 : 6 : namespace kp { 7 : 8 8 : OpTensorCopy::OpTensorCopy(const std::vector<std::shared_ptr<Tensor>>& tensors) 9 : { 10 16 : KP_LOG_DEBUG("Kompute OpTensorCopy constructor with params"); 11 : 12 8 : this->mTensors = tensors; 13 : 14 8 : if (this->mTensors.size() < 2) { 15 1 : throw std::runtime_error( 16 2 : "Kompute OpTensorCopy called with less than 2 tensor"); 17 : } 18 : 19 7 : kp::Tensor::TensorDataTypes dataType = this->mTensors[0]->dataType(); 20 7 : uint32_t size = this->mTensors[0]->size(); 21 22 : for (const std::shared_ptr<Tensor>& tensor : tensors) { 22 15 : if (tensor->dataType() != dataType) { 23 0 : throw std::runtime_error(fmt::format( 24 : "Attempting to copy tensors of different types from {} to {}", 25 0 : Tensor::toString(dataType), 26 0 : Tensor::toString(tensor->dataType()))); 27 : } 28 15 : if (tensor->size() != size) { 29 0 : throw std::runtime_error(fmt::format( 30 : "Attempting to copy tensors of different sizes from {} to {}", 31 : size, 32 0 : tensor->size())); 33 : } 34 : } 35 9 : } 36 : 37 14 : OpTensorCopy::~OpTensorCopy() 38 : { 39 14 : KP_LOG_DEBUG("Kompute OpTensorCopy destructor started"); 40 14 : } 41 : 42 : void 43 7 : OpTensorCopy::record(const vk::CommandBuffer& commandBuffer) 44 : { 45 14 : KP_LOG_DEBUG("Kompute OpTensorCopy record called"); 46 : 47 : // We iterate from the second tensor onwards and record a copy to all 48 15 : for (size_t i = 1; i < this->mTensors.size(); i++) { 49 8 : this->mTensors[i]->recordCopyFrom(commandBuffer, this->mTensors[0]); 50 : } 51 7 : } 52 : 53 : void 54 7 : OpTensorCopy::preEval(const vk::CommandBuffer& /*commandBuffer*/) 55 : { 56 14 : KP_LOG_DEBUG("Kompute OpTensorCopy preEval called"); 57 7 : } 58 : 59 : void 60 7 : OpTensorCopy::postEval(const vk::CommandBuffer& /*commandBuffer*/) 61 : { 62 14 : KP_LOG_DEBUG("Kompute OpTensorCopy postEval called"); 63 : 64 : // Do not copy on CPU side if source is storage tensor 65 7 : if (this->mTensors[0]->tensorType() == kp::Tensor::TensorTypes::eStorage) 66 : { 67 2 : KP_LOG_DEBUG("Kompute OpTensorCopy not copying tensor source given it's of eStorage type"); 68 1 : return; 69 : } 70 6 : void* data = this->mTensors[0]->rawData(); 71 : 72 : // Copy the data from the first tensor into all the tensors 73 13 : for (size_t i = 1; i < this->mTensors.size(); i++) { 74 7 : if (this->mTensors[i]->tensorType() == kp::Tensor::TensorTypes::eStorage) { 75 2 : KP_LOG_DEBUG("Kompute OpTensorCopy not copying to tensor dest given it's of eStorage type"); 76 1 : continue; 77 : } 78 6 : this->mTensors[i]->setRawData(data); 79 : } 80 : } 81 : 82 : }