Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0 2 : #pragma once 3 : 4 : #include <fstream> 5 : 6 : #include "kompute/Core.hpp" 7 : 8 : #include "ShaderOpMult.hpp" 9 : 10 : #include "kompute/Algorithm.hpp" 11 : #include "kompute/Tensor.hpp" 12 : 13 : #include "kompute/operations/OpAlgoDispatch.hpp" 14 : 15 : namespace kp { 16 : 17 : /** 18 : * Operation that performs multiplication on two tensors and outpus on third 19 : * tensor. 20 : */ 21 : class OpMult : public OpAlgoDispatch 22 : { 23 : public: 24 : /** 25 : * Default constructor with parameters that provides the bare minimum 26 : * requirements for the operations to be able to create and manage their 27 : * sub-components. 28 : * 29 : * @param tensors Tensors that are to be used in this operation 30 : * @param algorithm An algorithm that will be overridden with the OpMult 31 : * shader data and the tensors provided which are expected to be 3 32 : */ 33 4 : OpMult(std::vector<std::shared_ptr<Tensor>> tensors, 34 : std::shared_ptr<Algorithm> algorithm) 35 4 : : OpAlgoDispatch(algorithm) 36 : { 37 8 : KP_LOG_DEBUG("Kompute OpMult constructor with params"); 38 : 39 4 : if (tensors.size() != 3) { 40 0 : throw std::runtime_error( 41 0 : "Kompute OpMult expected 3 tensors but got " + 42 0 : std::to_string(tensors.size())); 43 : } 44 : 45 : const std::vector<uint32_t> spirv = std::vector<uint32_t>( 46 4 : SHADEROPMULT_COMP_SPV.begin(), SHADEROPMULT_COMP_SPV.end()); 47 : 48 4 : algorithm->rebuild<>(tensors, spirv); 49 4 : } 50 : 51 : /** 52 : * Default destructor, which is in charge of destroying the algorithm 53 : * components but does not destroy the underlying tensors 54 : */ 55 16 : ~OpMult() override { KP_LOG_DEBUG("Kompute OpMult destructor started"); } 56 : }; 57 : 58 : } // End namespace kp