Skip to content

Commit effaef8

Browse files
committed
Refactor legacy tools to use Accelerator HAL Adapter
1 parent ccc9129 commit effaef8

File tree

7 files changed

+179
-8
lines changed

7 files changed

+179
-8
lines changed

TEST_PLAN.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Test Plan: Accelerator HAL Migration
2+
3+
This document outlines the test plan to verify that the migration to the Accelerator HAL (Hardware Abstraction Layer) preserves existing functionality for NVML-based monitoring and health checks.
4+
5+
## Objective
6+
7+
Ensure that all existing NVML paths (`nvml_monitor` and `check_nvidia_smi`) continue to function identically after being refactored to use the `AcceleratorManager` and `NVMLBackend` interface.
8+
9+
## Coverage Areas
10+
11+
1. **Metric Collection (`nvml_monitor`)**: Verifying GPU metrics (utilization, memory, power, temperature, clocks, ECC) are collected correctly.
12+
2. **Health Checks (`check_nvidia_smi`)**: Verifying GPU presence, running processes, and error detection.
13+
3. **Error Handling**: Ensuring that backend unavailability or device errors are handled gracefully and logged appropriately.
14+
15+
## Test Cases
16+
17+
### 1. Unit Tests
18+
19+
Run existing unit tests to verify no regressions in logic.
20+
21+
```bash
22+
pytest gcm/tests/test_accelerator_hal.py
23+
pytest gcm/tests/health_checks_tests/test_check_nvidia_smi.py
24+
pytest gcm/tests/test_nvml_monitor.py
25+
```
26+
27+
### 2. Manual Verification (Stubbed)
28+
29+
Since we cannot run on actual GPU hardware in this environment, we rely on the stubbed NVML library used in tests.
30+
31+
#### A. NVML Monitor
32+
33+
**Refactored Logic:**
34+
`nvml_monitor` now instantiates `AcceleratorManager`, probes backends, and uses `AcceleratorTelemetryAdapter` to interact with device handles provided by `NVMLBackend`.
35+
36+
**Verification Step:**
37+
Verify that `nvml_monitor.py` correctly fetches device count and metrics via the adapter. The adapter ensures that underlying `pynvml` calls are routed through the `AcceleratorManager`'s backend instance.
38+
39+
#### B. Health Checks
40+
41+
**Refactored Logic:**
42+
`check_nvidia_smi` now instantiates `AcceleratorManager` and uses `AcceleratorTelemetryAdapter` to perform checks.
43+
44+
**Verification Step:**
45+
Verify that `check_nvidia_smi.py` correctly detects GPU count and running processes via the adapter.
46+
47+
## Refactoring Status
48+
49+
- **`gcm/accelerator`**: Core HAL interfaces and NVML backend implementation are complete.
50+
- **`nvml_monitor.py`**: Refactored to use `AcceleratorManager` via `AcceleratorTelemetryAdapter`.
51+
- **`check_nvidia_smi.py`**: Refactored to use `AcceleratorManager` via `AcceleratorTelemetryAdapter`.
52+
- **Legacy Shim**: Added `gcm/monitoring/accelerator_adapter.py` to bridge `DeviceTelemetryClient` calls to the HAL backend, ensuring 100% backward compatibility for methods not yet fully exposed in `MetricSet` (e.g., specific ECC error counts).
53+
54+
## Rollout Strategy
55+
56+
1. **Phase 1 (Current PR)**: Introduce HAL, migrate all NVML usage to `AcceleratorManager` via adapter shim.
57+
2. **Phase 2 (Future)**: Update `nvml_monitor` logic to use `AcceleratorManager.read_metrics()` directly, removing dependency on `DeviceTelemetryClient` interface once `MetricSet` is expanded to cover all needs.
58+
59+
This incremental approach ensures that the new architecture is active immediately while minimizing risk to existing business logic.

gcm/accelerator/backends/nvml.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,21 @@ def read_metrics(self, device: DeviceHandle, _request: MetricRequest) -> MetricS
160160
),
161161
)
162162

163+
def get_raw_handle(self, device_id: str) -> Any:
164+
client = self._ensure_client()
165+
if device_id in self._handles:
166+
return self._handles[device_id]
167+
168+
try:
169+
index = int(device_id)
170+
handle = client.get_device_by_index(index)
171+
self._handles[device_id] = handle
172+
return handle
173+
except (ValueError, DeviceTelemetryException) as e:
174+
raise UnsupportedOperationError(
175+
f"invalid NVML device id: {device_id}"
176+
) from e
177+
163178
def close(self) -> None:
164179
client = self._client
165180
self._client = None

gcm/health_checks/checks/check_nvidia_smi.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import click
2323
import gni_lib
2424
import psutil
25+
from gcm.accelerator.manager import AcceleratorManager
26+
from gcm.accelerator.registry import default_backend_factories
2527
from gcm.health_checks.check_utils.output_context_manager import OutputContext
2628
from gcm.health_checks.check_utils.telem import TelemetryContext
2729
from gcm.health_checks.click import common_arguments, telemetry_argument
@@ -32,6 +34,7 @@
3234
from gcm.health_checks.env_variables import EnvCtx
3335
from gcm.health_checks.measurement_units import convert_bytes
3436
from gcm.health_checks.types import CHECK_TYPE, CheckEnv, ExitCode
37+
from gcm.monitoring.accelerator_adapter import AcceleratorTelemetryAdapter
3538
from gcm.monitoring.click import heterogeneous_cluster_v1_option
3639
from gcm.monitoring.device_telemetry_client import (
3740
DeviceTelemetryClient,
@@ -60,10 +63,10 @@ class NvidiaSmiCliImpl:
6063
log_folder: str
6164

6265
def get_device_telemetry(self) -> DeviceTelemetryClient:
63-
# Fallback to direct NVML client until check_nvidia_smi is refactored
64-
from gcm.monitoring.device_telemetry_nvml import NVMLDeviceTelemetryClient
65-
66-
return NVMLDeviceTelemetryClient()
66+
# Use Accelerator Manager + Adapter for legacy support
67+
# This ensures all paths go through the new accelerator interface
68+
manager = AcceleratorManager(factories=default_backend_factories())
69+
return AcceleratorTelemetryAdapter(manager)
6770

6871

6972
def check_gpu_num(
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
from gcm.accelerator.backend import BackendName
5+
from gcm.accelerator.manager import AcceleratorManager
6+
from gcm.monitoring.device_telemetry_client import DeviceTelemetryClient, GPUDevice
7+
8+
9+
class AcceleratorTelemetryAdapter(DeviceTelemetryClient):
10+
"""
11+
Adapter to allow legacy code expecting DeviceTelemetryClient/GPUDevice
12+
to function using AcceleratorManager.
13+
"""
14+
15+
def __init__(self, manager: AcceleratorManager):
16+
self._manager = manager
17+
# Ensure we have probed
18+
self._manager.probe_all()
19+
20+
def get_device_count(self) -> int:
21+
backend = self._manager.get_backend(BackendName.NVML)
22+
# If NVML backend isn't available, count is 0
23+
if not backend:
24+
return 0
25+
26+
# Enumerate to get count.
27+
return len(backend.enumerate_devices())
28+
29+
def get_device_by_index(self, index: int) -> GPUDevice:
30+
backend = self._manager.get_backend(BackendName.NVML)
31+
if not backend:
32+
raise IndexError("NVML Backend not available")
33+
34+
# We need to access get_raw_handle which we added to NVMLBackend
35+
# We can detect it dynamically
36+
if hasattr(backend, "get_raw_handle"):
37+
return backend.get_raw_handle(str(index)) # type: ignore[attr-defined]
38+
39+
raise NotImplementedError(
40+
"Backend does not support raw handle access needed for legacy code"
41+
)

gcm/monitoring/cli/nvml_monitor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from gcm.accelerator.manager import AcceleratorManager
3030
from gcm.accelerator.registry import default_backend_factories
3131
from gcm.exporters import registry
32+
from gcm.monitoring.accelerator_adapter import AcceleratorTelemetryAdapter
3233
from gcm.monitoring.accumulate import Accumulator
3334
from gcm.monitoring.click import (
3435
click_default_cmd,
@@ -278,10 +279,10 @@ class CliObjectImpl:
278279
clock: Clock = field(default_factory=ClockImpl)
279280

280281
def get_device_telemetry(self) -> DeviceTelemetryClient:
281-
# Fallback to direct NVML client if needed, or update to use HAL
282-
from gcm.monitoring.device_telemetry_nvml import NVMLDeviceTelemetryClient
283-
284-
return NVMLDeviceTelemetryClient()
282+
# Use Accelerator Manager + Adapter for legacy support
283+
# This ensures all paths go through the new accelerator interface
284+
manager = AcceleratorManager(factories=default_backend_factories())
285+
return AcceleratorTelemetryAdapter(manager)
285286

286287
def read_env(self, process_id: int) -> Env:
287288
return read_environ_from_proc(process_id)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
from unittest.mock import patch
4+
5+
from gcm.health_checks.checks.check_nvidia_smi import NvidiaSmiCliImpl
6+
from gcm.monitoring.accelerator_adapter import AcceleratorTelemetryAdapter
7+
8+
9+
def test_nvidia_smi_cli_impl_uses_hal_adapter() -> None:
10+
# Patch default_backend_factories to avoid actual registry access
11+
with (
12+
patch("gcm.health_checks.checks.check_nvidia_smi.default_backend_factories"),
13+
patch(
14+
"gcm.health_checks.checks.check_nvidia_smi.AcceleratorManager"
15+
) as MockManager,
16+
):
17+
# Mock manager instance
18+
manager_instance = MockManager.return_value
19+
20+
cli = NvidiaSmiCliImpl(
21+
cluster="test_cluster", type="test_type", log_level="INFO", log_folder="."
22+
)
23+
telemetry = cli.get_device_telemetry()
24+
25+
assert isinstance(telemetry, AcceleratorTelemetryAdapter)
26+
# Verify manager was initialized and probed
27+
MockManager.assert_called()
28+
manager_instance.probe_all.assert_called()
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
from unittest.mock import patch
4+
5+
from gcm.monitoring.accelerator_adapter import AcceleratorTelemetryAdapter
6+
from gcm.monitoring.cli.nvml_monitor import CliObjectImpl
7+
8+
9+
def test_cli_object_impl_uses_hal_adapter() -> None:
10+
# Patch default_backend_factories to avoid actual registry access
11+
with (
12+
patch("gcm.monitoring.cli.nvml_monitor.default_backend_factories"),
13+
patch("gcm.monitoring.cli.nvml_monitor.AcceleratorManager") as MockManager,
14+
):
15+
# Mock manager instance
16+
manager_instance = MockManager.return_value
17+
18+
cli = CliObjectImpl()
19+
telemetry = cli.get_device_telemetry()
20+
21+
assert isinstance(telemetry, AcceleratorTelemetryAdapter)
22+
# Verify manager was initialized and probed
23+
MockManager.assert_called()
24+
manager_instance.probe_all.assert_called()

0 commit comments

Comments
 (0)