@@ -191,7 +191,7 @@ def release(self):
191
191
self .shape = None
192
192
193
193
194
- def allocate_buffers (context , stream = None , sync_mode = True ):
194
+ def allocate_buffers (context , stream = None , sync_mode = True , shared_mem = {} ):
195
195
"""
196
196
Read bindings' information in ExecutionContext, create pagelocked np.ndarray in CPU,
197
197
allocate corresponding memory in GPU.
@@ -222,7 +222,10 @@ def allocate_buffers(context, stream=None, sync_mode=True):
222
222
223
223
inputs = []
224
224
outputs = []
225
+ out_pointer = 0
225
226
has_dynamic_axes = False
227
+ inv_shared_mem = {v : k for k , v in shared_mem .items ()}
228
+
226
229
if stream is None and not sync_mode :
227
230
stream = cuda .Stream ()
228
231
for binding in context .engine :
@@ -237,14 +240,27 @@ def allocate_buffers(context, stream=None, sync_mode=True):
237
240
else :
238
241
size = trt .volume (shape ) * context .engine .max_batch_size
239
242
# Allocate host and device buffers
240
- host_mem = cuda .pagelocked_empty (size , dtype )
241
- device_mem = cuda .mem_alloc (host_mem .nbytes )
243
+ if not context .engine .binding_is_input (binding ):
244
+ if out_pointer in shared_mem .values ():
245
+ # avoid allocating memory in gpu, just pass the same device_mem and host that corresponds.
246
+ input_idx = inv_shared_mem [out_pointer ]
247
+ device_mem = inputs [input_idx ].device
248
+ host_mem = inputs [input_idx ].host
249
+ else :
250
+ host_mem = cuda .pagelocked_empty (size , dtype )
251
+ device_mem = cuda .mem_alloc (host_mem .nbytes )
252
+ out_pointer += 1
253
+
254
+ else :
255
+ host_mem = cuda .pagelocked_empty (size , dtype )
256
+ device_mem = cuda .mem_alloc (host_mem .nbytes )
242
257
mem_obj = HostDeviceMem (host_mem , device_mem , shape , dtype , binding )
243
258
# Append to the appropriate list.
244
259
if context .engine .binding_is_input (binding ):
245
260
inputs .append (mem_obj )
246
261
else :
247
262
outputs .append (mem_obj )
263
+
248
264
return inputs , outputs , stream , has_dynamic_axes
249
265
250
266
@@ -305,7 +321,7 @@ def get_bindings(context, dict_inputs, dict_outputs):
305
321
return bindings
306
322
307
323
308
- def execute_async (context , bindings , inputs , outputs , stream ):
324
+ def execute_async (context , bindings , inputs , outputs , stream , shared_mem , inputs_from_cpu , outputs_to_cpu ):
309
325
"""
310
326
Execute an TensorRT engine.
311
327
@@ -318,19 +334,32 @@ def execute_async(context, bindings, inputs, outputs, stream):
318
334
outputs: list[HostDeviceMem]
319
335
stream: pycuda.driver.Stream
320
336
used for memory transfers between CPU-GPU
337
+ inputs_from_cpu: bool, reload inputs from CPU again.
338
+ outputs_from_cpu: bool, transfer back all outputs back from GPU to CPU.
321
339
322
340
Returns
323
341
-------
324
342
list : np.ndarray
325
343
For each outputs of the engine
326
344
"""
327
345
# Transfer input data to the GPU.
328
- [cuda .memcpy_htod_async (inp .device , inp .host , stream ) for inp in inputs ]
346
+ if inputs_from_cpu :
347
+ # Reload all inputs from "inputs"
348
+ [cuda .memcpy_htod_async (inp .device , inp .host , stream ) for inp in inputs ]
349
+ else :
350
+ # Reload all inputs from "inputs" except the ones with shared memory.
351
+ [cuda .memcpy_htod_async (inp .device , inp .host , stream ) for i , inp in enumerate (inputs ) if i not in shared_mem .keys ()]
352
+
329
353
# Run inference.
330
354
check = context .execute_async (bindings = bindings , stream_handle = stream .handle )
331
355
assert check , "Kernel execution failed"
332
356
# Transfer predictions back from the GPU.
333
- [cuda .memcpy_dtoh_async (out .host , out .device , stream ) for out in outputs ]
357
+ if outputs_to_cpu :
358
+ # All outputs
359
+ [cuda .memcpy_dtoh (out .host , out .device ) for out in outputs ]
360
+ else :
361
+ # only outputs with no memory shared
362
+ [cuda .memcpy_dtoh_async (out .host , out .device , stream ) for i , out in enumerate (outputs ) if i not in shared_mem .values ()]
334
363
# Synchronize the stream
335
364
stream .synchronize ()
336
365
# Return only the host outputs.
@@ -339,7 +368,7 @@ def execute_async(context, bindings, inputs, outputs, stream):
339
368
return [out .host for out in outputs ]
340
369
341
370
342
- def execute_sync (context , bindings , inputs , outputs ):
371
+ def execute_sync (context , bindings , inputs , outputs , shared_mem , inputs_from_cpu , outputs_to_cpu ):
343
372
"""
344
373
Execute an TensorRT engine.
345
374
@@ -352,25 +381,36 @@ def execute_sync(context, bindings, inputs, outputs):
352
381
outputs: list[HostDeviceMem]
353
382
stream: pycuda.driver.Stream
354
383
used for memory transfers between CPU-GPU
384
+ inputs_from_cpu: bool, reload inputs from CPU again.
385
+ outputs_to_cpu: bool, transfer back all outputs back from GPU to CPU.
355
386
356
387
Parameters
357
388
----------
358
389
list[np.ndarray] for each outputs of the engine
359
390
"""
360
391
# Transfer input data to the GPU.
361
- [cuda .memcpy_htod (inp .device , inp .host ) for inp in inputs ]
392
+ if inputs_from_cpu :
393
+ # Reload all inputs from "inputs".
394
+ [cuda .memcpy_htod (inp .device , inp .host ) for inp in inputs ]
395
+ else :
396
+ # Reload all inputs from "inputs" except the ones with shared memory.
397
+ [cuda .memcpy_htod (inp .device , inp .host ) for i , inp in enumerate (inputs ) if i not in shared_mem .keys ()]
362
398
# Run inference.
363
399
check = context .execute_v2 (bindings = bindings )
364
400
assert check , "Kernel execution failed"
365
401
# Transfer predictions back from the GPU.
366
- [cuda .memcpy_dtoh (out .host , out .device ) for out in outputs ]
402
+ if outputs_to_cpu :
403
+ # All outputs
404
+ [cuda .memcpy_dtoh (out .host , out .device ) for out in outputs ]
405
+ else :
406
+ # only outputs with no memory shared
407
+ [cuda .memcpy_dtoh (out .host , out .device ) for i , out in enumerate (outputs ) if i not in shared_mem .values ()]
367
408
# Return only the host outputs.
368
409
for out in outputs :
369
410
out .host = out .host .reshape (out .shape )
370
411
return [out .host for out in outputs ]
371
412
372
413
373
-
374
414
def rename_nodes_ (graph , verbose = False ):
375
415
376
416
dont_rename = [v .name for v in graph .inputs + graph .outputs ]
0 commit comments