|
| 1 | +//===-- include/flang-rt/runtime/work-queue.h -------------------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +// Internal runtime utilities for work queues that replace the use of recursion |
| 10 | +// for better GPU device support. |
| 11 | +// |
| 12 | +// A work queue is a list of tickets. Each ticket class has a Begin() |
| 13 | +// member function that is called once, and a Continue() member function |
| 14 | +// that can be called zero or more times. A ticket's execution terminates |
| 15 | +// when either of these member functions returns a status other than |
| 16 | +// StatOkContinue, and if that status is not StatOk, then the whole queue |
| 17 | +// is shut down. |
| 18 | +// |
| 19 | +// By returning StatOkContinue from its Continue() member function, |
| 20 | +// a ticket suspends its execution so that any nested tickets that it |
| 21 | +// may have created can be run to completion. It is the reponsibility |
| 22 | +// of each ticket class to maintain resumption information in its state |
| 23 | +// and manage its own progress. Most ticket classes inherit from |
| 24 | +// class ComponentTicketBase, which implements an outer loop over all |
| 25 | +// components of a derived type, and an inner loop over all elements |
| 26 | +// of a descriptor, possibly with multiple phases of execution per element. |
| 27 | +// |
| 28 | +// Tickets are created by WorkQueue::Begin...() member functions. |
| 29 | +// There is one of these for each "top level" recursive function in the |
| 30 | +// Fortran runtime support library that has been restructured into this |
| 31 | +// ticket framework. |
| 32 | +// |
| 33 | +// When the work queue is running tickets, it always selects the last ticket |
| 34 | +// on the list for execution -- "work stack" might have been a more accurate |
| 35 | +// name for this framework. This ticket may, while doing its job, create |
| 36 | +// new tickets, and since those are pushed after the active one, the first |
| 37 | +// such nested ticket will be the next one executed to completion -- i.e., |
| 38 | +// the order of nested WorkQueue::Begin...() calls is respected. |
| 39 | +// Note that a ticket's Continue() member function won't be called again |
| 40 | +// until all nested tickets have run to completion and it is once again |
| 41 | +// the last ticket on the queue. |
| 42 | +// |
| 43 | +// Example for an assignment to a derived type: |
| 44 | +// 1. Assign() is called, and its work queue is created. It calls |
| 45 | +// WorkQueue::BeginAssign() and then WorkQueue::Run(). |
| 46 | +// 2. Run calls AssignTicket::Begin(), which pushes a tickets via |
| 47 | +// BeginFinalize() and returns StatOkContinue. |
| 48 | +// 3. FinalizeTicket::Begin() and FinalizeTicket::Continue() are called |
| 49 | +// until one of them returns StatOk, which ends the finalization ticket. |
| 50 | +// 4. AssignTicket::Continue() is then called; it creates a DerivedAssignTicket |
| 51 | +// and then returns StatOk, which ends the ticket. |
| 52 | +// 5. At this point, only one ticket remains. DerivedAssignTicket::Begin() |
| 53 | +// and ::Continue() are called until they are done (not StatOkContinue). |
| 54 | +// Along the way, it may create nested AssignTickets for components, |
| 55 | +// and suspend itself so that they may each run to completion. |
| 56 | + |
| 57 | +#ifndef FLANG_RT_RUNTIME_WORK_QUEUE_H_ |
| 58 | +#define FLANG_RT_RUNTIME_WORK_QUEUE_H_ |
| 59 | + |
| 60 | +#include "flang-rt/runtime/descriptor.h" |
| 61 | +#include "flang-rt/runtime/stat.h" |
| 62 | +#include "flang/Common/api-attrs.h" |
| 63 | +#include "flang/Runtime/freestanding-tools.h" |
| 64 | +#include <flang/Common/variant.h> |
| 65 | + |
| 66 | +namespace Fortran::runtime { |
| 67 | +class Terminator; |
| 68 | +class WorkQueue; |
| 69 | +namespace typeInfo { |
| 70 | +class DerivedType; |
| 71 | +class Component; |
| 72 | +} // namespace typeInfo |
| 73 | + |
| 74 | +// Ticket workers |
| 75 | + |
| 76 | +// Ticket workers return status codes. Returning StatOkContinue means |
| 77 | +// that the ticket is incomplete and must be resumed; any other value |
| 78 | +// means that the ticket is complete, and if not StatOk, the whole |
| 79 | +// queue can be shut down due to an error. |
| 80 | +static constexpr int StatOkContinue{1234}; |
| 81 | + |
| 82 | +struct NullTicket { |
| 83 | + RT_API_ATTRS int Begin(WorkQueue &) const { return StatOk; } |
| 84 | + RT_API_ATTRS int Continue(WorkQueue &) const { return StatOk; } |
| 85 | +}; |
| 86 | + |
| 87 | +// Base class for ticket workers that operate elementwise over descriptors |
| 88 | +// TODO: if ComponentTicketBase remains this class' only client, |
| 89 | +// merge them for better comprehensibility. |
| 90 | +class ElementalTicketBase { |
| 91 | +protected: |
| 92 | + RT_API_ATTRS ElementalTicketBase(const Descriptor &instance) |
| 93 | + : instance_{instance} { |
| 94 | + instance_.GetLowerBounds(subscripts_); |
| 95 | + } |
| 96 | + RT_API_ATTRS bool CueUpNextItem() const { return elementAt_ < elements_; } |
| 97 | + RT_API_ATTRS void AdvanceToNextElement() { |
| 98 | + phase_ = 0; |
| 99 | + ++elementAt_; |
| 100 | + instance_.IncrementSubscripts(subscripts_); |
| 101 | + } |
| 102 | + RT_API_ATTRS void Reset() { |
| 103 | + phase_ = 0; |
| 104 | + elementAt_ = 0; |
| 105 | + instance_.GetLowerBounds(subscripts_); |
| 106 | + } |
| 107 | + |
| 108 | + const Descriptor &instance_; |
| 109 | + std::size_t elements_{instance_.Elements()}; |
| 110 | + std::size_t elementAt_{0}; |
| 111 | + int phase_{0}; |
| 112 | + SubscriptValue subscripts_[common::maxRank]; |
| 113 | +}; |
| 114 | + |
| 115 | +// Base class for ticket workers that operate over derived type components |
| 116 | +// in an outer loop, and elements in an inner loop. |
| 117 | +class ComponentTicketBase : protected ElementalTicketBase { |
| 118 | +protected: |
| 119 | + RT_API_ATTRS ComponentTicketBase( |
| 120 | + const Descriptor &instance, const typeInfo::DerivedType &derived); |
| 121 | + RT_API_ATTRS bool CueUpNextItem(); |
| 122 | + RT_API_ATTRS void AdvanceToNextComponent() { elementAt_ = elements_; } |
| 123 | + RT_API_ATTRS void Reset() { |
| 124 | + ElementalTicketBase::Reset(); |
| 125 | + component_ = nullptr; |
| 126 | + componentAt_ = 0; |
| 127 | + } |
| 128 | + |
| 129 | + const typeInfo::DerivedType &derived_; |
| 130 | + std::size_t components_{0}, componentAt_{0}; |
| 131 | + const typeInfo::Component *component_{nullptr}; |
| 132 | + StaticDescriptor<common::maxRank, true, 0> componentDescriptor_; |
| 133 | +}; |
| 134 | + |
| 135 | +// Implements derived type instance initialization |
| 136 | +class InitializeTicket : private ComponentTicketBase { |
| 137 | +public: |
| 138 | + RT_API_ATTRS InitializeTicket( |
| 139 | + const Descriptor &instance, const typeInfo::DerivedType &derived) |
| 140 | + : ComponentTicketBase{instance, derived} {} |
| 141 | + RT_API_ATTRS int Begin(WorkQueue &); |
| 142 | + RT_API_ATTRS int Continue(WorkQueue &); |
| 143 | +}; |
| 144 | + |
| 145 | +// Initializes one derived type instance from the value of another |
| 146 | +class InitializeCloneTicket : private ComponentTicketBase { |
| 147 | +public: |
| 148 | + RT_API_ATTRS InitializeCloneTicket(const Descriptor &clone, |
| 149 | + const Descriptor &original, const typeInfo::DerivedType &derived, |
| 150 | + bool hasStat, const Descriptor *errMsg) |
| 151 | + : ComponentTicketBase{original, derived}, clone_{clone}, |
| 152 | + hasStat_{hasStat}, errMsg_{errMsg} {} |
| 153 | + RT_API_ATTRS int Begin(WorkQueue &) { return StatOkContinue; } |
| 154 | + RT_API_ATTRS int Continue(WorkQueue &); |
| 155 | + |
| 156 | +private: |
| 157 | + const Descriptor &clone_; |
| 158 | + bool hasStat_{false}; |
| 159 | + const Descriptor *errMsg_{nullptr}; |
| 160 | + StaticDescriptor<common::maxRank, true, 0> cloneComponentDescriptor_; |
| 161 | +}; |
| 162 | + |
| 163 | +// Implements derived type instance finalization |
| 164 | +class FinalizeTicket : private ComponentTicketBase { |
| 165 | +public: |
| 166 | + RT_API_ATTRS FinalizeTicket( |
| 167 | + const Descriptor &instance, const typeInfo::DerivedType &derived) |
| 168 | + : ComponentTicketBase{instance, derived} {} |
| 169 | + RT_API_ATTRS int Begin(WorkQueue &); |
| 170 | + RT_API_ATTRS int Continue(WorkQueue &); |
| 171 | + |
| 172 | +private: |
| 173 | + const typeInfo::DerivedType *finalizableParentType_{nullptr}; |
| 174 | +}; |
| 175 | + |
| 176 | +// Implements derived type instance destruction |
| 177 | +class DestroyTicket : private ComponentTicketBase { |
| 178 | +public: |
| 179 | + RT_API_ATTRS DestroyTicket(const Descriptor &instance, |
| 180 | + const typeInfo::DerivedType &derived, bool finalize) |
| 181 | + : ComponentTicketBase{instance, derived}, finalize_{finalize} {} |
| 182 | + RT_API_ATTRS int Begin(WorkQueue &); |
| 183 | + RT_API_ATTRS int Continue(WorkQueue &); |
| 184 | + |
| 185 | +private: |
| 186 | + bool finalize_{false}; |
| 187 | +}; |
| 188 | + |
| 189 | +// Implements general intrinsic assignment |
| 190 | +class AssignTicket { |
| 191 | +public: |
| 192 | + RT_API_ATTRS AssignTicket( |
| 193 | + Descriptor &to, const Descriptor &from, int flags, MemmoveFct memmoveFct) |
| 194 | + : to_{to}, from_{&from}, flags_{flags}, memmoveFct_{memmoveFct} {} |
| 195 | + RT_API_ATTRS int Begin(WorkQueue &); |
| 196 | + RT_API_ATTRS int Continue(WorkQueue &); |
| 197 | + |
| 198 | +private: |
| 199 | + RT_API_ATTRS bool IsSimpleMemmove() const { |
| 200 | + return !toDerived_ && to_.rank() == from_->rank() && to_.IsContiguous() && |
| 201 | + from_->IsContiguous() && to_.ElementBytes() == from_->ElementBytes(); |
| 202 | + } |
| 203 | + RT_API_ATTRS Descriptor &GetTempDescriptor(); |
| 204 | + |
| 205 | + Descriptor &to_; |
| 206 | + const Descriptor *from_{nullptr}; |
| 207 | + int flags_{0}; // enum AssignFlags |
| 208 | + MemmoveFct memmoveFct_{nullptr}; |
| 209 | + StaticDescriptor<common::maxRank, true, 0> tempDescriptor_; |
| 210 | + const typeInfo::DerivedType *toDerived_{nullptr}; |
| 211 | + Descriptor *toDeallocate_{nullptr}; |
| 212 | + bool persist_{false}; |
| 213 | + bool done_{false}; |
| 214 | +}; |
| 215 | + |
| 216 | +// Implements derived type intrinsic assignment |
| 217 | +class DerivedAssignTicket : private ComponentTicketBase { |
| 218 | +public: |
| 219 | + RT_API_ATTRS DerivedAssignTicket(const Descriptor &to, const Descriptor &from, |
| 220 | + const typeInfo::DerivedType &derived, int flags, MemmoveFct memmoveFct, |
| 221 | + Descriptor *deallocateAfter) |
| 222 | + : ComponentTicketBase{to, derived}, from_{from}, flags_{flags}, |
| 223 | + memmoveFct_{memmoveFct}, deallocateAfter_{deallocateAfter} {} |
| 224 | + RT_API_ATTRS int Begin(WorkQueue &); |
| 225 | + RT_API_ATTRS int Continue(WorkQueue &); |
| 226 | + RT_API_ATTRS void AdvanceToNextElement(); |
| 227 | + RT_API_ATTRS void Reset(); |
| 228 | + |
| 229 | +private: |
| 230 | + const Descriptor &from_; |
| 231 | + int flags_{0}; |
| 232 | + MemmoveFct memmoveFct_{nullptr}; |
| 233 | + Descriptor *deallocateAfter_{nullptr}; |
| 234 | + SubscriptValue fromSubscripts_[common::maxRank]; |
| 235 | + StaticDescriptor<common::maxRank, true, 0> fromComponentDescriptor_; |
| 236 | +}; |
| 237 | + |
| 238 | +struct Ticket { |
| 239 | + RT_API_ATTRS int Continue(WorkQueue &); |
| 240 | + bool begun{false}; |
| 241 | + std::variant<NullTicket, InitializeTicket, InitializeCloneTicket, |
| 242 | + FinalizeTicket, DestroyTicket, AssignTicket, DerivedAssignTicket> |
| 243 | + u; |
| 244 | +}; |
| 245 | + |
| 246 | +class WorkQueue { |
| 247 | +public: |
| 248 | + RT_API_ATTRS explicit WorkQueue(Terminator &terminator) |
| 249 | + : terminator_{terminator} { |
| 250 | + for (int j{1}; j < numStatic_; ++j) { |
| 251 | + static_[j].previous = &static_[j - 1]; |
| 252 | + static_[j - 1].next = &static_[j]; |
| 253 | + } |
| 254 | + } |
| 255 | + RT_API_ATTRS ~WorkQueue(); |
| 256 | + RT_API_ATTRS Terminator &terminator() { return terminator_; }; |
| 257 | + |
| 258 | + RT_API_ATTRS void BeginInitialize( |
| 259 | + const Descriptor &descriptor, const typeInfo::DerivedType &derived); |
| 260 | + RT_API_ATTRS void BeginInitializeClone(const Descriptor &clone, |
| 261 | + const Descriptor &original, const typeInfo::DerivedType &derived, |
| 262 | + bool hasStat, const Descriptor *errMsg); |
| 263 | + RT_API_ATTRS void BeginFinalize( |
| 264 | + const Descriptor &descriptor, const typeInfo::DerivedType &derived); |
| 265 | + RT_API_ATTRS void BeginDestroy(const Descriptor &descriptor, |
| 266 | + const typeInfo::DerivedType &derived, bool finalize); |
| 267 | + RT_API_ATTRS void BeginAssign( |
| 268 | + Descriptor &to, const Descriptor &from, int flags, MemmoveFct memmoveFct); |
| 269 | + RT_API_ATTRS void BeginDerivedAssign(Descriptor &to, const Descriptor &from, |
| 270 | + const typeInfo::DerivedType &derived, int flags, MemmoveFct memmoveFct, |
| 271 | + Descriptor *deallocateAfter); |
| 272 | + |
| 273 | + RT_API_ATTRS int Run(); |
| 274 | + |
| 275 | +private: |
| 276 | + // Most uses of the work queue won't go very deep. |
| 277 | + static constexpr int numStatic_{2}; |
| 278 | + |
| 279 | + struct TicketList { |
| 280 | + bool isStatic{true}; |
| 281 | + Ticket ticket; |
| 282 | + TicketList *previous{nullptr}, *next{nullptr}; |
| 283 | + }; |
| 284 | + |
| 285 | + RT_API_ATTRS Ticket &StartTicket(); |
| 286 | + RT_API_ATTRS void Stop(); |
| 287 | + |
| 288 | + Terminator &terminator_; |
| 289 | + TicketList *first_{nullptr}, *last_{nullptr}, *insertAfter_{nullptr}; |
| 290 | + TicketList static_[numStatic_]; |
| 291 | + TicketList *firstFree_{static_}; |
| 292 | +}; |
| 293 | + |
| 294 | +} // namespace Fortran::runtime |
| 295 | +#endif // FLANG_RT_RUNTIME_WORK_QUEUE_H_ |
0 commit comments