Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0 2 : 3 : #include "kompute/Sequence.hpp" 4 : 5 : namespace kp { 6 : 7 62 : Sequence::Sequence(std::shared_ptr<vk::PhysicalDevice> physicalDevice, 8 : std::shared_ptr<vk::Device> device, 9 : std::shared_ptr<vk::Queue> computeQueue, 10 : uint32_t queueIndex, 11 62 : uint32_t totalTimestamps) 12 : { 13 124 : KP_LOG_DEBUG("Kompute Sequence Constructor with existing device & queue"); 14 : 15 62 : this->mPhysicalDevice = physicalDevice; 16 62 : this->mDevice = device; 17 62 : this->mComputeQueue = computeQueue; 18 62 : this->mQueueIndex = queueIndex; 19 : 20 62 : this->createCommandPool(); 21 62 : this->createCommandBuffer(); 22 62 : if (totalTimestamps > 0) 23 0 : this->createTimestampQueryPool(totalTimestamps + 24 : 1); //+1 for the first one 25 62 : } 26 : 27 62 : Sequence::~Sequence() 28 : { 29 124 : KP_LOG_DEBUG("Kompute Sequence Destructor started"); 30 : 31 62 : if (this->mDevice) { 32 50 : this->destroy(); 33 : } 34 62 : } 35 : 36 : void 37 134 : Sequence::begin() 38 : { 39 268 : KP_LOG_DEBUG("Kompute sequence called BEGIN"); 40 : 41 134 : if (this->isRecording()) { 42 58 : KP_LOG_DEBUG("Kompute Sequence begin called when already recording"); 43 29 : return; 44 : } 45 : 46 105 : if (this->isRunning()) { 47 1 : throw std::runtime_error( 48 2 : "Kompute Sequence begin called when sequence still running"); 49 : } 50 : 51 208 : KP_LOG_INFO("Kompute Sequence command now started recording"); 52 104 : this->mCommandBuffer->begin(vk::CommandBufferBeginInfo()); 53 104 : this->mRecording = true; 54 : 55 : // latch the first timestamp before any commands are submitted 56 104 : if (this->timestampQueryPool) 57 0 : this->mCommandBuffer->writeTimestamp( 58 : vk::PipelineStageFlagBits::eAllCommands, 59 0 : *this->timestampQueryPool, 60 : 0); 61 : } 62 : 63 : void 64 105 : Sequence::end() 65 : { 66 210 : KP_LOG_DEBUG("Kompute Sequence calling END"); 67 : 68 105 : if (this->isRunning()) { 69 1 : throw std::runtime_error( 70 2 : "Kompute Sequence begin called when sequence still running"); 71 : } 72 : 73 104 : if (!this->isRecording()) { 74 2 : KP_LOG_WARN("Kompute Sequence end called when not recording"); 75 1 : return; 76 : } else { 77 206 : KP_LOG_INFO("Kompute Sequence command recording END"); 78 103 : this->mCommandBuffer->end(); 79 103 : this->mRecording = false; 80 : } 81 : } 82 : 83 : void 84 74 : Sequence::clear() 85 : { 86 148 : KP_LOG_DEBUG("Kompute Sequence calling clear"); 87 74 : this->mOperations.clear(); 88 74 : if (this->isRecording()) { 89 1 : this->end(); 90 : } 91 74 : } 92 : 93 : std::shared_ptr<Sequence> 94 294 : Sequence::eval() 95 : { 96 588 : KP_LOG_DEBUG("Kompute sequence EVAL BEGIN"); 97 : 98 588 : return this->evalAsync()->evalAwait(); 99 : } 100 : 101 : std::shared_ptr<Sequence> 102 66 : Sequence::eval(std::shared_ptr<OpBase> op) 103 : { 104 66 : this->clear(); 105 132 : return this->record(op)->eval(); 106 : } 107 : 108 : std::shared_ptr<Sequence> 109 304 : Sequence::evalAsync() 110 : { 111 304 : if (this->isRecording()) { 112 102 : this->end(); 113 : } 114 : 115 304 : if (this->mIsRunning) { 116 1 : throw std::runtime_error( 117 : "Kompute Sequence evalAsync called when an eval async was " 118 2 : "called without successful wait"); 119 : } 120 : 121 303 : this->mIsRunning = true; 122 : 123 1047 : for (size_t i = 0; i < this->mOperations.size(); i++) { 124 744 : this->mOperations[i]->preEval(*this->mCommandBuffer); 125 : } 126 : 127 : vk::SubmitInfo submitInfo( 128 303 : 0, nullptr, nullptr, 1, this->mCommandBuffer.get()); 129 : 130 606 : this->mFence = this->mDevice->createFence(vk::FenceCreateInfo()); 131 : 132 606 : KP_LOG_DEBUG( 133 : "Kompute sequence submitting command buffer into compute queue"); 134 : 135 303 : this->mComputeQueue->submit(1, &submitInfo, this->mFence); 136 : 137 606 : return shared_from_this(); 138 : } 139 : 140 : std::shared_ptr<Sequence> 141 7 : Sequence::evalAsync(std::shared_ptr<OpBase> op) 142 : { 143 7 : this->clear(); 144 7 : this->record(op); 145 7 : this->evalAsync(); 146 7 : return shared_from_this(); 147 : } 148 : 149 : std::shared_ptr<Sequence> 150 305 : Sequence::evalAwait(uint64_t waitFor) 151 : { 152 305 : if (!this->mIsRunning) { 153 4 : KP_LOG_WARN("Kompute Sequence evalAwait called without existing eval"); 154 2 : return shared_from_this(); 155 : } 156 : 157 : vk::Result result = 158 303 : this->mDevice->waitForFences(1, &this->mFence, VK_TRUE, waitFor); 159 303 : this->mDevice->destroy( 160 : this->mFence, (vk::Optional<const vk::AllocationCallbacks>)nullptr); 161 : 162 303 : this->mIsRunning = false; 163 : 164 303 : if (result == vk::Result::eTimeout) { 165 4 : KP_LOG_WARN("Kompute Sequence evalAwait reached timeout of {}", 166 : waitFor); 167 2 : return shared_from_this(); 168 : } 169 : 170 1043 : for (size_t i = 0; i < this->mOperations.size(); i++) { 171 742 : this->mOperations[i]->postEval(*this->mCommandBuffer); 172 : } 173 : 174 301 : return shared_from_this(); 175 : } 176 : 177 : bool 178 214 : Sequence::isRunning() const 179 : { 180 214 : return this->mIsRunning; 181 : } 182 : 183 : bool 184 619 : Sequence::isRecording() const 185 : { 186 619 : return this->mRecording; 187 : } 188 : 189 : bool 190 7 : Sequence::isInit() const 191 : { 192 10 : return this->mDevice && this->mCommandPool && this->mCommandBuffer && 193 10 : this->mComputeQueue; 194 : } 195 : 196 : void 197 1 : Sequence::rerecord() 198 : { 199 1 : this->end(); 200 1 : std::vector<std::shared_ptr<OpBase>> ops = this->mOperations; 201 1 : this->mOperations.clear(); 202 4 : for (const std::shared_ptr<kp::OpBase>& op : ops) { 203 3 : this->record(op); 204 : } 205 1 : } 206 : 207 : void 208 64 : Sequence::destroy() 209 : { 210 128 : KP_LOG_DEBUG("Kompute Sequence destroy called"); 211 : 212 64 : if (!this->mDevice) { 213 4 : KP_LOG_WARN("Kompute Sequence destroy called " 214 : "with null Device pointer"); 215 2 : return; 216 : } 217 : 218 62 : if (this->mFreeCommandBuffer) { 219 124 : KP_LOG_INFO("Freeing CommandBuffer"); 220 62 : if (!this->mCommandBuffer) { 221 0 : KP_LOG_WARN("Kompute Sequence destroy called with null " 222 : "CommandPool pointer"); 223 0 : return; 224 : } 225 62 : this->mDevice->freeCommandBuffers( 226 62 : *this->mCommandPool, 1, this->mCommandBuffer.get()); 227 : 228 62 : this->mCommandBuffer = nullptr; 229 62 : this->mFreeCommandBuffer = false; 230 : 231 124 : KP_LOG_DEBUG("Kompute Sequence Freed CommandBuffer"); 232 : } 233 : 234 62 : if (this->mFreeCommandPool) { 235 124 : KP_LOG_INFO("Destroying CommandPool"); 236 62 : if (this->mCommandPool == nullptr) { 237 0 : KP_LOG_WARN("Kompute Sequence destroy called with null " 238 : "CommandPool pointer"); 239 0 : return; 240 : } 241 62 : this->mDevice->destroy( 242 62 : *this->mCommandPool, 243 : (vk::Optional<const vk::AllocationCallbacks>)nullptr); 244 : 245 62 : this->mCommandPool = nullptr; 246 62 : this->mFreeCommandPool = false; 247 : 248 124 : KP_LOG_DEBUG("Kompute Sequence Destroyed CommandPool"); 249 : } 250 : 251 62 : if (this->mOperations.size()) { 252 112 : KP_LOG_INFO("Kompute Sequence clearing operations buffer"); 253 56 : this->mOperations.clear(); 254 : } 255 : 256 62 : if (this->timestampQueryPool) { 257 0 : KP_LOG_INFO("Destroying QueryPool"); 258 0 : this->mDevice->destroy( 259 0 : *this->timestampQueryPool, 260 : (vk::Optional<const vk::AllocationCallbacks>)nullptr); 261 : 262 0 : this->timestampQueryPool = nullptr; 263 0 : KP_LOG_DEBUG("Kompute Sequence Destroyed QueryPool"); 264 : } 265 : 266 62 : if (this->mDevice) { 267 62 : this->mDevice = nullptr; 268 : } 269 62 : if (this->mPhysicalDevice) { 270 62 : this->mPhysicalDevice = nullptr; 271 : } 272 62 : if (this->mComputeQueue) { 273 62 : this->mComputeQueue = nullptr; 274 : } 275 : } 276 : 277 : std::shared_ptr<Sequence> 278 133 : Sequence::record(std::shared_ptr<OpBase> op) 279 : { 280 266 : KP_LOG_DEBUG("Kompute Sequence record function started"); 281 : 282 133 : this->begin(); 283 : 284 266 : KP_LOG_DEBUG( 285 : "Kompute Sequence running record on OpBase derived class instance"); 286 : 287 133 : op->record(*this->mCommandBuffer); 288 : 289 132 : this->mOperations.push_back(op); 290 : 291 132 : if (this->timestampQueryPool) 292 0 : this->mCommandBuffer->writeTimestamp( 293 : vk::PipelineStageFlagBits::eAllCommands, 294 0 : *this->timestampQueryPool, 295 0 : this->mOperations.size()); 296 : 297 132 : return shared_from_this(); 298 : } 299 : 300 : void 301 62 : Sequence::createCommandPool() 302 : { 303 124 : KP_LOG_DEBUG("Kompute Sequence creating command pool"); 304 : 305 62 : if (!this->mDevice) { 306 0 : throw std::runtime_error("Kompute Sequence device is null"); 307 : } 308 : 309 62 : this->mFreeCommandPool = true; 310 : 311 : vk::CommandPoolCreateInfo commandPoolInfo(vk::CommandPoolCreateFlags(), 312 62 : this->mQueueIndex); 313 62 : this->mCommandPool = std::make_shared<vk::CommandPool>(); 314 62 : this->mDevice->createCommandPool( 315 : &commandPoolInfo, nullptr, this->mCommandPool.get()); 316 124 : KP_LOG_DEBUG("Kompute Sequence Command Pool Created"); 317 62 : } 318 : 319 : void 320 62 : Sequence::createCommandBuffer() 321 : { 322 124 : KP_LOG_DEBUG("Kompute Sequence creating command buffer"); 323 62 : if (!this->mDevice) { 324 0 : throw std::runtime_error("Kompute Sequence device is null"); 325 : } 326 62 : if (!this->mCommandPool) { 327 0 : throw std::runtime_error("Kompute Sequence command pool is null"); 328 : } 329 : 330 62 : this->mFreeCommandBuffer = true; 331 : 332 : vk::CommandBufferAllocateInfo commandBufferAllocateInfo( 333 62 : *this->mCommandPool, vk::CommandBufferLevel::ePrimary, 1); 334 : 335 62 : this->mCommandBuffer = std::make_shared<vk::CommandBuffer>(); 336 62 : this->mDevice->allocateCommandBuffers(&commandBufferAllocateInfo, 337 : this->mCommandBuffer.get()); 338 124 : KP_LOG_DEBUG("Kompute Sequence Command Buffer Created"); 339 62 : } 340 : 341 : void 342 0 : Sequence::createTimestampQueryPool(uint32_t totalTimestamps) 343 : { 344 0 : KP_LOG_DEBUG("Kompute Sequence creating query pool"); 345 0 : if (!this->isInit()) { 346 0 : throw std::runtime_error( 347 0 : "createTimestampQueryPool() called on uninitialized Sequence"); 348 : } 349 0 : if (!this->mPhysicalDevice) { 350 0 : throw std::runtime_error("Kompute Sequence physical device is null"); 351 : } 352 : 353 : vk::PhysicalDeviceProperties physicalDeviceProperties = 354 0 : this->mPhysicalDevice->getProperties(); 355 : 356 0 : if (physicalDeviceProperties.limits.timestampComputeAndGraphics) { 357 0 : vk::QueryPoolCreateInfo queryPoolInfo; 358 0 : queryPoolInfo.setQueryCount(totalTimestamps); 359 0 : queryPoolInfo.setQueryType(vk::QueryType::eTimestamp); 360 0 : this->timestampQueryPool = std::make_shared<vk::QueryPool>( 361 0 : this->mDevice->createQueryPool(queryPoolInfo)); 362 : 363 0 : KP_LOG_DEBUG("Query pool for timestamps created"); 364 : } else { 365 0 : throw std::runtime_error("Device does not support timestamps"); 366 : } 367 0 : } 368 : 369 : std::vector<std::uint64_t> 370 0 : Sequence::getTimestamps() 371 : { 372 0 : if (!this->timestampQueryPool) 373 0 : throw std::runtime_error("Timestamp latching not enabled"); 374 : 375 0 : const auto n = this->mOperations.size() + 1; 376 0 : std::vector<std::uint64_t> timestamps(n, 0); 377 0 : this->mDevice->getQueryPoolResults( 378 0 : *this->timestampQueryPool, 379 : 0, 380 : n, 381 0 : timestamps.size() * sizeof(std::uint64_t), 382 0 : timestamps.data(), 383 : sizeof(uint64_t), 384 : vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait); 385 : 386 0 : return timestamps; 387 : } 388 : 389 : }