@@ -497,6 +497,7 @@ void get_host_info(host_info* phi)
497497bool comm_support_mnnvl (wholememory_comm_t wm_comm, const std::unique_ptr<rank_info[]>& p_rank_info)
498498{
499499#if CUDA_VERSION >= 12030
500+ if (!nvmlFabricSymbolLoaded) return 0 ;
500501 int flag = 0 ;
501502 CUdevice currentDev;
502503 WM_CU_CHECK_NO_THROW (cuDeviceGet (¤tDev, wm_comm->dev_id ));
@@ -534,16 +535,22 @@ void exchange_rank_info(wholememory_comm_t wm_comm)
534535 wm_comm->clique_info .is_in_clique = 0 ;
535536
536537#if CUDA_VERSION >= 12030
537- memset (&ri.fabric_info , 0 , sizeof (ri.fabric_info ));
538- WHOLEMEMORY_CHECK_NOTHROW (GetGpuFabricInfo (wm_comm->dev_id , &ri.fabric_info ) ==
539- WHOLEMEMORY_SUCCESS);
538+ if (nvmlFabricSymbolLoaded) {
539+ memset (&ri.fabric_info , 0 , sizeof (ri.fabric_info ));
540+ WHOLEMEMORY_CHECK_NOTHROW (GetGpuFabricInfo (wm_comm->dev_id , &ri.fabric_info ) ==
541+ WHOLEMEMORY_SUCCESS);
540542
541- // // A zero UUID means we don't have MNNVL fabric info
542- if (((((long *)ri.fabric_info .clusterUuid )[0 ] | ((long *)ri.fabric_info .clusterUuid )[1 ]) == 0 )) {
543- wm_comm->clique_info .is_in_clique = 0 ;
543+ // // A zero UUID means we don't have MNNVL fabric info
544+ if (((((long *)ri.fabric_info .clusterUuid )[0 ] | ((long *)ri.fabric_info .clusterUuid )[1 ]) == 0 )) {
545+ wm_comm->clique_info .is_in_clique = 0 ;
544546
547+ } else {
548+ wm_comm->clique_info .is_in_clique = 1 ;
549+ }
545550 } else {
546- wm_comm->clique_info .is_in_clique = 1 ;
551+ WHOLEMEMORY_WARN (
552+ " Some required NVML symbols are missing, likely due to an outdated GPU display driver. MNNVL "
553+ " support will be disabled." );
547554 }
548555
549556#endif
@@ -573,38 +580,41 @@ void exchange_rank_info(wholememory_comm_t wm_comm)
573580 }
574581
575582#if CUDA_VERSION >= 12030
576-
577- if ((memcmp (ri.fabric_info .clusterUuid ,
578- p_rank_info.get ()[r].fabric_info .clusterUuid ,
579- NVML_GPU_FABRIC_UUID_LEN) == 0 ) &&
580- (ri.fabric_info .cliqueId == p_rank_info.get ()[r].fabric_info .cliqueId )) {
581- if (r == wm_comm->world_rank ) {
582- wm_comm->clique_info .clique_rank = wm_comm->clique_info .clique_rank_num ;
583+ if (nvmlFabricSymbolLoaded) {
584+ if ((memcmp (ri.fabric_info .clusterUuid ,
585+ p_rank_info.get ()[r].fabric_info .clusterUuid ,
586+ NVML_GPU_FABRIC_UUID_LEN) == 0 ) &&
587+ (ri.fabric_info .cliqueId == p_rank_info.get ()[r].fabric_info .cliqueId )) {
588+ if (r == wm_comm->world_rank ) {
589+ wm_comm->clique_info .clique_rank = wm_comm->clique_info .clique_rank_num ;
590+ }
591+ if (wm_comm->clique_info .clique_rank_num == 0 ) {
592+ wm_comm->clique_info .clique_first_rank = r;
593+ }
594+ wm_comm->clique_info .clique_rank_num ++;
583595 }
584- if (wm_comm->clique_info .clique_rank_num == 0 ) { wm_comm->clique_info .clique_first_rank = r; }
585- wm_comm->clique_info .clique_rank_num ++;
596+ clique_uuids.insert (
597+ std::string (reinterpret_cast <const char *>(p_rank_info.get ()[r].fabric_info .clusterUuid ),
598+ NVML_GPU_FABRIC_UUID_LEN));
586599 }
587- clique_uuids.insert (
588- std::string (reinterpret_cast <const char *>(p_rank_info.get ()[r].fabric_info .clusterUuid ),
589- NVML_GPU_FABRIC_UUID_LEN));
590-
591600#endif
592601 }
593602
594603#if CUDA_VERSION >= 12030
595- wm_comm->clique_info .clique_num = clique_uuids.size ();
596-
597- std::string uuid = std::string (reinterpret_cast <const char *>(ri.fabric_info .clusterUuid ),
598- NVML_GPU_FABRIC_UUID_LEN);
599- int id = 0 ;
600- for (auto clique_uuid : clique_uuids) {
601- if (clique_uuid == uuid) { wm_comm->clique_info .clique_id = id; }
602- id++;
603- }
604-
605- wm_comm->support_mnnvl = (comm_support_mnnvl (wm_comm, p_rank_info)) &&
606- (wm_comm->clique_info .clique_rank_num == wm_comm->world_size );
604+ if (nvmlFabricSymbolLoaded) {
605+ wm_comm->clique_info .clique_num = clique_uuids.size ();
606+
607+ std::string uuid = std::string (reinterpret_cast <const char *>(ri.fabric_info .clusterUuid ),
608+ NVML_GPU_FABRIC_UUID_LEN);
609+ int id = 0 ;
610+ for (auto clique_uuid : clique_uuids) {
611+ if (clique_uuid == uuid) { wm_comm->clique_info .clique_id = id; }
612+ id++;
613+ }
607614
615+ wm_comm->support_mnnvl = (comm_support_mnnvl (wm_comm, p_rank_info)) &&
616+ (wm_comm->clique_info .clique_rank_num == wm_comm->world_size );
617+ }
608618#endif
609619}
610620
0 commit comments