Skip to content

feat(megatron-bridge): add pretrain trainer, hooks, and configs #3080

feat(megatron-bridge): add pretrain trainer, hooks, and configs

feat(megatron-bridge): add pretrain trainer, hooks, and configs #3080

Workflow file for this run

name: Primus-CI-TAS
on:
workflow_dispatch:
push:
branches:
- main
tags:
- "v*"
pull_request:
env:
PRIMUS_TURBO_COMMIT: 333b68d7c81b722b21b4aad10cd250c45f15027c # fix sm_scale none bug (#263)
PRIMUS_TURBO_AITER_COMMIT: e83f9903c07001a0ec29e85d223f6e6cdbe00859
ROCSHMEM_COMMIT: 17ff985c026f9f97f85068647e863ab541dd5645 # Update version to 3.2.0 for 7.2.0 rocm release (#351) (#355)
UCCL_COMMIT: 5afb4117893c58cc0c8557d9286336141a301053 # [EP]: fix fp8 error of internode_ll on amd gfx950 arch. (#710)
BASE_IMAGE: docker.io/rocm/primus:v26.2
MAXTEXT_BASE_IMAGE: docker.io/rocm/jax-training:maxtext-v26.1
jobs:
code-lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- run: echo "🎉 Begin Primus Python Lint."
- run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by GitHub!"
- uses: actions/checkout@v4
- run: git config --global --add safe.directory /github/workspace
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install Autoflake
run: pip install autoflake==2.3.1
- name: Run Autoflake
if: always()
run: |
output=$(autoflake . --remove-all-unused-imports --remove-unused-variables --expand-star-imports --ignore-init-module-imports --recursive)
if [[ -n "$output" ]]; then
echo "Autoflake check failed: $output"
exit 1
else
echo "Autoflake check success."
fi
- uses: isort/isort-action@v1
if: always()
with:
isort-version: "5.13.2"
configuration: "--profile black --check-only --diff"
- uses: psf/black@24.8.0
if: always()
with:
options: "--check --diff --color --verbose --line-length=110"
build-docker:
needs: [code-lint]
runs-on: build-docker
strategy:
matrix:
python-version: ["3.12"]
steps:
- run: echo "🎉 Begin Build Primus Docker Image."
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Show Environment Info
run: |
echo "Hostname: $(hostname)"
echo "PWD: $(pwd)"
echo "HOME: $HOME"
echo "GITHUB_WORKSPACE: $GITHUB_WORKSPACE"
echo "Runner Temp Dir: $RUNNER_TEMP"
echo "Runner Tool Cache: $RUNNER_TOOL_CACHE"
- name: Parse Commit Info
run: |
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "IMAGE_TAG=pr-${{ github.event.pull_request.number }}" >> $GITHUB_ENV
elif [[ "${{ github.event_name }}" == "push" && "${{ github.ref }}" == "refs/heads/main" ]]; then
echo "IMAGE_TAG=latest" >> $GITHUB_ENV
elif [[ "${{ github.event_name }}" == "release" ]]; then
TAG_NAME="${{ github.ref }}"
TAG="${TAG_NAME#refs/tags/}"
echo "IMAGE_TAG=$TAG" >> $GITHUB_ENV
else
echo "IMAGE_TAG=others" >> $GITHUB_ENV
fi
- name: Build and Push Docker Image
run: |
echo "> Login to ROCm Docker Hub"
docker login -u rocmshared -p ${{ secrets.ROCM_DOCKER_HUB_TOKEN }}
echo "> Build Docker Image with tag: ${{ env.IMAGE_TAG }}"
echo "> Build dependencies"
start_time=$(date +%s)
docker build -f $GITHUB_WORKSPACE/.github/workflows/docker/Dockerfile \
--network=host \
-t tasimage/primus:${{env.IMAGE_TAG}} \
--build-arg BASE_IMAGE=${BASE_IMAGE} \
--build-arg PRIMUS_TURBO_COMMIT=${PRIMUS_TURBO_COMMIT} \
--build-arg PRIMUS_TURBO_AITER_COMMIT=${PRIMUS_TURBO_AITER_COMMIT} \
--build-arg ROCSHMEM_COMMIT=${ROCSHMEM_COMMIT} \
--build-arg PRIMUS_TURBO_FRAMEWORK=PYTORCH \
--build-arg UCCL_COMMIT=${UCCL_COMMIT} \
$GITHUB_WORKSPACE/.github/workflows/docker
end_time=$(date +%s)
elapsed=$((end_time - start_time))
echo "⏱️ [build primus docker] Total elapsed time: ${elapsed} seconds"
docker tag tasimage/primus:${{env.IMAGE_TAG}} docker.io/tasimage/primus:${{env.IMAGE_TAG}}
docker login -u tasimage -p ${{ secrets.PRIMUS_DOCKER_HUB_TOKEN }}
docker push docker.io/tasimage/primus:${{env.IMAGE_TAG}}
docker login -u rocmshared -p ${{ secrets.ROCM_DOCKER_HUB_TOKEN }}
# # Primus v26.2 already includes AINIC under /workspace. Re-enable this
# # Dockerfile.ainic build only when we need to refresh the tasimage -ainic image.
# echo "> Build Docker Image with tag: ${{ env.IMAGE_TAG }}-ainic"
# start_time=$(date +%s)
# mkdir -p $GITHUB_WORKSPACE/.github/workflows/docker/ainic
# cp /apps/tas/0_public/primus_docker_ci/ainic/ainic_bundle_1.117.5-a-56.tar.gz $GITHUB_WORKSPACE/.github/workflows/docker/ainic/ || { echo "Error: Failed to copy ainic bundle"; exit 1; }
# docker build -f $GITHUB_WORKSPACE/.github/workflows/docker/Dockerfile.ainic \
# --network=host \
# -t tasimage/primus:${{env.IMAGE_TAG}}-ainic \
# --build-arg BASE_IMAGE=docker.io/tasimage/primus:${{env.IMAGE_TAG}} \
# --build-arg AINIC_BUNDLE_PATH=ainic \
# $GITHUB_WORKSPACE/.github/workflows/docker
# end_time=$(date +%s)
# elapsed=$((end_time - start_time))
# echo "⏱️ [build primus docker-ainic] Total elapsed time: ${elapsed} seconds"
# docker tag tasimage/primus:${{env.IMAGE_TAG}}-ainic docker.io/tasimage/primus:${{env.IMAGE_TAG}}-ainic
# docker login -u tasimage -p ${{ secrets.PRIMUS_DOCKER_HUB_TOKEN }}
# docker push docker.io/tasimage/primus:${{env.IMAGE_TAG}}-ainic
# docker login -u rocmshared -p ${{ secrets.ROCM_DOCKER_HUB_TOKEN }}
echo "> Build Docker Image with tag: ${{ env.IMAGE_TAG }}-v25.09-ainic"
start_time=$(date +%s)
mkdir -p $GITHUB_WORKSPACE/.github/workflows/docker/ainic
cp /apps/tas/0_public/primus_docker_ci/ainic/ainic_bundle_1.117.5-a-56.tar.gz $GITHUB_WORKSPACE/.github/workflows/docker/ainic/ || { echo "Error: Failed to copy ainic bundle"; exit 1; }
docker build -f $GITHUB_WORKSPACE/.github/workflows/docker/Dockerfile_v25.09_ainic \
--network=host \
-t tasimage/primus:${{env.IMAGE_TAG}}-v25.09-ainic \
--build-arg AINIC_BUNDLE_PATH=ainic \
--build-arg PRIMUS_TURBO_COMMIT=${PRIMUS_TURBO_COMMIT} \
$GITHUB_WORKSPACE/.github/workflows/docker
end_time=$(date +%s)
elapsed=$((end_time - start_time))
echo "⏱️ [build primus docker-v25.09-ainic] Total elapsed time: ${elapsed} seconds"
docker tag tasimage/primus:${{env.IMAGE_TAG}}-v25.09-ainic docker.io/tasimage/primus:${{env.IMAGE_TAG}}-v25.09-ainic
docker login -u tasimage -p ${{ secrets.PRIMUS_DOCKER_HUB_TOKEN }}
docker push docker.io/tasimage/primus:${{env.IMAGE_TAG}}-v25.09-ainic
docker login -u rocmshared -p ${{ secrets.ROCM_DOCKER_HUB_TOKEN }}
echo "> Build Docker Image with tag: ${{ env.IMAGE_TAG }}-jax"
start_time=$(date +%s)
docker build -f $GITHUB_WORKSPACE/.github/workflows/docker/Dockerfile \
--network=host \
-t tasimage/primus:${{env.IMAGE_TAG}}-jax \
--build-arg BASE_IMAGE=${MAXTEXT_BASE_IMAGE} \
--build-arg PRIMUS_TURBO_COMMIT=${PRIMUS_TURBO_COMMIT} \
--build-arg PRIMUS_TURBO_AITER_COMMIT=${PRIMUS_TURBO_AITER_COMMIT} \
--build-arg PRIMUS_TURBO_FRAMEWORK=JAX \
--build-arg UCCL_COMMIT=${UCCL_COMMIT} \
--build-arg ROCSHMEM_COMMIT=${ROCSHMEM_COMMIT} .
end_time=$(date +%s)
elapsed=$((end_time - start_time))
echo "⏱️ [build primus docker-jax] Total elapsed time: ${elapsed} seconds"
echo "> Docker tag image for Docker Hub"
docker tag tasimage/primus:${{env.IMAGE_TAG}}-jax docker.io/tasimage/primus:${{env.IMAGE_TAG}}-jax
docker login -u tasimage -p ${{ secrets.PRIMUS_DOCKER_HUB_TOKEN }}
docker push docker.io/tasimage/primus:${{env.IMAGE_TAG}}-jax
docker login -u rocmshared -p ${{ secrets.ROCM_DOCKER_HUB_TOKEN }}
echo "> Build Docker Image with tag: ${{ env.IMAGE_TAG }}-jax-ainic"
start_time=$(date +%s)
mkdir -p $GITHUB_WORKSPACE/.github/workflows/docker/ainic
cp /apps/tas/0_public/primus_docker_ci/ainic/ainic_bundle_1.117.5-a-56.tar.gz $GITHUB_WORKSPACE/.github/workflows/docker/ainic/ || { echo "Error: Failed to copy ainic bundle"; exit 1; }
docker build -f $GITHUB_WORKSPACE/.github/workflows/docker/Dockerfile_jax.ainic \
--network=host \
-t tasimage/primus:${{env.IMAGE_TAG}}-jax-ainic \
--build-arg BASE_IMAGE=${MAXTEXT_BASE_IMAGE} \
--build-arg AINIC_BUNDLE_PATH=ainic \
$GITHUB_WORKSPACE/.github/workflows/docker
end_time=$(date +%s)
elapsed=$((end_time - start_time))
echo "⏱️ [build primus docker-jax-ainic] Total elapsed time: ${elapsed} seconds"
docker tag tasimage/primus:${{env.IMAGE_TAG}}-jax-ainic docker.io/tasimage/primus:${{env.IMAGE_TAG}}-jax-ainic
docker login -u tasimage -p ${{ secrets.PRIMUS_DOCKER_HUB_TOKEN }}
docker push docker.io/tasimage/primus:${{env.IMAGE_TAG}}-jax-ainic
docker login -u rocmshared -p ${{ secrets.ROCM_DOCKER_HUB_TOKEN }}
# echo "> Docker cleanup local images"
# docker rmi tasimage/primus:${{env.IMAGE_TAG}}
# docker rmi tasimage/primus:${{env.IMAGE_TAG}}-v25.09-ainic
# docker rmi tasimage/primus:${{env.IMAGE_TAG}}-v25.10-ainic
# docker rmi tasimage/primus:${{env.IMAGE_TAG}}-jax
echo "> build-docker success"
run-unittest-torch:
env:
PRIMUS_WORKDIR: /mnt/apps_proxy/tas/0_public/primus_ci/actions-runner-torch
# PRIMUS_WORKDIR: /wekafs/primus-data/primus_safe_ci/torch
needs: [code-lint]
# runs-on: [primus-lm-cicd-torch-j8knc]
runs-on: [primus-lm-cicd-v26.2-tas8n-a16-40]
steps:
- run: echo "🎉 Begin Primus-Turbo Checkout."
- name: Set commit hash to env
run: echo "PRIMUS_TURBO_COMMIT=${PRIMUS_TURBO_COMMIT}" >> $GITHUB_ENV
- name: Checkout Repo Primus-Turbo
uses: actions/checkout@v4
with:
repository: AMD-AIG-AIMA/Primus-Turbo
submodules: "recursive"
path: Primus-Turbo
ref: ${{ env.PRIMUS_TURBO_COMMIT }}
- run: echo "Begin AITER + Primus-Turbo Install."
- name: Install AITER
run: |
echo "✅ [Uninstall old aiter] started at: $(date)"
pip3 uninstall aiter amd-aiter -y || true
rm -rf /tmp/aiter || true
cd /tmp
git clone --recursive https://github.com/ROCm/aiter.git
cd aiter
git checkout ${PRIMUS_TURBO_AITER_COMMIT}
git submodule update --init --recursive
start_time=$(date +%s)
echo "✅ [Build aiter] started at: $(date)"
PREBUILD_KERNELS=3 pip install --no-cache-dir --use-pep517 .
end_time=$(date +%s)
elapsed=$((end_time - start_time))
echo "✅ [Build aiter] ended at: $(date)"
echo "⏱️ [Build aiter] Total elapsed time: ${elapsed} seconds"
- name: Install Primus-Turbo
run: |
rm -rf /tmp/Primus-Turbo || true
mv Primus-Turbo /tmp/
echo "Primus-Turbo dir: /tmp/Primus-Turbo"
git config --global --add safe.directory /tmp/Primus-Turbo || true
cd /tmp/Primus-Turbo || true
start_time=$(date +%s)
echo "✅ [Pip install requirements] started at: $(date)"
mkdir -p ${PRIMUS_WORKDIR}/primus-cache
MAX_JOBS=128 pip install --cache-dir=${PRIMUS_WORKDIR}/primus-cache --no-build-isolation --no-clean -r requirements.txt
end_time=$(date +%s)
elapsed=$((end_time - start_time))
echo "✅ [Pip install requirements] ended at: $(date)"
echo "⏱️ [Pip install requirements] Total elapsed time: ${elapsed} seconds"
start_time=$(date +%s)
echo "✅ [build primus-turbo] started at: $(date)"
pip3 install --no-build-isolation -e . -v
end_time=$(date +%s)
elapsed=$((end_time - start_time))
echo "✅ [build primus-turbo] ended at: $(date)"
echo "⏱️ [build primus-turbo] Total elapsed time: ${elapsed} seconds"
- run: echo "🎉 Begin Primus Unit Test."
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Show Environment Info
run: |
echo "Hostname: $(hostname)"
echo "PWD: $(pwd)"
echo "HOME: $HOME"
echo "GITHUB_WORKSPACE: $GITHUB_WORKSPACE"
echo "Runner Temp Dir: $RUNNER_TEMP"
echo "Runner Tool Cache: $RUNNER_TOOL_CACHE"
- name: Install Primus
run: |
pip install -r requirements.txt
- name: Set UT_LOG_PATH
run: |
ts="$(date +%Y%m%d-%H%M%S)"
commit_id="${GITHUB_SHA::7}"
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "UT_LOG_PATH=${PRIMUS_WORKDIR}/ut_out/pr-${{ github.event.pull_request.number }}-${ts}-${commit_id}" >> $GITHUB_ENV
elif [[ "${{ github.event_name }}" == "push" && "${{ github.ref }}" == "refs/heads/main" ]]; then
echo "UT_LOG_PATH=${PRIMUS_WORKDIR}/ut_out/main-${ts}-${commit_id}" >> $GITHUB_ENV
elif [[ "${{ github.event_name }}" == "release" ]]; then
TAG_NAME="${{ github.ref }}"
TAG="${TAG_NAME#refs/tags/}"
echo "UT_LOG_PATH=${PRIMUS_WORKDIR}/ut_out/${TAG}-${ts}-${commit_id}" >> $GITHUB_ENV
else
echo "UT_LOG_PATH=${PRIMUS_WORKDIR}/ut_out/others-${ts}-${commit_id}" >> $GITHUB_ENV
fi
- name: Run CLI Shell Tests
run: |
echo "Running Primus CLI shell tests..."
bash ./tests/runner/run_all_tests.sh
- name: Run Primus Core Tests
run: |
echo "Running Primus Core tests..."
# Note: The tests `test_fp8_te_linear` and `test_te_linear` are temporarily skipped due to intermittent failures.
# Note HSA_NO_SCRATCH_RECLAIM=1 must be set to avoid RCCL perf hit (TAS-8N Node), rocm ver:70125424
export HSA_NO_SCRATCH_RECLAIM=1
pytest --maxfail=1 -s ./tests/unit_tests/ \
--deselect=tests/unit_tests/megatron/cco/test_tp_overlap.py::TPOverlapTestCase::test_fp8_te_linear \
--deselect=tests/unit_tests/megatron/cco/test_tp_overlap.py::TPOverlapTestCase::test_te_linear \
--deselect=tests/unit_tests/megatron/transformer/moe/test_token_dispatcher.py::TestFlexDispatcher::test_forward_backward \
--deselect=tests/unit_tests/megatron/transformer/moe/test_token_dispatcher.py::TestFlexDispatcher::test_capacity_forward_backward
- name: Run Primus Model Tests -- Megatron-LM
env:
HF_TOKEN: ${{secrets.HF_TOKEN}}
run: |
echo "Set UT_LOG_PATH: ${{ env.UT_LOG_PATH }}"
rm -rf "${{ env.UT_LOG_PATH }}"
mkdir -p "${{ env.UT_LOG_PATH }}"
# MASTER_PORT=10009 DATA_PATH=/wekafs/primus-data \
MASTER_PORT=10009 DATA_PATH=/mnt/apps_proxy/tas/0_public/data HSA_NO_SCRATCH_RECLAIM=1 \
pytest --maxfail=1 -s ./tests/trainer/test_megatron_trainer.py
- name: Run Primus Model Tests -- TorchTitan
env:
HF_TOKEN: ${{secrets.HF_TOKEN}}
run: |
echo "Set UT_LOG_PATH: ${{ env.UT_LOG_PATH }}"
rm -rf "${{ env.UT_LOG_PATH }}"
mkdir -p "${{ env.UT_LOG_PATH }}"
# MASTER_PORT=10009 DATA_PATH=/wekafs/primus-data \
MASTER_PORT=10009 DATA_PATH=/mnt/apps_proxy/tas/0_public/data HSA_NO_SCRATCH_RECLAIM=1 \
pytest --maxfail=1 -s ./tests/trainer/test_torchtitan_trainer.py
- name: Clean
run: |
rm -rf ${PRIMUS_WORKDIR}/Primus-Turbo
rm -rf ${PRIMUS_WORKDIR}/Primus
run-unittest-jax:
env:
PRIMUS_WORKDIR: /wekafs/primus-data/primus_safe_ci/jax
needs: [code-lint]
runs-on: [primus-lm-cicd-jax-m42vb]
steps:
- run: echo "🎉 Begin Primus-Turbo Checkout."
- name: Set commit hash to env
run: echo "PRIMUS_TURBO_COMMIT=${PRIMUS_TURBO_COMMIT}" >> $GITHUB_ENV
- name: Checkout Repo Primus-Turbo
uses: actions/checkout@v4
with:
repository: AMD-AIG-AIMA/Primus-Turbo
submodules: "recursive"
path: Primus-Turbo
ref: ${{ env.PRIMUS_TURBO_COMMIT }}
- run: echo "Begin Primus-Turbo Install."
- name: Install Primus-Turbo
run: |
mv Primus-Turbo /tmp/
echo "Primus-Turbo dir: /tmp/Primus-Turbo"
git config --global --add safe.directory /tmp/Primus-Turbo
cd /tmp/Primus-Turbo
start_time=$(date +%s)
echo "✅ [Pip install requirements] started at: $(date)"
mkdir -p ${PRIMUS_WORKDIR}/primus-cache
python3 -m pip install --upgrade pip setuptools
end_time=$(date +%s)
elapsed=$((end_time - start_time))
echo "✅ [Pip install requirements] ended at: $(date)"
echo "⏱️ [Pip install requirements] Total elapsed time: ${elapsed} seconds"
start_time=$(date +%s)
echo "✅ [build primus-turbo] started at: $(date)"
end_time=$(date +%s)
elapsed=$((end_time - start_time))
echo "✅ [build primus-turbo] ended at: $(date)"
echo "⏱️ [build primus-turbo] Torch installation causes segfault, so we skip it and actually not install turbo. Total elapsed time: ${elapsed} seconds"
- run: echo "🎉 Begin Primus Unit Test."
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Show Environment Info
run: |
echo "Hostname: $(hostname)"
echo "PWD: $(pwd)"
echo "HOME: $HOME"
echo "GITHUB_WORKSPACE: $GITHUB_WORKSPACE"
echo "Runner Temp Dir: $RUNNER_TEMP"
echo "Runner Tool Cache: $RUNNER_TOOL_CACHE"
- name: Install Primus
run: |
pip install -r requirements-jax.txt
- name: Set UT_LOG_PATH
run: |
ts="$(date +%Y%m%d-%H%M%S)"
commit_id="${GITHUB_SHA::7}"
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "UT_LOG_PATH=${PRIMUS_WORKDIR}/ut_out/pr-${{ github.event.pull_request.number }}-${ts}-${commit_id}" >> $GITHUB_ENV
elif [[ "${{ github.event_name }}" == "push" && "${{ github.ref }}" == "refs/heads/main" ]]; then
echo "UT_LOG_PATH=${PRIMUS_WORKDIR}/ut_out/main-${ts}-${commit_id}" >> $GITHUB_ENV
elif [[ "${{ github.event_name }}" == "release" ]]; then
TAG_NAME="${{ github.ref }}"
TAG="${TAG_NAME#refs/tags/}"
echo "UT_LOG_PATH=${PRIMUS_WORKDIR}/ut_out/${TAG}-${ts}-${commit_id}" >> $GITHUB_ENV
else
echo "UT_LOG_PATH=${PRIMUS_WORKDIR}/ut_out/others-${ts}-${commit_id}" >> $GITHUB_ENV
fi
- name: Run Shell Tests
run: |
echo "Running Primus CLI shell tests..."
bash ./tests/runner/run_all_tests.sh
- name: Run Unit Tests
env:
HF_TOKEN: ${{secrets.HF_TOKEN}}
run: |
echo "Set UT_LOG_PATH: ${{ env.UT_LOG_PATH }}"
rm -rf "${{ env.UT_LOG_PATH }}"
mkdir -p "${{ env.UT_LOG_PATH }}"
MASTER_PORT=10009 DATA_PATH=/wekafs/primus-data \
JAX_SKIP_UT=1 python ./tests/run_unit_tests.py --jax
- name: Clean
run: |
rm -rf ${PRIMUS_WORKDIR}/Primus-Turbo
rm -rf ${PRIMUS_WORKDIR}/Primus