Skip to content

Commit 904536b

Browse files
findkimjubick1337
authored andcommitted
Fix race condition when executing with multi-node where some ranks does not wait for setup (#7016)
Signed-off-by: Kim Ngo <[email protected]> Signed-off-by: jubick1337 <[email protected]>
1 parent f7e33fc commit 904536b

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

nemo/collections/nlp/modules/common/megatron/megatron_utils.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
# limitations under the License.
1515

1616
import os
17+
import shutil
1718
from typing import Dict, List
1819

1920
import torch
2021
import wget
2122
from torch.hub import _get_torch_home
2223

23-
from nemo.utils import logging
24+
from nemo.utils import get_rank, logging
2425

2526
__all__ = [
2627
"get_megatron_lm_model",
@@ -202,16 +203,14 @@ def _download(path: str, url: str):
202203
if url is None:
203204
return None
204205

205-
if not os.path.exists(path):
206-
master_device = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
207-
if not os.path.exists(path):
208-
if master_device:
209-
os.makedirs(MEGATRON_CACHE, exist_ok=True)
210-
logging.info(f"Downloading from {url}")
211-
wget.download(url, path)
212-
# wait until the master process downloads the file and writes it to the cache dir
213-
if torch.distributed.is_initialized():
214-
torch.distributed.barrier()
206+
if get_rank.is_global_rank_zero() and not os.path.exists(path):
207+
os.makedirs(MEGATRON_CACHE, exist_ok=True)
208+
logging.info(f"Downloading from {url} to {path}")
209+
downloaded_path = wget.download(url)
210+
shutil.move(downloaded_path, path)
211+
# wait until the master process downloads the file and writes it to the cache dir
212+
if torch.distributed.is_initialized():
213+
torch.distributed.barrier()
215214

216215
return path
217216

0 commit comments

Comments
 (0)