Skip to content

Commit 0b94b60

Browse files
committed
Support to add dymload of nvml symbols
1 parent e1e32bc commit 0b94b60

File tree

5 files changed

+159
-46
lines changed

5 files changed

+159
-46
lines changed

cpp/src/nvml_wrap.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "nvml_wrap.h"
16+
#include <dlfcn.h>
17+
#include <mutex>
18+
#include <stdio.h>
19+
20+
namespace {
21+
22+
void* nvml_handle = nullptr;
23+
std::mutex nvml_mutex;
24+
bool nvml_loaded = false;
25+
26+
bool LoadNvmlLibrary()
27+
{
28+
nvml_handle = dlopen("libnvidia-ml.so.1", RTLD_NOW);
29+
if (!nvml_handle) {
30+
nvml_handle = dlopen("libnvidia-ml.so", RTLD_NOW);
31+
if (!nvml_handle) {
32+
fprintf(stderr, "Failed to load NVML library: %s\n", dlerror());
33+
return false;
34+
}
35+
}
36+
return true;
37+
}
38+
39+
template <typename T>
40+
T LoadNvmlSymbol(const char* name)
41+
{
42+
void* symbol = dlsym(nvml_handle, name);
43+
if (!symbol) { return nullptr; }
44+
return reinterpret_cast<T>(symbol);
45+
}
46+
47+
} // namespace
48+
49+
// Global function pointers
50+
nvmlDeviceGetHandleByIndexFunc nvmlDeviceGetHandleByIndexPtr = nullptr;
51+
nvmlDeviceGetGpuFabricInfoFunc nvmlDeviceGetGpuFabricInfoPtr = nullptr;
52+
53+
// Ensure NVML is loaded and symbols are initialized
54+
bool NvmlFabricSymbolLoaded()
55+
{
56+
std::lock_guard<std::mutex> lock(nvml_mutex);
57+
if (nvml_loaded) {
58+
return true; // Already loaded
59+
}
60+
61+
if (LoadNvmlLibrary()) {
62+
nvmlDeviceGetHandleByIndexPtr =
63+
LoadNvmlSymbol<nvmlDeviceGetHandleByIndexFunc>("nvmlDeviceGetHandleByIndex");
64+
nvmlDeviceGetGpuFabricInfoPtr =
65+
LoadNvmlSymbol<nvmlDeviceGetGpuFabricInfoFunc>("nvmlDeviceGetGpuFabricInfo");
66+
67+
if (!nvmlDeviceGetHandleByIndexPtr || !nvmlDeviceGetGpuFabricInfoPtr) {
68+
dlclose(nvml_handle);
69+
nvml_handle = nullptr;
70+
} else {
71+
nvml_loaded = true;
72+
}
73+
}
74+
return nvml_loaded;
75+
}

cpp/src/nvml_wrap.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <nvml.h>
18+
19+
bool NvmlFabricSymbolLoaded();
20+
21+
typedef nvmlReturn_t (*nvmlDeviceGetHandleByIndexFunc)(unsigned int, nvmlDevice_t*);
22+
typedef nvmlReturn_t (*nvmlDeviceGetGpuFabricInfoFunc)(nvmlDevice_t, nvmlGpuFabricInfo_t*);
23+
24+
extern nvmlDeviceGetHandleByIndexFunc nvmlDeviceGetHandleByIndexPtr;
25+
extern nvmlDeviceGetGpuFabricInfoFunc nvmlDeviceGetGpuFabricInfoPtr;

cpp/src/wholememory/communicator.cpp

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ void get_host_info(host_info* phi)
497497
bool comm_support_mnnvl(wholememory_comm_t wm_comm, const std::unique_ptr<rank_info[]>& p_rank_info)
498498
{
499499
#if CUDA_VERSION >= 12030
500+
if (!nvmlFabricSymbolLoaded) return 0;
500501
int flag = 0;
501502
CUdevice currentDev;
502503
WM_CU_CHECK_NO_THROW(cuDeviceGet(&currentDev, wm_comm->dev_id));
@@ -534,16 +535,22 @@ void exchange_rank_info(wholememory_comm_t wm_comm)
534535
wm_comm->clique_info.is_in_clique = 0;
535536

536537
#if CUDA_VERSION >= 12030
537-
memset(&ri.fabric_info, 0, sizeof(ri.fabric_info));
538-
WHOLEMEMORY_CHECK_NOTHROW(GetGpuFabricInfo(wm_comm->dev_id, &ri.fabric_info) ==
539-
WHOLEMEMORY_SUCCESS);
538+
if (nvmlFabricSymbolLoaded) {
539+
memset(&ri.fabric_info, 0, sizeof(ri.fabric_info));
540+
WHOLEMEMORY_CHECK_NOTHROW(GetGpuFabricInfo(wm_comm->dev_id, &ri.fabric_info) ==
541+
WHOLEMEMORY_SUCCESS);
540542

541-
// // A zero UUID means we don't have MNNVL fabric info
542-
if (((((long*)ri.fabric_info.clusterUuid)[0] | ((long*)ri.fabric_info.clusterUuid)[1]) == 0)) {
543-
wm_comm->clique_info.is_in_clique = 0;
543+
// // A zero UUID means we don't have MNNVL fabric info
544+
if (((((long*)ri.fabric_info.clusterUuid)[0] | ((long*)ri.fabric_info.clusterUuid)[1]) == 0)) {
545+
wm_comm->clique_info.is_in_clique = 0;
544546

547+
} else {
548+
wm_comm->clique_info.is_in_clique = 1;
549+
}
545550
} else {
546-
wm_comm->clique_info.is_in_clique = 1;
551+
WHOLEMEMORY_WARN(
552+
"Some required NVML symbols are missing, likely due to an outdated GPU display driver. MNNVL "
553+
"support will be disabled.");
547554
}
548555

549556
#endif
@@ -573,38 +580,41 @@ void exchange_rank_info(wholememory_comm_t wm_comm)
573580
}
574581

575582
#if CUDA_VERSION >= 12030
576-
577-
if ((memcmp(ri.fabric_info.clusterUuid,
578-
p_rank_info.get()[r].fabric_info.clusterUuid,
579-
NVML_GPU_FABRIC_UUID_LEN) == 0) &&
580-
(ri.fabric_info.cliqueId == p_rank_info.get()[r].fabric_info.cliqueId)) {
581-
if (r == wm_comm->world_rank) {
582-
wm_comm->clique_info.clique_rank = wm_comm->clique_info.clique_rank_num;
583+
if (nvmlFabricSymbolLoaded) {
584+
if ((memcmp(ri.fabric_info.clusterUuid,
585+
p_rank_info.get()[r].fabric_info.clusterUuid,
586+
NVML_GPU_FABRIC_UUID_LEN) == 0) &&
587+
(ri.fabric_info.cliqueId == p_rank_info.get()[r].fabric_info.cliqueId)) {
588+
if (r == wm_comm->world_rank) {
589+
wm_comm->clique_info.clique_rank = wm_comm->clique_info.clique_rank_num;
590+
}
591+
if (wm_comm->clique_info.clique_rank_num == 0) {
592+
wm_comm->clique_info.clique_first_rank = r;
593+
}
594+
wm_comm->clique_info.clique_rank_num++;
583595
}
584-
if (wm_comm->clique_info.clique_rank_num == 0) { wm_comm->clique_info.clique_first_rank = r; }
585-
wm_comm->clique_info.clique_rank_num++;
596+
clique_uuids.insert(
597+
std::string(reinterpret_cast<const char*>(p_rank_info.get()[r].fabric_info.clusterUuid),
598+
NVML_GPU_FABRIC_UUID_LEN));
586599
}
587-
clique_uuids.insert(
588-
std::string(reinterpret_cast<const char*>(p_rank_info.get()[r].fabric_info.clusterUuid),
589-
NVML_GPU_FABRIC_UUID_LEN));
590-
591600
#endif
592601
}
593602

594603
#if CUDA_VERSION >= 12030
595-
wm_comm->clique_info.clique_num = clique_uuids.size();
596-
597-
std::string uuid = std::string(reinterpret_cast<const char*>(ri.fabric_info.clusterUuid),
598-
NVML_GPU_FABRIC_UUID_LEN);
599-
int id = 0;
600-
for (auto clique_uuid : clique_uuids) {
601-
if (clique_uuid == uuid) { wm_comm->clique_info.clique_id = id; }
602-
id++;
603-
}
604-
605-
wm_comm->support_mnnvl = (comm_support_mnnvl(wm_comm, p_rank_info)) &&
606-
(wm_comm->clique_info.clique_rank_num == wm_comm->world_size);
604+
if (nvmlFabricSymbolLoaded) {
605+
wm_comm->clique_info.clique_num = clique_uuids.size();
606+
607+
std::string uuid = std::string(reinterpret_cast<const char*>(ri.fabric_info.clusterUuid),
608+
NVML_GPU_FABRIC_UUID_LEN);
609+
int id = 0;
610+
for (auto clique_uuid : clique_uuids) {
611+
if (clique_uuid == uuid) { wm_comm->clique_info.clique_id = id; }
612+
id++;
613+
}
607614

615+
wm_comm->support_mnnvl = (comm_support_mnnvl(wm_comm, p_rank_info)) &&
616+
(wm_comm->clique_info.clique_rank_num == wm_comm->world_size);
617+
}
608618
#endif
609619
}
610620

cpp/src/wholememory/system_info.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#include "system_info.hpp"
17-
1816
#include <string>
1917

2018
#include "cuda_macros.hpp"
@@ -140,17 +138,19 @@ wholememory_error_code_t NvmlEnsureInitialized()
140138
wholememory_error_code_t GetGpuFabricInfo(int dev, nvmlGpuFabricInfo_t* gpuFabricInfo)
141139
{
142140
WHOLEMEMORY_CHECK_NOTHROW(NvmlEnsureInitialized() == WHOLEMEMORY_SUCCESS);
143-
std::lock_guard<std::mutex> locked(lock);
144-
// gpuFabricInfo->version = nvmlGpuFabricInfo_v2;
145-
nvmlDevice_t nvml_device;
146-
nvmlReturn_t ret = nvmlDeviceGetHandleByIndex(dev, &nvml_device);
147-
WHOLEMEMORY_EXPECTS_NOTHROW(
148-
ret == NVML_SUCCESS, "nvmlDeviceGetHandleByIndex error:%s", nvmlErrorString(ret));
149-
ret = nvmlDeviceGetGpuFabricInfo(nvml_device, gpuFabricInfo);
150-
WHOLEMEMORY_EXPECTS_NOTHROW(
151-
ret == NVML_SUCCESS, "nvmlDeviceGetGpuFabricInfo error:%s", nvmlErrorString(ret));
152-
153-
return WHOLEMEMORY_SUCCESS;
141+
if (wholememory::nvmlFabricSymbolLoaded) {
142+
std::lock_guard<std::mutex> locked(lock);
143+
// gpuFabricInfo->version = nvmlGpuFabricInfo_v2;
144+
nvmlDevice_t nvml_device;
145+
nvmlReturn_t ret = nvmlDeviceGetHandleByIndexPtr(dev, &nvml_device);
146+
WHOLEMEMORY_EXPECTS_NOTHROW(
147+
ret == NVML_SUCCESS, "nvmlDeviceGetHandleByIndex error:%s", nvmlErrorString(ret));
148+
ret = nvmlDeviceGetGpuFabricInfoPtr(nvml_device, gpuFabricInfo);
149+
WHOLEMEMORY_EXPECTS_NOTHROW(
150+
ret == NVML_SUCCESS, "nvmlDeviceGetGpuFabricInfo error:%s", nvmlErrorString(ret));
151+
return WHOLEMEMORY_SUCCESS;
152+
}
153+
return WHOLEMEMORY_SYSTEM_ERROR;
154154
}
155155

156156
}; // namespace wholememory

cpp/src/wholememory/system_info.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "wholememory/wholememory.h"
1919

2020
#if CUDA_VERSION >= 12030
21+
#include "nvml_wrap.h"
2122
#include <nvml.h>
2223
#endif
2324
bool DevAttrPagebleMemoryAccess();
@@ -37,7 +38,9 @@ bool SupportEGM();
3738
// bool SupportMNNVLForEGM();
3839
#if CUDA_VERSION >= 12030
3940
namespace wholememory {
41+
42+
inline bool nvmlFabricSymbolLoaded = NvmlFabricSymbolLoaded();
4043
wholememory_error_code_t GetGpuFabricInfo(int dev, nvmlGpuFabricInfo_t* gpuFabricInfo);
41-
}
44+
} // namespace wholememory
4245

4346
#endif

0 commit comments

Comments
 (0)