|
| 1 | +# Multi-Node WAN Training Guide |
| 2 | + |
| 3 | +Distributed WAN model training across multiple nodes with AMD ROCm GPUs. |
| 4 | + |
| 5 | +## Quick Start |
| 6 | + |
| 7 | +### Option 1: Using Helper Script (Recommended) |
| 8 | + |
| 9 | +```bash |
| 10 | +cd /home/amd/jianhan/github/maxdiffusion/multi_node |
| 11 | + |
| 12 | +# Edit run_multinode_train.sh to set configuration and enable/disable steps |
| 13 | +bash run_multinode_train.sh |
| 14 | +``` |
| 15 | + |
| 16 | +### Option 2: Manual Execution |
| 17 | + |
| 18 | +```bash |
| 19 | +# Set ALL required environment variables (no defaults) |
| 20 | +export COORDINATOR_IP="172.29.0.73" |
| 21 | +export IMAGE_TAG="maxdiffusion-multinode-train:v1" |
| 22 | +export MULTI_NODES_LOG_DIR="/home/amd/jianhan/multi_node_log" |
| 23 | +export SHARE_DOCKERFILE_PATH="/home/amd/jianhan/github/maxdiffusion/multi_node/docker/jax_maxdiffusion_wan2.1_train_inference.ubuntu.amd.Dockerfile" |
| 24 | +export SHARED_CODE_BASE_PATH="/home/amd/jianhan/github/maxdiffusion" |
| 25 | +export MAXDIFFUSION_DIR_IN_DOCKER="/app/maxdiffusion" |
| 26 | +export RUN_NAME="WAN_14B_FSDP8" |
| 27 | +export REMOVE_IMAGES="n" |
| 28 | +export CHMOD_RUN="n" |
| 29 | +export REGISTRY_USERNAME="rocmshared" |
| 30 | +export REGISTRY_TOKEN="your_token" |
| 31 | + |
| 32 | +# Run commands |
| 33 | +bash wan_multinode_train.sh "node1,node2,node3,node4" clean |
| 34 | +bash wan_multinode_train.sh "node1,node2,node3,node4" build |
| 35 | +bash wan_multinode_train.sh "node1,node2,node3,node4" launch |
| 36 | + |
| 37 | +# Monitor training |
| 38 | +tail -f ${MULTI_NODES_LOG_DIR}/slurm_logs/${RUN_NAME}_*/node_*_rank_0.log |
| 39 | +``` |
| 40 | + |
| 41 | +## Prerequisites |
| 42 | + |
| 43 | +- **Password-less SSH**: Set up SSH keys for all nodes |
| 44 | + ```bash |
| 45 | + ssh-keygen -t ed25519 -C "multinode-training" |
| 46 | + for node in node1 node2 node3; do ssh-copy-id $node; done |
| 47 | + ``` |
| 48 | +- **Docker 20.10+** on all nodes |
| 49 | +- **AMD ROCm 5.7+** with MI250/MI300 GPUs |
| 50 | +- **Port 12345 open** between nodes (JAX coordinator) |
| 51 | +- **50GB+ disk space** per node |
| 52 | + |
| 53 | +## Environment Variables |
| 54 | + |
| 55 | +**All variables are required (no defaults):** |
| 56 | + |
| 57 | +| Variable | Description | Example | |
| 58 | +|----------|-------------|---------| |
| 59 | +| `COORDINATOR_IP` | JAX coordinator IP | `172.29.0.73` | |
| 60 | +| `IMAGE_TAG` | Docker image name | `maxdiffusion-multinode-train:v1` | |
| 61 | +| `MULTI_NODES_LOG_DIR` | Base log directory | `/home/amd/jianhan/multi_node_log` | |
| 62 | +| `SHARE_DOCKERFILE_PATH` | Path to Dockerfile | `/home/amd/.../jax_maxdiffusion_wan2.1...Dockerfile` | |
| 63 | +| `SHARED_CODE_BASE_PATH` | Codebase path | `/home/amd/jianhan/github/maxdiffusion` | |
| 64 | +| `MAXDIFFUSION_DIR_IN_DOCKER` | Docker mount path | `/app/maxdiffusion` | |
| 65 | +| `RUN_NAME` | Experiment name | `WAN_14B_FSDP8` or `WAN_1_3B_FSDP8` | |
| 66 | +| `REMOVE_IMAGES` | Remove images on clean | `y` or `n` | |
| 67 | +| `CHMOD_RUN` | Only fix permissions (skip training) | `y` or `n` (default: `n`) | |
| 68 | +| `REGISTRY_USERNAME` | Docker Hub username | `rocmshared` | |
| 69 | +| `REGISTRY_TOKEN` | Docker Hub token | Your token | |
| 70 | + |
| 71 | +## Scripts Overview |
| 72 | + |
| 73 | +- **`run_multinode_train.sh`**: Helper script with pre-configured variables. Edit to set config and enable/disable steps |
| 74 | +- **`wan_multinode_train.sh`**: Main wrapper for clean/build/launch operations (requires all env vars) |
| 75 | +- **`wan_multinode_train_clean.sh`**: Cleans containers and syncs codebase via rsync |
| 76 | +- **`wan_multinode_train_build_docker.sh`**: Builds Docker images in parallel (5 retries) |
| 77 | +- **`wan_multinode_train_launch.sh`**: Launches distributed training with JAX |
| 78 | + |
| 79 | +## Directory Structure |
| 80 | + |
| 81 | +``` |
| 82 | +multi_node_log/ |
| 83 | +├── slurm_logs/ |
| 84 | +│ ├── CLEAN_*N_*/ # Cleanup logs |
| 85 | +│ ├── BUILD_DOCKER_*N_*/ # Build logs |
| 86 | +│ └── ${RUN_NAME}_*N_*/ # Training logs (e.g., WAN_14B_FSDP8_4N_20260204-141300) |
| 87 | +│ ├── node_*_rank_0.log # Primary logs |
| 88 | +│ └── host_output.{out,err} |
| 89 | +└── output/ |
| 90 | + └── ${RUN_NAME}_*N_*/ # Checkpoints |
| 91 | +``` |
| 92 | + |
| 93 | +## Typical Workflow |
| 94 | + |
| 95 | +```bash |
| 96 | +# First time: Run all steps |
| 97 | +bash run_multinode_train.sh |
| 98 | + |
| 99 | +# Code changes: Skip build (edit run_multinode_train.sh, comment out build line) |
| 100 | +# Dockerfile changes: Run build only (comment out clean and launch) |
| 101 | +# Quick iteration: Run clean + launch only (comment out build) |
| 102 | +``` |
| 103 | + |
| 104 | +## Common Commands |
| 105 | + |
| 106 | +```bash |
| 107 | +# Change model |
| 108 | +export RUN_NAME="WAN_1_3B_FSDP8" # or WAN_14B_FSDP8 |
| 109 | + |
| 110 | +# Remove Docker images (free disk space) |
| 111 | +export REMOVE_IMAGES="y" |
| 112 | + |
| 113 | +# Fix permissions only (no training) - useful for permission issues |
| 114 | +export CHMOD_RUN="y" |
| 115 | +bash wan_multinode_train.sh "node1,node2,node3,node4" launch |
| 116 | + |
| 117 | +# Monitor latest run |
| 118 | +LATEST=$(ls -td ${MULTI_NODES_LOG_DIR}/slurm_logs/${RUN_NAME}_* | head -1) |
| 119 | +tail -f ${LATEST}/node_*_rank_0.log |
| 120 | + |
| 121 | +# Average step time (exclude warmup) |
| 122 | +grep "seconds:" ${LATEST}/node_*_rank_0.log | tail -n +2 | \ |
| 123 | + awk -F'seconds: ' '{print $2}' | awk '{sum+=$1; count++} END {printf "Avg: %.2fs\n", sum/count}' |
| 124 | + |
| 125 | +# Check GPU utilization |
| 126 | +for node in core42-5-a08u01 core42-1-a08u07 core42-3-a08u19 core42-4-a08u25; do |
| 127 | + ssh $node "rocm-smi --showuse" |
| 128 | +done |
| 129 | + |
| 130 | +# Check containers |
| 131 | +for node in core42-5-a08u01 core42-1-a08u07; do |
| 132 | + ssh $node "docker ps" |
| 133 | +done |
| 134 | +``` |
| 135 | + |
| 136 | +## Performance (WAN 14B, 4 nodes × 8 GPUs) |
| 137 | + |
| 138 | +- **Batch size/device**: 1 |
| 139 | +- **Resolution**: 1280×720 × 85 frames |
| 140 | +- **Speed**: ~82-83s/step (after warmup) |
| 141 | +- **Throughput**: ~255 TFLOP/s/device |
| 142 | +- **FPS/device**: ~1.03 |
| 143 | +- **First step**: ~300s (JIT compilation) |
| 144 | + |
| 145 | +**Single Node**: For testing, omit node list (defaults to single node): `bash wan_multinode_train.sh "" launch` |
| 146 | +Or specify: `bash wan_multinode_train.sh "core42-4-a08u25" launch` |
| 147 | + |
| 148 | +## Troubleshooting |
| 149 | + |
| 150 | +```bash |
| 151 | +# SSH issues |
| 152 | +ssh -vvv node1 # Test connectivity |
| 153 | +eval "$(ssh-agent -s)" && ssh-add ~/.ssh/id_ed25519 |
| 154 | + |
| 155 | +# Docker issues |
| 156 | +ssh node1 "docker ps" # Check Docker |
| 157 | +ssh node1 "sudo usermod -aG docker $USER" # Add to docker group |
| 158 | + |
| 159 | +# JAX timeout |
| 160 | +ssh node1 "hostname -I" # Get coordinator IP |
| 161 | +export COORDINATOR_IP="172.29.0.XX" |
| 162 | +ssh node2 "nc -zv $COORDINATOR_IP 12345" # Test port |
| 163 | + |
| 164 | +# GPU not visible |
| 165 | +ssh node1 "rocm-smi" # Check GPUs |
| 166 | +ssh node1 "docker run --rm --privileged -e HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ${IMAGE_TAG} rocm-smi" |
| 167 | + |
| 168 | +# Build failures |
| 169 | +cat ${MULTI_NODES_LOG_DIR}/slurm_logs/BUILD_DOCKER_*/build_*.log |
| 170 | +ssh node1 "docker system prune -af" # Clean cache |
| 171 | + |
| 172 | +# Permission issues (codebase not writable) |
| 173 | +export CHMOD_RUN="y" |
| 174 | +bash wan_multinode_train.sh "node1,node2" launch # Fix perms only |
| 175 | + |
| 176 | +# Debug mode |
| 177 | +bash -x wan_multinode_train.sh "node1,node2" clean # Verbose |
| 178 | +``` |
| 179 | + |
| 180 | +## Log Analysis |
| 181 | + |
| 182 | +```bash |
| 183 | +# Find latest run |
| 184 | +LOG_DIR=$(ls -td ${MULTI_NODES_LOG_DIR}/slurm_logs/${RUN_NAME}_* | head -1) |
| 185 | + |
| 186 | +# View metrics |
| 187 | +grep "seconds:\|loss:\|TFLOP/s" ${LOG_DIR}/node_*_rank_0.log |
| 188 | + |
| 189 | +# Calculate stats |
| 190 | +grep "seconds:" ${LOG_DIR}/node_*_rank_0.log | tail -n +2 | \ |
| 191 | + awk -F'seconds: ' '{print $2}' | awk '{sum+=$1; count++} END {printf "Mean: %.2fs, Total: %d steps\n", sum/count, count}' |
| 192 | +``` |
| 193 | + |
| 194 | +## Resources |
| 195 | + |
| 196 | +- [JAX Distributed](https://jax.readthedocs.io/en/latest/multi_process.html) |
| 197 | +- [AMD ROCm](https://rocmdocs.amd.com/) |
| 198 | +- [MaxDiffusion](https://github.com/google/maxdiffusion) |
0 commit comments