Skip to content

Commit e72956e

Browse files
committed
merge
2 parents ca03dd4 + 9182523 commit e72956e

11 files changed

+1018
-298
lines changed

launch.sh

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@ done
1616

1717
# Set default log file if not provided
1818
if [ -z "$LOG_PATH" ]; then
19-
LOG_PATH="$PWD/output/output_$EXP_NAME.log"
19+
LOG_PATH="$PWD/output/"
2020
fi
2121

2222
export HF_TOKEN=""
2323
export HF_HOME="/app/hf_home/"
2424

25-
# export ROCR_VISIBLE_DEVICES="4,5,6,7"
26-
2725
export MIOPEN_CUSTOM_CACHE_DIR="/app/.cache/miopen/"
2826
export JAX_COMPILATION_CACHE_DIR="/app/.cache/jax/"
2927
export JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES="all"
@@ -54,7 +52,7 @@ export NVTE_CK_HOW_V3_BF16_CVT=1 # default
5452
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1
5553

5654
export NCCL_IB_HCA=bnxt_re0,bnxt_re1,bnxt_re2,bnxt_re3,bnxt_re4,bnxt_re5,bnxt_re6,bnxt_re7
57-
export NCCL_SOCKET_IFNAME=ens51f1np1
55+
export NCCL_SOCKET_IFNAME=enp159s0np0
5856
export NCCL_IB_GID_INDEX=3
5957
export NCCL_PROTO=Simple
6058

@@ -65,15 +63,14 @@ export GPU_MAX_HW_QUEUES=2
6563
export HIP_FORCE_DEV_KERNARG=1
6664
export HSA_NO_SCRATCH_RECLAIM=1
6765
# NCCL flags
68-
export NCCL_DEBUG=INFO #WARN, INFO
66+
export NCCL_DEBUG=WARN #WARN, INFO
6967
# export NCCL_DEBUG_SUBSYS=ALL
70-
# export RCCL_REPLAY_FILE=/shared_nfs/jianhan/slurm_logs-${SCALING_EXP}/cohere-${SLURM_JOB_NUM_NODES}N-8x22B-${SLURM_JOB_ID}-${timestamp}/mixtral_8x-22b_128N_run.bin
7168
export NCCL_PROTO=Simple
7269
export NCCL_IB_TIMEOUT=20
7370
export NCCL_IB_TC=41
7471
export NCCL_IB_SL=0
7572

76-
export GLOO_SOCKET_IFNAME=ens51f1np1
73+
export GLOO_SOCKET_IFNAME=${NCCL_SOCKET_IFNAME}
7774
export NCCL_CROSS_NIC=0
7875
export NCCL_CHECKS_DISABLE=1
7976
export NCCL_IB_QPS_PER_CONNECTION=1
@@ -100,22 +97,23 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enabl
10097

10198
rm -rf /app/.cache/*
10299
python3 setup.py develop
100+
ulimit -n 4096
103101

104-
EXP_NAME="WAN_train"
102+
EXP_NAME="train"
105103
LOG_FILE="$LOG_PATH/output_$HOST_NAME.log"
106104

105+
107106
# python -m src.maxdiffusion.train_flux src/maxdiffusion/configs/base_flux_dev.yml \
108107
python -m src.maxdiffusion.train_wan src/maxdiffusion/configs/base_wan_14b.yml \
109-
run_name="run_$EXP_NAME" output_dir="$PWD/output" \
108+
run_name="run_$EXP_NAME" output_dir="$LOG_PATH" \
110109
hardware=gpu \
111110
attention=cudnn_flash_te \
112-
max_train_steps=10 \
113-
dcn_data_parallelism=-1 \
114-
dcn_fsdp_batch_parallelism=1 \
111+
max_train_steps=20 \
112+
dcn_data_parallelism=1 \
113+
dcn_fsdp_parallelism=-1 \
115114
ici_data_parallelism=1 \
116115
ici_fsdp_parallelism=8 \
117116
per_device_batch_size=1 \
118-
enable_ssim=False \
119117
"${FILTERED_ARGS[@]}" |& tee -a "$LOG_FILE"
120118

121119

multi_node/README.md

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Multi-Node WAN Training Guide
2+
3+
Distributed WAN model training across multiple nodes with AMD ROCm GPUs.
4+
5+
## Quick Start
6+
7+
### Option 1: Using Helper Script (Recommended)
8+
9+
```bash
10+
cd /home/amd/jianhan/github/maxdiffusion/multi_node
11+
12+
# Edit run_multinode_train.sh to set configuration and enable/disable steps
13+
bash run_multinode_train.sh
14+
```
15+
16+
### Option 2: Manual Execution
17+
18+
```bash
19+
# Set ALL required environment variables (no defaults)
20+
export COORDINATOR_IP="172.29.0.73"
21+
export IMAGE_TAG="maxdiffusion-multinode-train:v1"
22+
export MULTI_NODES_LOG_DIR="/home/amd/jianhan/multi_node_log"
23+
export SHARE_DOCKERFILE_PATH="/home/amd/jianhan/github/maxdiffusion/multi_node/docker/jax_maxdiffusion_wan2.1_train_inference.ubuntu.amd.Dockerfile"
24+
export SHARED_CODE_BASE_PATH="/home/amd/jianhan/github/maxdiffusion"
25+
export MAXDIFFUSION_DIR_IN_DOCKER="/app/maxdiffusion"
26+
export RUN_NAME="WAN_14B_FSDP8"
27+
export REMOVE_IMAGES="n"
28+
export CHMOD_RUN="n"
29+
export REGISTRY_USERNAME="rocmshared"
30+
export REGISTRY_TOKEN="your_token"
31+
32+
# Run commands
33+
bash wan_multinode_train.sh "node1,node2,node3,node4" clean
34+
bash wan_multinode_train.sh "node1,node2,node3,node4" build
35+
bash wan_multinode_train.sh "node1,node2,node3,node4" launch
36+
37+
# Monitor training
38+
tail -f ${MULTI_NODES_LOG_DIR}/slurm_logs/${RUN_NAME}_*/node_*_rank_0.log
39+
```
40+
41+
## Prerequisites
42+
43+
- **Password-less SSH**: Set up SSH keys for all nodes
44+
```bash
45+
ssh-keygen -t ed25519 -C "multinode-training"
46+
for node in node1 node2 node3; do ssh-copy-id $node; done
47+
```
48+
- **Docker 20.10+** on all nodes
49+
- **AMD ROCm 5.7+** with MI250/MI300 GPUs
50+
- **Port 12345 open** between nodes (JAX coordinator)
51+
- **50GB+ disk space** per node
52+
53+
## Environment Variables
54+
55+
**All variables are required (no defaults):**
56+
57+
| Variable | Description | Example |
58+
|----------|-------------|---------|
59+
| `COORDINATOR_IP` | JAX coordinator IP | `172.29.0.73` |
60+
| `IMAGE_TAG` | Docker image name | `maxdiffusion-multinode-train:v1` |
61+
| `MULTI_NODES_LOG_DIR` | Base log directory | `/home/amd/jianhan/multi_node_log` |
62+
| `SHARE_DOCKERFILE_PATH` | Path to Dockerfile | `/home/amd/.../jax_maxdiffusion_wan2.1...Dockerfile` |
63+
| `SHARED_CODE_BASE_PATH` | Codebase path | `/home/amd/jianhan/github/maxdiffusion` |
64+
| `MAXDIFFUSION_DIR_IN_DOCKER` | Docker mount path | `/app/maxdiffusion` |
65+
| `RUN_NAME` | Experiment name | `WAN_14B_FSDP8` or `WAN_1_3B_FSDP8` |
66+
| `REMOVE_IMAGES` | Remove images on clean | `y` or `n` |
67+
| `CHMOD_RUN` | Only fix permissions (skip training) | `y` or `n` (default: `n`) |
68+
| `REGISTRY_USERNAME` | Docker Hub username | `rocmshared` |
69+
| `REGISTRY_TOKEN` | Docker Hub token | Your token |
70+
71+
## Scripts Overview
72+
73+
- **`run_multinode_train.sh`**: Helper script with pre-configured variables. Edit to set config and enable/disable steps
74+
- **`wan_multinode_train.sh`**: Main wrapper for clean/build/launch operations (requires all env vars)
75+
- **`wan_multinode_train_clean.sh`**: Cleans containers and syncs codebase via rsync
76+
- **`wan_multinode_train_build_docker.sh`**: Builds Docker images in parallel (5 retries)
77+
- **`wan_multinode_train_launch.sh`**: Launches distributed training with JAX
78+
79+
## Directory Structure
80+
81+
```
82+
multi_node_log/
83+
├── slurm_logs/
84+
│ ├── CLEAN_*N_*/ # Cleanup logs
85+
│ ├── BUILD_DOCKER_*N_*/ # Build logs
86+
│ └── ${RUN_NAME}_*N_*/ # Training logs (e.g., WAN_14B_FSDP8_4N_20260204-141300)
87+
│ ├── node_*_rank_0.log # Primary logs
88+
│ └── host_output.{out,err}
89+
└── output/
90+
└── ${RUN_NAME}_*N_*/ # Checkpoints
91+
```
92+
93+
## Typical Workflow
94+
95+
```bash
96+
# First time: Run all steps
97+
bash run_multinode_train.sh
98+
99+
# Code changes: Skip build (edit run_multinode_train.sh, comment out build line)
100+
# Dockerfile changes: Run build only (comment out clean and launch)
101+
# Quick iteration: Run clean + launch only (comment out build)
102+
```
103+
104+
## Common Commands
105+
106+
```bash
107+
# Change model
108+
export RUN_NAME="WAN_1_3B_FSDP8" # or WAN_14B_FSDP8
109+
110+
# Remove Docker images (free disk space)
111+
export REMOVE_IMAGES="y"
112+
113+
# Fix permissions only (no training) - useful for permission issues
114+
export CHMOD_RUN="y"
115+
bash wan_multinode_train.sh "node1,node2,node3,node4" launch
116+
117+
# Monitor latest run
118+
LATEST=$(ls -td ${MULTI_NODES_LOG_DIR}/slurm_logs/${RUN_NAME}_* | head -1)
119+
tail -f ${LATEST}/node_*_rank_0.log
120+
121+
# Average step time (exclude warmup)
122+
grep "seconds:" ${LATEST}/node_*_rank_0.log | tail -n +2 | \
123+
awk -F'seconds: ' '{print $2}' | awk '{sum+=$1; count++} END {printf "Avg: %.2fs\n", sum/count}'
124+
125+
# Check GPU utilization
126+
for node in core42-5-a08u01 core42-1-a08u07 core42-3-a08u19 core42-4-a08u25; do
127+
ssh $node "rocm-smi --showuse"
128+
done
129+
130+
# Check containers
131+
for node in core42-5-a08u01 core42-1-a08u07; do
132+
ssh $node "docker ps"
133+
done
134+
```
135+
136+
## Performance (WAN 14B, 4 nodes × 8 GPUs)
137+
138+
- **Batch size/device**: 1
139+
- **Resolution**: 1280×720 × 85 frames
140+
- **Speed**: ~82-83s/step (after warmup)
141+
- **Throughput**: ~255 TFLOP/s/device
142+
- **FPS/device**: ~1.03
143+
- **First step**: ~300s (JIT compilation)
144+
145+
**Single Node**: For testing, omit node list (defaults to single node): `bash wan_multinode_train.sh "" launch`
146+
Or specify: `bash wan_multinode_train.sh "core42-4-a08u25" launch`
147+
148+
## Troubleshooting
149+
150+
```bash
151+
# SSH issues
152+
ssh -vvv node1 # Test connectivity
153+
eval "$(ssh-agent -s)" && ssh-add ~/.ssh/id_ed25519
154+
155+
# Docker issues
156+
ssh node1 "docker ps" # Check Docker
157+
ssh node1 "sudo usermod -aG docker $USER" # Add to docker group
158+
159+
# JAX timeout
160+
ssh node1 "hostname -I" # Get coordinator IP
161+
export COORDINATOR_IP="172.29.0.XX"
162+
ssh node2 "nc -zv $COORDINATOR_IP 12345" # Test port
163+
164+
# GPU not visible
165+
ssh node1 "rocm-smi" # Check GPUs
166+
ssh node1 "docker run --rm --privileged -e HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ${IMAGE_TAG} rocm-smi"
167+
168+
# Build failures
169+
cat ${MULTI_NODES_LOG_DIR}/slurm_logs/BUILD_DOCKER_*/build_*.log
170+
ssh node1 "docker system prune -af" # Clean cache
171+
172+
# Permission issues (codebase not writable)
173+
export CHMOD_RUN="y"
174+
bash wan_multinode_train.sh "node1,node2" launch # Fix perms only
175+
176+
# Debug mode
177+
bash -x wan_multinode_train.sh "node1,node2" clean # Verbose
178+
```
179+
180+
## Log Analysis
181+
182+
```bash
183+
# Find latest run
184+
LOG_DIR=$(ls -td ${MULTI_NODES_LOG_DIR}/slurm_logs/${RUN_NAME}_* | head -1)
185+
186+
# View metrics
187+
grep "seconds:\|loss:\|TFLOP/s" ${LOG_DIR}/node_*_rank_0.log
188+
189+
# Calculate stats
190+
grep "seconds:" ${LOG_DIR}/node_*_rank_0.log | tail -n +2 | \
191+
awk -F'seconds: ' '{print $2}' | awk '{sum+=$1; count++} END {printf "Mean: %.2fs, Total: %d steps\n", sum/count, count}'
192+
```
193+
194+
## Resources
195+
196+
- [JAX Distributed](https://jax.readthedocs.io/en/latest/multi_process.html)
197+
- [AMD ROCm](https://rocmdocs.amd.com/)
198+
- [MaxDiffusion](https://github.com/google/maxdiffusion)

multi_node/docker/jax_maxdiffusion_wan2.1_train_inference.ubuntu.amd.Dockerfile

100755100644
Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
#
2626
#################################################################################
2727

28-
ARG BASE_DOCKER=rocm/pyt-megatron-lm-jax-nightly-private:jax_rocm7.1_jax_0.7.1_20251215
29-
# ARG BASE_DOCKER=rocm/jax-training:maxtext-v25.11
28+
ARG BASE_DOCKER=rocm/jax-training:maxtext-v25.11
3029
FROM $BASE_DOCKER
3130
USER root
3231
ENV WORKSPACE_DIR=/workspace
@@ -65,44 +64,10 @@ RUN pip install \
6564
typeguard==2.13.3 \
6665
qwix==0.1.5 --no-deps
6766

68-
#Download MaxDiffusion
69-
# RUN cd ${WORKSPACE_DIR} && \
70-
# git clone https://github.com/AI-Hypercomputer/maxdiffusion.git && \
71-
# cd maxdiffusion && \
72-
# git reset --hard "07b4d29c4a9bbdaafa501299275dcb15b5365034" && \
73-
# python3 setup.py develop
74-
# RUN cd ${WORKSPACE_DIR} && \
75-
# git clone https://github.com/cpersson-amd/maxdiffusion.git && \
76-
# cd maxdiffusion && \
77-
# git reset --hard "07b4d29c4a9bbdaafa501299275dcb15b5365034" && \
78-
# python3 setup.py develop
79-
8067
# Display installed packages for verification
8168
RUN pip list
8269

83-
# libaries for IB fabric
84-
RUN apt-get update
85-
RUN apt-get install -y libelf-dev unzip
86-
RUN apt-get install -y gcc make libtool autoconf librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils perftest ethtool libibverbs-dev rdma-core strace libibmad5 libibnetdisc5 ibverbs-providers libibumad-dev libibumad3 libibverbs1 libnl-3-dev libnl-route-3-dev
87-
88-
WORKDIR $WORKSPACE_DIR/
89-
90-
# The drivers should upgrade with each release and match the host version
91-
RUN wget https://docs.broadcom.com/docs-and-downloads/ethernet-network-adapters/NXE/Thor2/GCA1/bcm5760x_230.2.52.0a.zip
92-
RUN unzip bcm5760x_230.2.52.0a.zip
93-
RUN cd bcm5760x_230.2.52.0a/drivers_linux/bnxt_rocelib/ && \
94-
results=$(find -name "libbnxt*.tar.gz") && tar -xf $results && \
95-
untar_dir=$(find . -maxdepth 1 -type d -name "libbnxt*" ! -name "*.tar.gz" | head -n 1) && \
96-
cd $untar_dir && sh autogen.sh && ./configure && make && \
97-
find /usr/lib64/ /usr/lib -name "libbnxt_re-rdmav*.so" -exec mv {} {}.inbox \; && \
98-
make install all && sudo sh -c "echo /usr/local/lib >> /etc/ld.so.conf" && \
99-
sudo ldconfig && \
100-
cp -f bnxt_re.driver /etc/libibverbs.d/ && \
101-
find . -name "*.so" -exec md5sum {} \; && \
102-
BUILT_MD5SUM=$(find . -name "libbnxt_re-rdmav*.so" -exec md5sum {} \; | cut -d " " -f 1) && \
103-
echo -e "\n\nmd5sum of the built libbnxt_re is $BUILT_MD5SUM"
10470

105-
RUN ibv_devices
10671

10772

10873

multi_node/run_multinode_train.sh

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/bin/bash
2+
3+
# Required environment variables for wan_multinode_train.sh
4+
5+
# core42-4-a08u25:172.29.0.73
6+
export COORDINATOR_IP=172.29.0.73
7+
export IMAGE_TAG=your-name-wan-multinode-train:v1
8+
# Please keep MULTI_NODES_LOG_DIR outside SHARED_CODE_BASE_PATH since we are going to sync the whole SHARED_CODE_BASE_PATH
9+
export MULTI_NODES_LOG_DIR=/home/amd/your_dir/multi_node_log
10+
export SHARE_DOCKERFILE_PATH=/home/amd/your_dir/maxdiffusion/multi_node/docker/jax_maxdiffusion_wan2.1_train_inference.ubuntu.amd.Dockerfile
11+
export SHARED_CODE_BASE_PATH=/home/amd/your_dir/maxdiffusion
12+
export MAXDIFFUSION_DIR_IN_DOCKER=/app/maxdiffusion
13+
export RUN_NAME=WAN_14B_FSDP8
14+
export REMOVE_IMAGES=n
15+
export REGISTRY_USERNAME=""
16+
export REGISTRY_TOKEN=""
17+
export CHMOD_RUN=n
18+
19+
# Define node list
20+
# Please put the JAX COORDINATOR to the first of the list. The JAX COORDINATOR node will be launched before others to make sure all nodes can connect to the JAX COORDINATOR service.
21+
# core42-4-a08u25:172.29.0.73
22+
NODES="core42-4-a08u25,core42-1-a08u07,core42-3-a08u19,core42-5-a08u01"
23+
24+
# 1. Clean and sync codebase
25+
# To remove Docker images during cleanup, uncomment the line below:
26+
bash wan_multinode_train.sh "$NODES" clean
27+
28+
# # 2. Build Docker images (only when Dockerfile changes)
29+
bash wan_multinode_train.sh "$NODES" build
30+
31+
# # 3. Launch training
32+
bash wan_multinode_train.sh "$NODES" launch

multi_node/wan_multinode_train.sbatch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ echo "Building the container image on all nodes"
104104
srun bash -c '
105105
MAX_RETRIES=5
106106
INITIAL_DELAY=30 # seconds
107-
MAX_DELAY=1800 # seconds
107+
MAX_DELAY=180 # seconds
108108
RETRY_COUNT=0
109109
110110
while true; do

0 commit comments

Comments
 (0)