@@ -17,6 +17,9 @@ using atomic_int_sys = cuda::std::atomic<int>;
1717static inline atomic_uint64_sys *as_atomic_u64 (void *p) {
1818 return static_cast <atomic_uint64_sys *>(p);
1919}
20+ static inline atomic_uint64_sys *as_atomic_u64 (volatile uint64_t *p) {
21+ return reinterpret_cast <atomic_uint64_sys *>(const_cast <uint64_t *>(p));
22+ }
2023static inline atomic_int_sys *as_atomic_int (void *p) {
2124 return static_cast <atomic_int_sys *>(p);
2225}
@@ -36,13 +39,13 @@ lookup_function(cudaq_function_entry_t *table, size_t count,
3639}
3740
3841static int
39- find_idle_graph_worker_for_function (const cudaq_host_dispatch_loop_ctx_t *config ,
42+ find_idle_graph_worker_for_function (const cudaq_host_dispatch_loop_ctx_t *ctx ,
4043 uint32_t function_id) {
41- uint64_t mask = as_atomic_u64 (config ->idle_mask )->load (
44+ uint64_t mask = as_atomic_u64 (ctx ->idle_mask )->load (
4245 cuda::std::memory_order_acquire);
4346 while (mask != 0 ) {
4447 int worker_id = __builtin_ffsll (static_cast <long long >(mask)) - 1 ;
45- if (config ->workers [static_cast <size_t >(worker_id)].function_id ==
48+ if (ctx ->workers [static_cast <size_t >(worker_id)].function_id ==
4649 function_id)
4750 return worker_id;
4851 mask &= ~(1ULL << worker_id);
@@ -58,104 +61,101 @@ struct ParsedSlot {
5861
5962static ParsedSlot
6063parse_slot_with_function_table (void *slot_host,
61- const cudaq_host_dispatch_loop_ctx_t *config ) {
64+ const cudaq_host_dispatch_loop_ctx_t *ctx ) {
6265 ParsedSlot out;
6366 const RPCHeader *header = static_cast <const RPCHeader *>(slot_host);
6467 if (header->magic != RPC_MAGIC_REQUEST) {
6568 out.drop = true ;
6669 return out;
6770 }
6871 out.function_id = header->function_id ;
69- out.entry = lookup_function (config ->function_table ,
70- config-> function_table_count , out.function_id );
72+ out.entry = lookup_function (ctx ->function_table . entries ,
73+ ctx-> function_table . count , out.function_id );
7174 if (!out.entry )
7275 out.drop = true ;
7376 return out;
7477}
7578
76- static void finish_slot_and_advance (const cudaq_host_dispatch_loop_ctx_t *config ,
79+ static void finish_slot_and_advance (const cudaq_host_dispatch_loop_ctx_t *ctx ,
7780 size_t ¤t_slot, size_t num_slots,
7881 uint64_t &packets_dispatched) {
79- as_atomic_u64 (config-> rx_flags )[current_slot].store (
82+ as_atomic_u64 (ctx-> ringbuffer . rx_flags_host )[current_slot].store (
8083 0 , cuda::std::memory_order_release);
8184 packets_dispatched++;
82- if (config ->live_dispatched )
83- as_atomic_u64 (config ->live_dispatched )
85+ if (ctx ->live_dispatched )
86+ as_atomic_u64 (ctx ->live_dispatched )
8487 ->fetch_add (1 , cuda::std::memory_order_relaxed);
8588 current_slot = (current_slot + 1 ) % num_slots;
8689}
8790
88- static int acquire_graph_worker (const cudaq_host_dispatch_loop_ctx_t *config ,
91+ static int acquire_graph_worker (const cudaq_host_dispatch_loop_ctx_t *ctx ,
8992 bool use_function_table,
9093 const cudaq_function_entry_t *entry,
9194 uint32_t function_id) {
9295 if (use_function_table && entry &&
9396 entry->dispatch_mode == CUDAQ_DISPATCH_GRAPH_LAUNCH)
94- return find_idle_graph_worker_for_function (config , function_id);
97+ return find_idle_graph_worker_for_function (ctx , function_id);
9598 uint64_t mask =
96- as_atomic_u64 (config ->idle_mask )->load (cuda::std::memory_order_acquire);
99+ as_atomic_u64 (ctx ->idle_mask )->load (cuda::std::memory_order_acquire);
97100 if (mask == 0 )
98101 return -1 ;
99102 return __builtin_ffsll (static_cast <long long >(mask)) - 1 ;
100103}
101104
102- static void launch_graph_worker (const cudaq_host_dispatch_loop_ctx_t *config ,
105+ static void launch_graph_worker (const cudaq_host_dispatch_loop_ctx_t *ctx ,
103106 int worker_id, void *slot_host,
104107 size_t current_slot) {
105- as_atomic_u64 (config ->idle_mask )
108+ as_atomic_u64 (ctx ->idle_mask )
106109 ->fetch_and (~(1ULL << worker_id), cuda::std::memory_order_release);
107- config ->inflight_slot_tags [worker_id] = static_cast <int >(current_slot);
110+ ctx ->inflight_slot_tags [worker_id] = static_cast <int >(current_slot);
108111
109112 ptrdiff_t offset =
110- static_cast <uint8_t *>(slot_host) - config-> rx_data_host ;
111- void *data_dev = static_cast <void *>(config-> rx_data_dev + offset);
113+ static_cast <uint8_t *>(slot_host) - ctx-> ringbuffer . rx_data_host ;
114+ void *data_dev = static_cast <void *>(ctx-> ringbuffer . rx_data + offset);
112115
113- if (config->io_ctxs_host != nullptr ) {
114- // GraphIOContext mode: fill per-worker context with separate RX/TX info.
115- auto *h_ctxs = static_cast <GraphIOContext *>(config->io_ctxs_host );
116- auto *d_ctxs = static_cast <uint8_t *>(config->io_ctxs_dev );
116+ if (ctx->io_ctxs_host != nullptr ) {
117+ auto *h_ctxs = static_cast <GraphIOContext *>(ctx->io_ctxs_host );
118+ auto *d_ctxs = static_cast <uint8_t *>(ctx->io_ctxs_dev );
117119 GraphIOContext *h_ctx = &h_ctxs[worker_id];
118120
119121 h_ctx->rx_slot = data_dev;
120- h_ctx->tx_slot = config->tx_data_dev + current_slot * config->tx_stride_sz ;
121- h_ctx->tx_flag = &config->tx_flags_dev [current_slot];
122+ h_ctx->tx_slot = ctx->ringbuffer .tx_data +
123+ current_slot * ctx->ringbuffer .tx_stride_sz ;
124+ h_ctx->tx_flag = &ctx->ringbuffer .tx_flags [current_slot];
122125 h_ctx->tx_flag_value =
123126 reinterpret_cast <uint64_t >(h_ctx->tx_slot );
124- h_ctx->tx_stride_sz = config-> tx_stride_sz ;
127+ h_ctx->tx_stride_sz = ctx-> ringbuffer . tx_stride_sz ;
125128
126129 void *d_ctx = d_ctxs + worker_id * sizeof (GraphIOContext);
127- config ->h_mailbox_bank [worker_id] = d_ctx;
130+ ctx ->h_mailbox_bank [worker_id] = d_ctx;
128131
129- // In GraphIOContext mode the graph kernel writes tx_flag_value (READY)
130- // to tx_flags from the GPU. Set the in-flight marker BEFORE launch so
131- // the kernel's READY write is never clobbered by a late host write.
132- as_atomic_u64 (config->tx_flags )[current_slot].store (
132+ as_atomic_u64 (ctx->ringbuffer .tx_flags_host )[current_slot].store (
133133 CUDAQ_TX_FLAG_IN_FLIGHT, cuda::std::memory_order_release);
134134 __sync_synchronize ();
135135 } else {
136- config ->h_mailbox_bank [worker_id] = data_dev;
136+ ctx ->h_mailbox_bank [worker_id] = data_dev;
137137 }
138138 __sync_synchronize ();
139139
140140 const size_t w = static_cast <size_t >(worker_id);
141- if (config ->workers [w].pre_launch_fn )
142- config ->workers [w].pre_launch_fn (config ->workers [w].pre_launch_data ,
143- data_dev, config ->workers [w].stream );
144- cudaError_t err = cudaGraphLaunch (config ->workers [w].graph_exec ,
145- config ->workers [w].stream );
141+ if (ctx ->workers [w].pre_launch_fn )
142+ ctx ->workers [w].pre_launch_fn (ctx ->workers [w].pre_launch_data ,
143+ data_dev, ctx ->workers [w].stream );
144+ cudaError_t err = cudaGraphLaunch (ctx ->workers [w].graph_exec ,
145+ ctx ->workers [w].stream );
146146
147147 if (err != cudaSuccess) {
148148 uint64_t error_val = CUDAQ_TX_FLAG_ERROR_TAG << 48 | (uint64_t )err;
149- as_atomic_u64 (config-> tx_flags )[current_slot].store (
149+ as_atomic_u64 (ctx-> ringbuffer . tx_flags_host )[current_slot].store (
150150 error_val, cuda::std::memory_order_release);
151- as_atomic_u64 (config ->idle_mask )
151+ as_atomic_u64 (ctx ->idle_mask )
152152 ->fetch_or (1ULL << worker_id, cuda::std::memory_order_release);
153153 } else {
154- if (config ->workers [w].post_launch_fn )
155- config ->workers [w].post_launch_fn (config ->workers [w].post_launch_data ,
156- data_dev, config ->workers [w].stream );
157- if (config ->io_ctxs_host == nullptr ) {
158- as_atomic_u64 (config-> tx_flags )[current_slot].store (
154+ if (ctx ->workers [w].post_launch_fn )
155+ ctx ->workers [w].post_launch_fn (ctx ->workers [w].post_launch_data ,
156+ data_dev, ctx ->workers [w].stream );
157+ if (ctx ->io_ctxs_host == nullptr ) {
158+ as_atomic_u64 (ctx-> ringbuffer . tx_flags_host )[current_slot].store (
159159 CUDAQ_TX_FLAG_IN_FLIGHT, cuda::std::memory_order_release);
160160 }
161161 }
@@ -164,17 +164,18 @@ static void launch_graph_worker(const cudaq_host_dispatch_loop_ctx_t *config,
164164} // anonymous namespace
165165
166166extern " C" void
167- cudaq_host_dispatcher_loop (const cudaq_host_dispatch_loop_ctx_t *config ) {
167+ cudaq_host_dispatcher_loop (const cudaq_host_dispatch_loop_ctx_t *ctx ) {
168168 size_t current_slot = 0 ;
169- const size_t num_slots = config-> num_slots ;
169+ const size_t num_slots = ctx-> config . num_slots ;
170170 uint64_t packets_dispatched = 0 ;
171171 const bool use_function_table =
172- (config ->function_table != nullptr && config-> function_table_count > 0 );
172+ (ctx ->function_table . entries != nullptr && ctx-> function_table . count > 0 );
173173
174- while (as_atomic_int (config ->shutdown_flag )
174+ while (as_atomic_int (ctx ->shutdown_flag )
175175 ->load (cuda::std::memory_order_acquire) == 0 ) {
176- uint64_t rx_value = as_atomic_u64 (config->rx_flags )[current_slot].load (
177- cuda::std::memory_order_acquire);
176+ uint64_t rx_value =
177+ as_atomic_u64 (ctx->ringbuffer .rx_flags_host )[current_slot].load (
178+ cuda::std::memory_order_acquire);
178179
179180 if (rx_value == 0 ) {
180181 CUDAQ_REALTIME_CPU_RELAX ();
@@ -187,9 +188,9 @@ cudaq_host_dispatcher_loop(const cudaq_host_dispatch_loop_ctx_t *config) {
187188
188189 // TODO: Remove non-function-table path; RPC framing is always required.
189190 if (use_function_table) {
190- ParsedSlot parsed = parse_slot_with_function_table (slot_host, config );
191+ ParsedSlot parsed = parse_slot_with_function_table (slot_host, ctx );
191192 if (parsed.drop ) {
192- as_atomic_u64 (config-> rx_flags )[current_slot].store (
193+ as_atomic_u64 (ctx-> ringbuffer . rx_flags_host )[current_slot].store (
193194 0 , cuda::std::memory_order_release);
194195 current_slot = (current_slot + 1 ) % num_slots;
195196 continue ;
@@ -199,29 +200,29 @@ cudaq_host_dispatcher_loop(const cudaq_host_dispatch_loop_ctx_t *config) {
199200 }
200201
201202 if (entry && entry->dispatch_mode != CUDAQ_DISPATCH_GRAPH_LAUNCH) {
202- as_atomic_u64 (config-> rx_flags )[current_slot].store (
203+ as_atomic_u64 (ctx-> ringbuffer . rx_flags_host )[current_slot].store (
203204 0 , cuda::std::memory_order_release);
204205 current_slot = (current_slot + 1 ) % num_slots;
205206 continue ;
206207 }
207208
208209 int worker_id =
209- acquire_graph_worker (config , use_function_table, entry, function_id);
210+ acquire_graph_worker (ctx , use_function_table, entry, function_id);
210211 if (worker_id < 0 ) {
211212 CUDAQ_REALTIME_CPU_RELAX ();
212213 continue ;
213214 }
214215
215- launch_graph_worker (config , worker_id, slot_host, current_slot);
216- finish_slot_and_advance (config , current_slot, num_slots,
216+ launch_graph_worker (ctx , worker_id, slot_host, current_slot);
217+ finish_slot_and_advance (ctx , current_slot, num_slots,
217218 packets_dispatched);
218219 }
219220
220- for (size_t i = 0 ; i < config ->num_workers ; ++i) {
221- cudaStreamSynchronize (config ->workers [i].stream );
221+ for (size_t i = 0 ; i < ctx ->num_workers ; ++i) {
222+ cudaStreamSynchronize (ctx ->workers [i].stream );
222223 }
223224
224- if (config ->stats_counter ) {
225- *config ->stats_counter = packets_dispatched;
225+ if (ctx ->stats_counter ) {
226+ *ctx ->stats_counter = packets_dispatched;
226227 }
227228}
0 commit comments