Skip to content

Commit 16944de

Browse files
committed
[flang][runtime] Replace recursion with iterative work queue
Recursion, both direct and indirect, prevents accurate stack size calculation at link time for GPU device code. Restructure these recursive (often mutually so) routines in the Fortran runtime with new implementations based on an iterative work queue with suspendable/resumable work tickets: Assign, Initialize, initializeClone, Finalize, and Destroy. Default derived type I/O is also recursive, but already disabled. It can be added to this new framework later if the overall approach succeeds. Note that derived type FINAL subroutine calls, defined assignments, and defined I/O procedures all perform callbacks into user code, which may well reenter the runtime library. This kind of recursion is not handled by this change, although it may be possible to do so in the future using thread-local work queues. The effects of this restructuring on CPU performance are yet to be measured.
1 parent 440e510 commit 16944de

File tree

7 files changed

+1021
-476
lines changed

7 files changed

+1021
-476
lines changed
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
//===-- include/flang-rt/runtime/work-queue.h -------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// Internal runtime utilities for work queues that replace the use of recursion
10+
// for better GPU device support.
11+
//
12+
// A work queue is a list of tickets. Each ticket class has a Begin()
13+
// member function that is called once, and a Continue() member function
14+
// that can be called zero or more times. A ticket's execution terminates
15+
// when either of these member functions returns a status other than
16+
// StatOkContinue, and if that status is not StatOk, then the whole queue
17+
// is shut down.
18+
//
19+
// By returning StatOkContinue from its Continue() member function,
20+
// a ticket suspends its execution so that any nested tickets that it
21+
// may have created can be run to completion. It is the reponsibility
22+
// of each ticket class to maintain resumption information in its state
23+
// and manage its own progress. Most ticket classes inherit from
24+
// class ComponentTicketBase, which implements an outer loop over all
25+
// components of a derived type, and an inner loop over all elements
26+
// of a descriptor, possibly with multiple phases of execution per element.
27+
//
28+
// Tickets are created by WorkQueue::Begin...() member functions.
29+
// There is one of these for each "top level" recursive function in the
30+
// Fortran runtime support library that has been restructured into this
31+
// ticket framework.
32+
//
33+
// When the work queue is running tickets, it always selects the last ticket
34+
// on the list for execution -- "work stack" might have been a more accurate
35+
// name for this framework. This ticket may, while doing its job, create
36+
// new tickets, and since those are pushed after the active one, the first
37+
// such nested ticket will be the next one executed to completion -- i.e.,
38+
// the order of nested WorkQueue::Begin...() calls is respected.
39+
// Note that a ticket's Continue() member function won't be called again
40+
// until all nested tickets have run to completion and it is once again
41+
// the last ticket on the queue.
42+
//
43+
// Example for an assignment to a derived type:
44+
// 1. Assign() is called, and its work queue is created. It calls
45+
// WorkQueue::BeginAssign() and then WorkQueue::Run().
46+
// 2. Run calls AssignTicket::Begin(), which pushes a tickets via
47+
// BeginFinalize() and returns StatOkContinue.
48+
// 3. FinalizeTicket::Begin() and FinalizeTicket::Continue() are called
49+
// until one of them returns StatOk, which ends the finalization ticket.
50+
// 4. AssignTicket::Continue() is then called; it creates a DerivedAssignTicket
51+
// and then returns StatOk, which ends the ticket.
52+
// 5. At this point, only one ticket remains. DerivedAssignTicket::Begin()
53+
// and ::Continue() are called until they are done (not StatOkContinue).
54+
// Along the way, it may create nested AssignTickets for components,
55+
// and suspend itself so that they may each run to completion.
56+
57+
#ifndef FLANG_RT_RUNTIME_WORK_QUEUE_H_
58+
#define FLANG_RT_RUNTIME_WORK_QUEUE_H_
59+
60+
#include "flang-rt/runtime/descriptor.h"
61+
#include "flang-rt/runtime/stat.h"
62+
#include "flang/Common/api-attrs.h"
63+
#include "flang/Runtime/freestanding-tools.h"
64+
#include <flang/Common/variant.h>
65+
66+
namespace Fortran::runtime {
67+
class Terminator;
68+
class WorkQueue;
69+
namespace typeInfo {
70+
class DerivedType;
71+
class Component;
72+
} // namespace typeInfo
73+
74+
// Ticket workers
75+
76+
// Ticket workers return status codes. Returning StatOkContinue means
77+
// that the ticket is incomplete and must be resumed; any other value
78+
// means that the ticket is complete, and if not StatOk, the whole
79+
// queue can be shut down due to an error.
80+
static constexpr int StatOkContinue{1234};
81+
82+
struct NullTicket {
83+
RT_API_ATTRS int Begin(WorkQueue &) const { return StatOk; }
84+
RT_API_ATTRS int Continue(WorkQueue &) const { return StatOk; }
85+
};
86+
87+
// Base class for ticket workers that operate elementwise over descriptors
88+
// TODO: if ComponentTicketBase remains this class' only client,
89+
// merge them for better comprehensibility.
90+
class ElementalTicketBase {
91+
protected:
92+
RT_API_ATTRS ElementalTicketBase(const Descriptor &instance)
93+
: instance_{instance} {
94+
instance_.GetLowerBounds(subscripts_);
95+
}
96+
RT_API_ATTRS bool CueUpNextItem() const { return elementAt_ < elements_; }
97+
RT_API_ATTRS void AdvanceToNextElement() {
98+
phase_ = 0;
99+
++elementAt_;
100+
instance_.IncrementSubscripts(subscripts_);
101+
}
102+
RT_API_ATTRS void Reset() {
103+
phase_ = 0;
104+
elementAt_ = 0;
105+
instance_.GetLowerBounds(subscripts_);
106+
}
107+
108+
const Descriptor &instance_;
109+
std::size_t elements_{instance_.Elements()};
110+
std::size_t elementAt_{0};
111+
int phase_{0};
112+
SubscriptValue subscripts_[common::maxRank];
113+
};
114+
115+
// Base class for ticket workers that operate over derived type components
116+
// in an outer loop, and elements in an inner loop.
117+
class ComponentTicketBase : protected ElementalTicketBase {
118+
protected:
119+
RT_API_ATTRS ComponentTicketBase(
120+
const Descriptor &instance, const typeInfo::DerivedType &derived);
121+
RT_API_ATTRS bool CueUpNextItem();
122+
RT_API_ATTRS void AdvanceToNextComponent() { elementAt_ = elements_; }
123+
RT_API_ATTRS void Reset() {
124+
ElementalTicketBase::Reset();
125+
component_ = nullptr;
126+
componentAt_ = 0;
127+
}
128+
129+
const typeInfo::DerivedType &derived_;
130+
std::size_t components_{0}, componentAt_{0};
131+
const typeInfo::Component *component_{nullptr};
132+
StaticDescriptor<common::maxRank, true, 0> componentDescriptor_;
133+
};
134+
135+
// Implements derived type instance initialization
136+
class InitializeTicket : private ComponentTicketBase {
137+
public:
138+
RT_API_ATTRS InitializeTicket(
139+
const Descriptor &instance, const typeInfo::DerivedType &derived)
140+
: ComponentTicketBase{instance, derived} {}
141+
RT_API_ATTRS int Begin(WorkQueue &);
142+
RT_API_ATTRS int Continue(WorkQueue &);
143+
};
144+
145+
// Initializes one derived type instance from the value of another
146+
class InitializeCloneTicket : private ComponentTicketBase {
147+
public:
148+
RT_API_ATTRS InitializeCloneTicket(const Descriptor &clone,
149+
const Descriptor &original, const typeInfo::DerivedType &derived,
150+
bool hasStat, const Descriptor *errMsg)
151+
: ComponentTicketBase{original, derived}, clone_{clone},
152+
hasStat_{hasStat}, errMsg_{errMsg} {}
153+
RT_API_ATTRS int Begin(WorkQueue &) { return StatOkContinue; }
154+
RT_API_ATTRS int Continue(WorkQueue &);
155+
156+
private:
157+
const Descriptor &clone_;
158+
bool hasStat_{false};
159+
const Descriptor *errMsg_{nullptr};
160+
StaticDescriptor<common::maxRank, true, 0> cloneComponentDescriptor_;
161+
};
162+
163+
// Implements derived type instance finalization
164+
class FinalizeTicket : private ComponentTicketBase {
165+
public:
166+
RT_API_ATTRS FinalizeTicket(
167+
const Descriptor &instance, const typeInfo::DerivedType &derived)
168+
: ComponentTicketBase{instance, derived} {}
169+
RT_API_ATTRS int Begin(WorkQueue &);
170+
RT_API_ATTRS int Continue(WorkQueue &);
171+
172+
private:
173+
const typeInfo::DerivedType *finalizableParentType_{nullptr};
174+
};
175+
176+
// Implements derived type instance destruction
177+
class DestroyTicket : private ComponentTicketBase {
178+
public:
179+
RT_API_ATTRS DestroyTicket(const Descriptor &instance,
180+
const typeInfo::DerivedType &derived, bool finalize)
181+
: ComponentTicketBase{instance, derived}, finalize_{finalize} {}
182+
RT_API_ATTRS int Begin(WorkQueue &);
183+
RT_API_ATTRS int Continue(WorkQueue &);
184+
185+
private:
186+
bool finalize_{false};
187+
};
188+
189+
// Implements general intrinsic assignment
190+
class AssignTicket {
191+
public:
192+
RT_API_ATTRS AssignTicket(
193+
Descriptor &to, const Descriptor &from, int flags, MemmoveFct memmoveFct)
194+
: to_{to}, from_{&from}, flags_{flags}, memmoveFct_{memmoveFct} {}
195+
RT_API_ATTRS int Begin(WorkQueue &);
196+
RT_API_ATTRS int Continue(WorkQueue &);
197+
198+
private:
199+
RT_API_ATTRS bool IsSimpleMemmove() const {
200+
return !toDerived_ && to_.rank() == from_->rank() && to_.IsContiguous() &&
201+
from_->IsContiguous() && to_.ElementBytes() == from_->ElementBytes();
202+
}
203+
RT_API_ATTRS Descriptor &GetTempDescriptor();
204+
205+
Descriptor &to_;
206+
const Descriptor *from_{nullptr};
207+
int flags_{0}; // enum AssignFlags
208+
MemmoveFct memmoveFct_{nullptr};
209+
StaticDescriptor<common::maxRank, true, 0> tempDescriptor_;
210+
const typeInfo::DerivedType *toDerived_{nullptr};
211+
Descriptor *toDeallocate_{nullptr};
212+
bool persist_{false};
213+
bool done_{false};
214+
};
215+
216+
// Implements derived type intrinsic assignment
217+
class DerivedAssignTicket : private ComponentTicketBase {
218+
public:
219+
RT_API_ATTRS DerivedAssignTicket(const Descriptor &to, const Descriptor &from,
220+
const typeInfo::DerivedType &derived, int flags, MemmoveFct memmoveFct,
221+
Descriptor *deallocateAfter)
222+
: ComponentTicketBase{to, derived}, from_{from}, flags_{flags},
223+
memmoveFct_{memmoveFct}, deallocateAfter_{deallocateAfter} {}
224+
RT_API_ATTRS int Begin(WorkQueue &);
225+
RT_API_ATTRS int Continue(WorkQueue &);
226+
RT_API_ATTRS void AdvanceToNextElement();
227+
RT_API_ATTRS void Reset();
228+
229+
private:
230+
const Descriptor &from_;
231+
int flags_{0};
232+
MemmoveFct memmoveFct_{nullptr};
233+
Descriptor *deallocateAfter_{nullptr};
234+
SubscriptValue fromSubscripts_[common::maxRank];
235+
StaticDescriptor<common::maxRank, true, 0> fromComponentDescriptor_;
236+
};
237+
238+
struct Ticket {
239+
RT_API_ATTRS int Continue(WorkQueue &);
240+
bool begun{false};
241+
std::variant<NullTicket, InitializeTicket, InitializeCloneTicket,
242+
FinalizeTicket, DestroyTicket, AssignTicket, DerivedAssignTicket>
243+
u;
244+
};
245+
246+
class WorkQueue {
247+
public:
248+
RT_API_ATTRS explicit WorkQueue(Terminator &terminator)
249+
: terminator_{terminator} {
250+
for (int j{1}; j < numStatic_; ++j) {
251+
static_[j].previous = &static_[j - 1];
252+
static_[j - 1].next = &static_[j];
253+
}
254+
}
255+
RT_API_ATTRS ~WorkQueue();
256+
RT_API_ATTRS Terminator &terminator() { return terminator_; };
257+
258+
RT_API_ATTRS void BeginInitialize(
259+
const Descriptor &descriptor, const typeInfo::DerivedType &derived);
260+
RT_API_ATTRS void BeginInitializeClone(const Descriptor &clone,
261+
const Descriptor &original, const typeInfo::DerivedType &derived,
262+
bool hasStat, const Descriptor *errMsg);
263+
RT_API_ATTRS void BeginFinalize(
264+
const Descriptor &descriptor, const typeInfo::DerivedType &derived);
265+
RT_API_ATTRS void BeginDestroy(const Descriptor &descriptor,
266+
const typeInfo::DerivedType &derived, bool finalize);
267+
RT_API_ATTRS void BeginAssign(
268+
Descriptor &to, const Descriptor &from, int flags, MemmoveFct memmoveFct);
269+
RT_API_ATTRS void BeginDerivedAssign(Descriptor &to, const Descriptor &from,
270+
const typeInfo::DerivedType &derived, int flags, MemmoveFct memmoveFct,
271+
Descriptor *deallocateAfter);
272+
273+
RT_API_ATTRS int Run();
274+
275+
private:
276+
// Most uses of the work queue won't go very deep.
277+
static constexpr int numStatic_{2};
278+
279+
struct TicketList {
280+
bool isStatic{true};
281+
Ticket ticket;
282+
TicketList *previous{nullptr}, *next{nullptr};
283+
};
284+
285+
RT_API_ATTRS Ticket &StartTicket();
286+
RT_API_ATTRS void Stop();
287+
288+
Terminator &terminator_;
289+
TicketList *first_{nullptr}, *last_{nullptr}, *insertAfter_{nullptr};
290+
TicketList static_[numStatic_];
291+
TicketList *firstFree_{static_};
292+
};
293+
294+
} // namespace Fortran::runtime
295+
#endif // FLANG_RT_RUNTIME_WORK_QUEUE_H_

flang-rt/lib/runtime/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ set(supported_sources
6767
type-info.cpp
6868
unit.cpp
6969
utf.cpp
70+
work-queue.cpp
7071
)
7172

7273
# List of source not used for GPU offloading.
@@ -130,6 +131,7 @@ set(gpu_sources
130131
type-code.cpp
131132
type-info.cpp
132133
utf.cpp
134+
work-queue.cpp
133135
complex-powi.cpp
134136
reduce.cpp
135137
reduction.cpp

0 commit comments

Comments
 (0)