Skip to content

Commit 7fbb4cd

Browse files
committed
schedulers/aws_batch: fix thread local sessions + raise error on missing memory resource
1 parent 90b05b0 commit 7fbb4cd

File tree

2 files changed

+58
-14
lines changed

2 files changed

+58
-14
lines changed

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,17 @@
3838
import threading
3939
from dataclasses import dataclass
4040
from datetime import datetime
41-
from typing import Dict, Iterable, Mapping, Optional, Any, TYPE_CHECKING, Tuple
41+
from typing import (
42+
Dict,
43+
Iterable,
44+
Mapping,
45+
Optional,
46+
Any,
47+
TYPE_CHECKING,
48+
Tuple,
49+
TypeVar,
50+
Callable,
51+
)
4252

4353
import torchx
4454
import yaml
@@ -86,8 +96,10 @@ def _role_to_node_properties(idx: int, role: Role) -> Dict[str, object]:
8696
reqs.append({"type": "VCPU", "value": str(cpu)})
8797

8898
mem = resource.memMB
89-
if mem <= 0:
90-
mem = 1000
99+
if mem < 0:
100+
raise ValueError(
101+
f"AWSBatchScheduler requires memMB to be set to a positive value, got {mem}"
102+
)
91103
reqs.append({"type": "MEMORY", "value": str(mem)})
92104

93105
if resource.gpu > 0:
@@ -157,17 +169,29 @@ def __repr__(self) -> str:
157169
return str(self)
158170

159171

160-
def _thread_local_session() -> "boto3.session.Session":
161-
KEY = "torchx_aws_batch_session"
162-
local = threading.local()
163-
if hasattr(local, KEY):
164-
# pyre-ignore[16]
165-
return getattr(local, KEY)
172+
T = TypeVar("T")
173+
174+
175+
def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]:
176+
local: threading.local = threading.local()
177+
key: str = "value"
178+
179+
def wrapper() -> T:
180+
if key in local.__dict__:
181+
return local.__dict__[key]
182+
183+
v = f()
184+
local.__dict__[key] = v
185+
return v
186+
187+
return wrapper
188+
189+
190+
@_thread_local_cache
191+
def _local_session() -> "boto3.session.Session":
166192
import boto3.session
167193

168-
session = boto3.session.Session()
169-
setattr(local, KEY, session)
170-
return session
194+
return boto3.session.Session()
171195

172196

173197
class AWSBatchScheduler(Scheduler, DockerWorkspace):
@@ -239,14 +263,14 @@ def __init__(
239263
def _client(self) -> Any:
240264
if self.__client:
241265
return self.__client
242-
return _thread_local_session().client("batch")
266+
return _local_session().client("batch")
243267

244268
@property
245269
# pyre-fixme[3]: Return annotation cannot be `Any`.
246270
def _log_client(self) -> Any:
247271
if self.__log_client:
248272
return self.__log_client
249-
return _thread_local_session().client("logs")
273+
return _local_session().client("logs")
250274

251275
def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str:
252276
cfg = dryrun_info._cfg

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import threading
78
import unittest
89
from contextlib import contextmanager
910
from typing import Generator
@@ -15,6 +16,7 @@
1516
create_scheduler,
1617
AWSBatchScheduler,
1718
_role_to_node_properties,
19+
_local_session,
1820
)
1921

2022

@@ -192,6 +194,11 @@ def test_volume_mounts(self) -> None:
192194
mounts=[
193195
specs.VolumeMount(src="efsid", dst_path="/dst", read_only=True),
194196
],
197+
resource=specs.Resource(
198+
cpu=1,
199+
memMB=1000,
200+
gpu=0,
201+
),
195202
)
196203
props = _role_to_node_properties(0, role)
197204
self.assertEqual(
@@ -396,3 +403,16 @@ def test_log_iter(self) -> None:
396403
"foobar",
397404
],
398405
)
406+
407+
def test_local_session(self) -> None:
408+
a: object = _local_session()
409+
self.assertIs(a, _local_session())
410+
411+
def worker() -> None:
412+
b = _local_session()
413+
self.assertIs(b, _local_session())
414+
self.assertIsNot(a, b)
415+
416+
t = threading.Thread(target=worker)
417+
t.start()
418+
t.join()

0 commit comments

Comments
 (0)