Kompute
OpMult.hpp
1 // SPDX-License-Identifier: Apache-2.0
2 #pragma once
3 
4 #include <fstream>
5 
6 #include "kompute/Core.hpp"
7 
8 #include "ShaderOpMult.hpp"
9 
10 #include "kompute/Algorithm.hpp"
11 #include "kompute/Tensor.hpp"
12 
13 #include "kompute/operations/OpAlgoDispatch.hpp"
14 
15 namespace kp {
16 
21 class OpMult : public OpAlgoDispatch
22 {
23  public:
33  OpMult(std::vector<std::shared_ptr<Tensor>> tensors,
34  std::shared_ptr<Algorithm> algorithm)
35  : OpAlgoDispatch(algorithm)
36  {
37  KP_LOG_DEBUG("Kompute OpMult constructor with params");
38 
39  if (tensors.size() != 3) {
40  throw std::runtime_error(
41  "Kompute OpMult expected 3 tensors but got " +
42  std::to_string(tensors.size()));
43  }
44 
45  const std::vector<uint32_t> spirv = std::vector<uint32_t>(
46  SHADEROPMULT_COMP_SPV.begin(), SHADEROPMULT_COMP_SPV.end());
47 
48  algorithm->rebuild<>(tensors, spirv);
49  }
50 
55  ~OpMult() override { KP_LOG_DEBUG("Kompute OpMult destructor started"); }
56 };
57 
58 } // End namespace kp
Definition: OpAlgoDispatch.hpp:18
Definition: OpMult.hpp:22
~OpMult() override
Definition: OpMult.hpp:55
OpMult(std::vector< std::shared_ptr< Tensor >> tensors, std::shared_ptr< Algorithm > algorithm)
Definition: OpMult.hpp:33