28
28
#include < raft/core/handle.hpp>
29
29
30
30
#include < rmm/mr/device/cuda_async_memory_resource.hpp>
31
+ #include < rmm/mr/device/limiting_resource_adaptor.hpp>
32
+ #include < rmm/mr/device/logging_resource_adaptor.hpp>
31
33
#include < rmm/mr/device/pool_memory_resource.hpp>
34
+ #include < rmm/mr/device/tracking_resource_adaptor.hpp>
32
35
33
36
#include < rmm/mr/device/owning_wrapper.hpp>
34
37
@@ -256,7 +259,9 @@ void run_single_file_mp(std::string file_path,
256
259
{
257
260
std::cout << " running file " << file_path << " on gpu : " << device << std::endl;
258
261
auto memory_resource = make_async ();
259
- rmm::mr::set_current_device_resource (memory_resource.get ());
262
+ auto limiting_adaptor =
263
+ rmm::mr::limiting_resource_adaptor (memory_resource.get (), 6ULL * 1024ULL * 1024ULL * 1024ULL );
264
+ rmm::mr::set_current_device_resource (&limiting_adaptor);
260
265
int sol_found = run_single_file (file_path,
261
266
device,
262
267
batch_id,
@@ -340,6 +345,15 @@ int main(int argc, char* argv[])
340
345
.scan <' g' , double >()
341
346
.default_value (std::numeric_limits<double >::max ());
342
347
348
+ program.add_argument (" --memory-limit" )
349
+ .help (" memory limit in MB" )
350
+ .scan <' g' , double >()
351
+ .default_value (0.0 );
352
+
353
+ program.add_argument (" --track-allocations" )
354
+ .help (" track allocations (t/f)" )
355
+ .default_value (std::string (" f" ));
356
+
343
357
// Parse arguments
344
358
try {
345
359
program.parse_args (argc, argv);
@@ -362,10 +376,12 @@ int main(int argc, char* argv[])
362
376
std::string result_file;
363
377
int batch_num = -1 ;
364
378
365
- bool heuristics_only = program.get <std::string>(" --heuristics-only" )[0 ] == ' t' ;
366
- int num_cpu_threads = program.get <int >(" --num-cpu-threads" );
367
- bool write_log_file = program.get <std::string>(" --write-log-file" )[0 ] == ' t' ;
368
- bool log_to_console = program.get <std::string>(" --log-to-console" )[0 ] == ' t' ;
379
+ bool heuristics_only = program.get <std::string>(" --heuristics-only" )[0 ] == ' t' ;
380
+ int num_cpu_threads = program.get <int >(" --num-cpu-threads" );
381
+ bool write_log_file = program.get <std::string>(" --write-log-file" )[0 ] == ' t' ;
382
+ bool log_to_console = program.get <std::string>(" --log-to-console" )[0 ] == ' t' ;
383
+ double memory_limit = program.get <double >(" --memory-limit" );
384
+ bool track_allocations = program.get <std::string>(" --track-allocations" )[0 ] == ' t' ;
369
385
370
386
if (program.is_used (" --out-dir" )) {
371
387
out_dir = program.get <std::string>(" --out-dir" );
@@ -469,7 +485,17 @@ int main(int argc, char* argv[])
469
485
merge_result_files (out_dir, result_file, n_gpus, batch_num);
470
486
} else {
471
487
auto memory_resource = make_async ();
472
- rmm::mr::set_current_device_resource (memory_resource.get ());
488
+ if (memory_limit > 0 ) {
489
+ auto limiting_adaptor =
490
+ rmm::mr::limiting_resource_adaptor (memory_resource.get (), memory_limit * 1024ULL * 1024ULL );
491
+ rmm::mr::set_current_device_resource (&limiting_adaptor);
492
+ } else if (track_allocations) {
493
+ rmm::mr::tracking_resource_adaptor tracking_adaptor (memory_resource.get (),
494
+ /* capture_stacks=*/ true );
495
+ rmm::mr::set_current_device_resource (&tracking_adaptor);
496
+ } else {
497
+ rmm::mr::set_current_device_resource (memory_resource.get ());
498
+ }
473
499
run_single_file (path,
474
500
0 ,
475
501
0 ,
0 commit comments