LCOV - code coverage report
Current view: top level - src/include/kompute/operations - OpMult.hpp (source / functions) Hit Total Coverage
Test: lcov.info Lines: 8 11 72.7 %
Date: 2024-01-20 13:42:20 Functions: 3 3 100.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: Apache-2.0
       2             : #pragma once
       3             : 
       4             : #include <fstream>
       5             : 
       6             : #include "kompute/Core.hpp"
       7             : 
       8             : #include "ShaderOpMult.hpp"
       9             : 
      10             : #include "kompute/Algorithm.hpp"
      11             : #include "kompute/Tensor.hpp"
      12             : 
      13             : #include "kompute/operations/OpAlgoDispatch.hpp"
      14             : 
      15             : namespace kp {
      16             : 
      17             : /**
      18             :  * Operation that performs multiplication on two tensors and outpus on third
      19             :  * tensor.
      20             :  */
      21             : class OpMult : public OpAlgoDispatch
      22             : {
      23             :   public:
      24             :     /**
      25             :      * Default constructor with parameters that provides the bare minimum
      26             :      * requirements for the operations to be able to create and manage their
      27             :      * sub-components.
      28             :      *
      29             :      * @param tensors Tensors that are to be used in this operation
      30             :      * @param algorithm An algorithm that will be overridden with the OpMult
      31             :      * shader data and the tensors provided which are expected to be 3
      32             :      */
      33           4 :     OpMult(std::vector<std::shared_ptr<Tensor>> tensors,
      34             :            std::shared_ptr<Algorithm> algorithm)
      35           4 :       : OpAlgoDispatch(algorithm)
      36             :     {
      37           8 :         KP_LOG_DEBUG("Kompute OpMult constructor with params");
      38             : 
      39           4 :         if (tensors.size() != 3) {
      40           0 :             throw std::runtime_error(
      41           0 :               "Kompute OpMult expected 3 tensors but got " +
      42           0 :               std::to_string(tensors.size()));
      43             :         }
      44             : 
      45             :         const std::vector<uint32_t> spirv = std::vector<uint32_t>(
      46           4 :           SHADEROPMULT_COMP_SPV.begin(), SHADEROPMULT_COMP_SPV.end());
      47             : 
      48           4 :         algorithm->rebuild<>(tensors, spirv);
      49           4 :     }
      50             : 
      51             :     /**
      52             :      * Default destructor, which is in charge of destroying the algorithm
      53             :      * components but does not destroy the underlying tensors
      54             :      */
      55          16 :     ~OpMult() override { KP_LOG_DEBUG("Kompute OpMult destructor started"); }
      56             : };
      57             : 
      58             : } // End namespace kp

Generated by: LCOV version 1.14