6 #include "kompute/Core.hpp"
8 #include "ShaderOpMult.hpp"
10 #include "kompute/Algorithm.hpp"
11 #include "kompute/Tensor.hpp"
13 #include "kompute/operations/OpAlgoDispatch.hpp"
33 OpMult(std::vector<std::shared_ptr<Tensor>> tensors,
34 std::shared_ptr<Algorithm> algorithm)
37 KP_LOG_DEBUG(
"Kompute OpMult constructor with params");
39 if (tensors.size() != 3) {
40 throw std::runtime_error(
41 "Kompute OpMult expected 3 tensors but got " +
42 std::to_string(tensors.size()));
45 const std::vector<uint32_t> spirv = std::vector<uint32_t>(
46 SHADEROPMULT_COMP_SPV.begin(), SHADEROPMULT_COMP_SPV.end());
48 algorithm->rebuild<>(tensors, spirv);
55 ~OpMult()
override { KP_LOG_DEBUG(
"Kompute OpMult destructor started"); }
Definition: OpAlgoDispatch.hpp:18
Definition: OpMult.hpp:22
~OpMult() override
Definition: OpMult.hpp:55
OpMult(std::vector< std::shared_ptr< Tensor >> tensors, std::shared_ptr< Algorithm > algorithm)
Definition: OpMult.hpp:33