|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 | 16 | import os
|
| 17 | +import shutil |
17 | 18 | from typing import Dict, List
|
18 | 19 |
|
19 | 20 | import torch
|
20 | 21 | import wget
|
21 | 22 | from torch.hub import _get_torch_home
|
22 | 23 |
|
23 |
| -from nemo.utils import logging |
| 24 | +from nemo.utils import get_rank, logging |
24 | 25 |
|
25 | 26 | __all__ = [
|
26 | 27 | "get_megatron_lm_model",
|
@@ -202,16 +203,14 @@ def _download(path: str, url: str):
|
202 | 203 | if url is None:
|
203 | 204 | return None
|
204 | 205 |
|
205 |
| - if not os.path.exists(path): |
206 |
| - master_device = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 |
207 |
| - if not os.path.exists(path): |
208 |
| - if master_device: |
209 |
| - os.makedirs(MEGATRON_CACHE, exist_ok=True) |
210 |
| - logging.info(f"Downloading from {url}") |
211 |
| - wget.download(url, path) |
212 |
| - # wait until the master process downloads the file and writes it to the cache dir |
213 |
| - if torch.distributed.is_initialized(): |
214 |
| - torch.distributed.barrier() |
| 206 | + if get_rank.is_global_rank_zero() and not os.path.exists(path): |
| 207 | + os.makedirs(MEGATRON_CACHE, exist_ok=True) |
| 208 | + logging.info(f"Downloading from {url} to {path}") |
| 209 | + downloaded_path = wget.download(url) |
| 210 | + shutil.move(downloaded_path, path) |
| 211 | + # wait until the master process downloads the file and writes it to the cache dir |
| 212 | + if torch.distributed.is_initialized(): |
| 213 | + torch.distributed.barrier() |
215 | 214 |
|
216 | 215 | return path
|
217 | 216 |
|
|
0 commit comments