fix splitk triton kernel int32 overflow (fairinternal/xformers#1462) #434
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: gpu_test_gh | |
| on: | |
| workflow_dispatch: {} | |
| pull_request: | |
| paths: | |
| - "xformers/**" | |
| - "!xformers/benchmarks/**" | |
| - "!xformers/version.txt" | |
| - ".github/workflows/gpu_test_gh*" | |
| - "tests/**" | |
| - "setup.py" | |
| - "requirements*.txt" | |
| - "third_party/**" | |
| push: | |
| branches: | |
| - main | |
| env: | |
| XFORMERS_BUILD_TYPE: "Release" | |
| CI: "1" | |
| TORCHINDUCTOR_COMPILE_THREADS: "1" | |
| jobs: | |
| gpu_test_gh: | |
| strategy: | |
| fail-fast: false | |
| matrix: | |
| gpu: | |
| - runner: "h100-256GB" | |
| sm: "9.0a" | |
| - runner: "4-core-ubuntu-gpu-t4" | |
| sm: "7.5" | |
| python: [3.11] | |
| name: test_sm${{ matrix.gpu.sm }} | |
| runs-on: ${{ matrix.gpu.runner }} | |
| timeout-minutes: 360 | |
| defaults: | |
| run: | |
| shell: bash -l {0} | |
| steps: | |
| - name: Recursive checkout | |
| uses: actions/checkout@v3 | |
| with: | |
| submodules: recursive | |
| path: "." | |
| fetch-depth: 0 # We need commits history as well | |
| - run: nvidia-smi | |
| - name: Install micromamba | |
| run: | | |
| set -ex | |
| curl -Ls https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xvj bin/micromamba | |
| echo "eval \"\$($(pwd)/bin/micromamba shell hook --shell bash)\"" >> ~/.profile | |
| cat ~/.profile | |
| - name: Create environment | |
| run: | | |
| set -ex | |
| micromamba config set channel_priority strict | |
| micromamba create -n env python=${{ matrix.python }} \ | |
| zlib pip ninja ccache=4.8 cuda-toolkit \ | |
| -c "nvidia/label/cuda-12.6" -c conda-forge -q -y | |
| - name: Activate environment | |
| shell: bash -l {0} | |
| run: | | |
| echo "micromamba activate env" >> ~/.profile | |
| echo "==== .profile =====" | |
| cat ~/.profile | |
| - name: Selective build/tests | |
| if: github.event_name == 'pull_request' | |
| run: | | |
| pip install -r .github/selective_ci/requirements.txt | |
| python .github/selective_ci/selective_ci.py --base_commit ${{ github.event.pull_request.base.sha }} | |
| - name: Setup test requirements | |
| run: | | |
| which python | |
| which nvcc | |
| pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126 | |
| pip install -r requirements-test.txt --progress-bar off | |
| - run: TORCH_CUDA_ARCH_LIST=${{ matrix.gpu.sm }} python -m pip install -v --no-build-isolation -e . | |
| - run: python -m xformers.info | |
| - name: xFormers import should not init cuda context | |
| run: | | |
| # NOTE: we check GPU version by default to determine if triton should be used | |
| # and this initializes CUDA context, unless we set `XFORMERS_ENABLE_TRITON` | |
| XFORMERS_ENABLE_TRITON=1 python -c "import xformers; import xformers.ops; import torch; assert not torch.cuda.is_initialized()" | |
| - name: Check for PyTorch stable symbols | |
| run: | | |
| bad_symbols=$(nm --dynamic --undefined-only --demangle xformers/_C.so | grep --extended-regexp "(torch|at|c10|c10d)::" || true) | |
| if [[ $bad_symbols != "" ]]; then echo "These non-stable PyTorch symbols made it into the xFormers shared library:"; echo $bad_symbols; exit 1; fi | |
| - name: FA3 was built | |
| run: | | |
| if [[ ${{ matrix.gpu.sm }} != 7.5 && -z "$XFORMERS_DISABLE_FLASH_ATTN" ]] | |
| then | |
| python -c "from xformers.ops.fmha.flash3 import _C_flashattention3 ; assert _C_flashattention3 is not None, 'FA3 not built'" | |
| fi | |
| - name: Unit tests | |
| run: | | |
| python -m pytest --verbose --random-order-bucket=global --maxfail=20 --junitxml=test-results/junit.xml --cov-report=xml --cov=./ tests | |
| - name: Publish Test Report | |
| uses: mikepenz/action-junit-report@v3 | |
| if: success() || failure() # always run even if the previous step fails | |
| with: | |
| report_paths: 'test-results/*.xml' |