|
38 | 38 | import threading
|
39 | 39 | from dataclasses import dataclass
|
40 | 40 | from datetime import datetime
|
41 |
| -from typing import Dict, Iterable, Mapping, Optional, Any, TYPE_CHECKING, Tuple |
| 41 | +from typing import ( |
| 42 | + Dict, |
| 43 | + Iterable, |
| 44 | + Mapping, |
| 45 | + Optional, |
| 46 | + Any, |
| 47 | + TYPE_CHECKING, |
| 48 | + Tuple, |
| 49 | + TypeVar, |
| 50 | + Callable, |
| 51 | +) |
42 | 52 |
|
43 | 53 | import torchx
|
44 | 54 | import yaml
|
@@ -86,8 +96,10 @@ def _role_to_node_properties(idx: int, role: Role) -> Dict[str, object]:
|
86 | 96 | reqs.append({"type": "VCPU", "value": str(cpu)})
|
87 | 97 |
|
88 | 98 | mem = resource.memMB
|
89 |
| - if mem <= 0: |
90 |
| - mem = 1000 |
| 99 | + if mem < 0: |
| 100 | + raise ValueError( |
| 101 | + f"AWSBatchScheduler requires memMB to be set to a positive value, got {mem}" |
| 102 | + ) |
91 | 103 | reqs.append({"type": "MEMORY", "value": str(mem)})
|
92 | 104 |
|
93 | 105 | if resource.gpu > 0:
|
@@ -157,17 +169,29 @@ def __repr__(self) -> str:
|
157 | 169 | return str(self)
|
158 | 170 |
|
159 | 171 |
|
160 |
| -def _thread_local_session() -> "boto3.session.Session": |
161 |
| - KEY = "torchx_aws_batch_session" |
162 |
| - local = threading.local() |
163 |
| - if hasattr(local, KEY): |
164 |
| - # pyre-ignore[16] |
165 |
| - return getattr(local, KEY) |
| 172 | +T = TypeVar("T") |
| 173 | + |
| 174 | + |
| 175 | +def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]: |
| 176 | + local: threading.local = threading.local() |
| 177 | + key: str = "value" |
| 178 | + |
| 179 | + def wrapper() -> T: |
| 180 | + if key in local.__dict__: |
| 181 | + return local.__dict__[key] |
| 182 | + |
| 183 | + v = f() |
| 184 | + local.__dict__[key] = v |
| 185 | + return v |
| 186 | + |
| 187 | + return wrapper |
| 188 | + |
| 189 | + |
| 190 | +@_thread_local_cache |
| 191 | +def _local_session() -> "boto3.session.Session": |
166 | 192 | import boto3.session
|
167 | 193 |
|
168 |
| - session = boto3.session.Session() |
169 |
| - setattr(local, KEY, session) |
170 |
| - return session |
| 194 | + return boto3.session.Session() |
171 | 195 |
|
172 | 196 |
|
173 | 197 | class AWSBatchScheduler(Scheduler, DockerWorkspace):
|
@@ -239,14 +263,14 @@ def __init__(
|
239 | 263 | def _client(self) -> Any:
|
240 | 264 | if self.__client:
|
241 | 265 | return self.__client
|
242 |
| - return _thread_local_session().client("batch") |
| 266 | + return _local_session().client("batch") |
243 | 267 |
|
244 | 268 | @property
|
245 | 269 | # pyre-fixme[3]: Return annotation cannot be `Any`.
|
246 | 270 | def _log_client(self) -> Any:
|
247 | 271 | if self.__log_client:
|
248 | 272 | return self.__log_client
|
249 |
| - return _thread_local_session().client("logs") |
| 273 | + return _local_session().client("logs") |
250 | 274 |
|
251 | 275 | def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str:
|
252 | 276 | cfg = dryrun_info._cfg
|
|
0 commit comments