Skip to content

Commit 5cfd066

Browse files
d4l3kfacebook-github-bot
authored andcommitted
torchx/runner: log events to torch.monitor
Summary: This logs the `torchx.runner.events.Events` to `torch.monitor` as well as the existing event handlers. Once monitor is stable the existing ones will be removed entirely in favor of the new interface. `torch.monitor` is only available with pytorch 1.11 (or main) so it's a no-op if it's not available. Differential Revision: D33928333 fbshipit-source-id: 632a7e096f46e5c1946932d4a765be179e1e53a8
1 parent 787bb8f commit 5cfd066

File tree

3 files changed

+77
-1
lines changed

3 files changed

+77
-1
lines changed

torchx/runner/events/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ def _get_or_create_logger(destination: str = "null") -> logging.Logger:
5959
def record(event: TorchxEvent, destination: str = "null") -> None:
6060
_get_or_create_logger(destination).info(event.serialize())
6161

62+
if destination != "console":
63+
# if using torch>1.11 log the event to torch.monitor
64+
try:
65+
from torch import monitor
66+
67+
monitor.log_event(event.to_monitor_event())
68+
except ImportError:
69+
pass
70+
6271

6372
class log_event:
6473
"""

torchx/runner/events/api.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77

88
import json
99
from dataclasses import asdict, dataclass
10+
from datetime import datetime
1011
from enum import Enum
11-
from typing import Optional, Union
12+
from typing import Optional, Union, TYPE_CHECKING
13+
14+
if TYPE_CHECKING:
15+
from torch import monitor
1216

1317

1418
class SourceType(str, Enum):
@@ -60,3 +64,12 @@ def deserialize(data: Union[str, "TorchxEvent"]) -> "TorchxEvent":
6064

6165
def serialize(self) -> str:
6266
return json.dumps(asdict(self))
67+
68+
def to_monitor_event(self) -> "monitor.Event":
69+
from torch import monitor
70+
71+
return monitor.Event(
72+
name="torch.runner.Event",
73+
timestamp=datetime.now(),
74+
data={k: v for k, v in self.__dict__.items() if v is not None},
75+
)

torchx/runner/events/test/lib_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,24 @@
88
import json
99
import logging
1010
import unittest
11+
from typing import List
1112
from unittest.mock import patch, MagicMock
1213

1314
from torchx.runner.events import (
1415
_get_or_create_logger,
1516
SourceType,
1617
TorchxEvent,
1718
log_event,
19+
record,
1820
)
1921

22+
try:
23+
from torch import monitor
24+
25+
SKIP_MONITOR: bool = False
26+
except ImportError:
27+
SKIP_MONITOR: bool = True
28+
2029

2130
class TorchxEventLibTest(unittest.TestCase):
2231
def assert_event(
@@ -57,6 +66,51 @@ def test_event_deser(self) -> None:
5766
deser_event = TorchxEvent.deserialize(json_event)
5867
self.assert_event(event, deser_event)
5968

69+
@unittest.skipIf(SKIP_MONITOR, "no torch.monitor available")
70+
def test_monitor(self) -> None:
71+
event = TorchxEvent(
72+
session="test_session",
73+
scheduler="test_scheduler",
74+
api="test_api",
75+
source=SourceType.EXTERNAL,
76+
)
77+
monitor_event = event.to_monitor_event()
78+
self.assertEqual(
79+
monitor_event.data,
80+
{
81+
"session": "test_session",
82+
"scheduler": "test_scheduler",
83+
"api": "test_api",
84+
"source": "EXTERNAL",
85+
},
86+
)
87+
self.assertEqual(monitor_event.name, "torch.runner.Event")
88+
89+
@unittest.skipIf(SKIP_MONITOR, "no torch.monitor available")
90+
@patch("torchx.runner.events._get_or_create_logger")
91+
def test_monitor_record(self, get_logging_handler: MagicMock) -> None:
92+
event = TorchxEvent(
93+
session="test_session",
94+
scheduler="test_scheduler",
95+
api="test_api",
96+
source=SourceType.EXTERNAL,
97+
)
98+
events: List[monitor.Event] = []
99+
100+
def handler(e: monitor.Event) -> None:
101+
events.append(e)
102+
103+
handle = monitor.register_event_handler(handler)
104+
105+
try:
106+
record(event)
107+
finally:
108+
monitor.unregister_event_handler(handle)
109+
110+
self.assertEqual(get_logging_handler.call_count, 1)
111+
self.assertEqual(len(events), 1)
112+
self.assertEqual(events[0].data["session"], "test_session")
113+
60114

61115
@patch("torchx.runner.events.record")
62116
class LogEventTest(unittest.TestCase):

0 commit comments

Comments
 (0)