Skip to content

Commit dc5b8d6

Browse files
Nikhil BansalOrbax Authors
authored andcommitted
POC cloud run
PiperOrigin-RevId: 880680159
1 parent f75444c commit dc5b8d6

File tree

8 files changed

+415
-6
lines changed

8 files changed

+415
-6
lines changed

.github/workflows/build_image.yml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
2+
# Copyright 2025 Google LLC
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
# https://www.apache.org/licenses/LICENSE-2.0
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# This workflow will build and push MaxText Docker image to GCR.
13+
name: Build and Push Orbax-checkpoint Docker Images
14+
on:
15+
schedule:
16+
# Run the job daily at 12AM UTC
17+
- cron: '0 0 * * *'
18+
push:
19+
branches:
20+
- "test_*"
21+
permissions:
22+
contents: read
23+
jobs:
24+
build_and_push:
25+
runs-on: linux-x86-n2-16-buildkit
26+
container: google/cloud-sdk:524.0.0
27+
steps:
28+
- name: Checkout Orbax-checkpoint
29+
uses: actions/checkout@v5
30+
- name: Mark git repositories as safe
31+
run: git config --global --add safe.directory '*'
32+
- name: Configure Docker
33+
run: gcloud auth configure-docker us-docker.pkg.dev,gcr.io -q
34+
- name: Set up Docker BuildX
35+
uses: docker/setup-buildx-action@v3.11.1
36+
with:
37+
driver: remote
38+
endpoint: tcp://localhost:1234
39+
- name: Build and push Docker image
40+
uses: docker/build-push-action@v6
41+
with:
42+
push: true
43+
context: .
44+
file: ./checkpoint/orbax/checkpoint/_src/testing/oss/Dockerfile
45+
tags: gcr.io/orbax-checkpoint/orbax-benchmarks-it-runner:latest
46+
cache-from: type=gha
47+
outputs: type=image,compression=zstd,force-compression=true
48+
build-args: |
49+
DEVICE=tpu
50+
JAX_VERSION=newest
51+
BRANCH=main
52+
GITHUB_RUNNER=true
53+
- name: Add tags to Docker images
54+
shell: bash
55+
run: |
56+
SOURCE_IMAGE="gcr.io/orbax-checkpoint/orbax-benchmarks-it-runner"
57+
# Add Orbax-checkpoint tag
58+
orbax_hash=$(git rev-parse --short HEAD)
59+
gcloud container images add-tag "$SOURCE_IMAGE:latest" "$SOURCE_IMAGE:orbax_${orbax_hash}" --quiet

checkpoint/orbax/checkpoint/_src/testing/benchmarks/emergency_checkpoint_manager_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def test_fn(
196196

197197
with metrics.measure("create_directories"):
198198
if jax.process_index() == 0:
199-
persistent_directory.mkdir(parents=True)
199+
persistent_directory.mkdir(parents=True, exist_ok=True)
200200
local_directory.mkdir(parents=True, exist_ok=True)
201201
multihost.sync_global_processes("create directories")
202202

checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,32 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
1313
dnsutils \
1414
&& rm -rf /var/lib/apt/lists/*
1515

16+
RUN apt-get update && apt-get install -y --no-install-recommends \
17+
git \
18+
dnsutils \
19+
curl \
20+
ca-certificates \
21+
gnupg \
22+
apt-transport-https \
23+
gettext-base \
24+
gawk \
25+
&& curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg \
26+
&& echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee /etc/apt/sources.list.d/google-cloud-sdk.list \
27+
&& apt-get update \
28+
&& apt-get install -y --no-install-recommends \
29+
google-cloud-cli \
30+
google-cloud-cli-gke-gcloud-auth-plugin \
31+
kubectl \
32+
&& apt-get clean \
33+
&& rm -rf /var/lib/apt/lists/*
34+
35+
ENV DOCKER_VERSION=24.0.7
36+
RUN curl -fsSL "https://download.docker.com/linux/static/stable/x86_64/docker-${DOCKER_VERSION}.tgz" -o docker.tgz \
37+
&& tar -xvzf docker.tgz \
38+
&& mv docker/docker /usr/local/bin/ \
39+
&& rm -rf docker docker.tgz \
40+
&& chmod +x /usr/local/bin/docker
41+
1642
# 2. Checkout Orbax (Optimized Shallow Fetch)
1743
ARG PR_NUMBER
1844
ARG BRANCH=main
@@ -50,6 +76,9 @@ ARG DEVICE=tpu
5076
# Install GCSFS and Portpicker
5177
RUN pip install --no-cache-dir gcsfs portpicker clu tensorflow
5278

79+
# Create a fake docker binary that always returns success (exit 0)
80+
RUN echo '#!/bin/bash\nexit 0' > /usr/local/bin/docker && chmod +x /usr/local/bin/docker
81+
5382
# Install requirements from repo root if it exists
5483
RUN if [ -f "requirements.txt" ]; then pip install --no-cache-dir -r requirements.txt; fi
5584

@@ -82,7 +111,7 @@ RUN if [ "$JAX_VERSION" = "newest" ]; then \
82111
# 4. Install Orbax from Source
83112
WORKDIR /app/orbax_repo/checkpoint
84113
RUN pip install --no-cache-dir .
85-
RUN pip install --no-cache-dir pathwaysutils
114+
RUN pip install xpk
86115

87116
# 5. Environment Setup
88117
# Set PYTHONPATH so 'import orbax' works from the correctly mapped directory
@@ -91,7 +120,30 @@ ENV PYTHONPATH=/app/orbax_repo/checkpoint
91120
# Verify installation
92121
RUN python3 -c "import orbax.checkpoint; print('Orbax installed:', orbax.checkpoint.__file__)"
93122

94-
# 6. Entrypoint
95-
# We point to the benchmark script relative to the repo root structure
96-
WORKDIR /app/orbax_repo/checkpoint
97-
ENTRYPOINT ["python3", "orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py"]
123+
# Create benchmark script
124+
RUN tee /app/run_benchmark.sh <<-'EOF'
125+
#!/bin/bash
126+
gcloud container clusters get-credentials orbax-cluster-test-mtc \
127+
--region us-west1 \
128+
--project orbax-checkpoint
129+
130+
python3 orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py \
131+
--cluster_name orbax-cluster-test-mtc \
132+
--tpu_type v5litepod-16 \
133+
--zone us-west1-c \
134+
--config_file orbax/checkpoint/_src/testing/benchmarks/configs/emergency_checkpoint_manager_benchmark.yaml \
135+
--docker_image gcr.io/orbax-checkpoint/orbax-benchmarks:latest \
136+
--output_directory gs://orbax-benchmarks-hns/runs/$(date +%Y%m%d) \
137+
--nodelete_cluster_on_completion \
138+
--ramdisk_directory=/local/test \
139+
--num_slices 2 \
140+
--test_restart_workflow \
141+
--verbose \
142+
--skip_validation
143+
144+
docker version
145+
EOF
146+
147+
RUN chmod +x /app/run_benchmark.sh
148+
149+
CMD ["/app/run_benchmark.sh"]
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Base image argument (defaulting to slim python image)
2+
ARG BASE_IMAGE=python:3.11-slim
3+
FROM $BASE_IMAGE
4+
5+
WORKDIR /app
6+
7+
# 1. Install System Dependencies
8+
# common utils + git (needed for checkout)
9+
# --no-install-recommends limits bloat
10+
# python3-pip is standard in python images, no need to install
11+
RUN apt-get update && apt-get install -y --no-install-recommends \
12+
git \
13+
dnsutils \
14+
&& rm -rf /var/lib/apt/lists/*
15+
16+
# 2. Checkout Orbax (Optimized Shallow Fetch)
17+
ARG PR_NUMBER
18+
ARG BRANCH=main
19+
ARG REPO_URL=https://github.com/google/orbax.git
20+
21+
# Logic:
22+
# 1. Init empty repo
23+
# 2. Add remote
24+
# 3. Shallow fetch ONLY the target (PR or Branch)
25+
# 4. Checkout
26+
# 5. DELETE .git history to save space
27+
RUN mkdir orbax_repo && cd orbax_repo && \
28+
git init && \
29+
git remote add origin $REPO_URL && \
30+
if [ -n "$PR_NUMBER" ]; then \
31+
echo "Fetching PR #${PR_NUMBER} (Shallow)..." && \
32+
git fetch --depth 1 origin pull/$PR_NUMBER/head:pr_branch && \
33+
git checkout pr_branch; \
34+
else \
35+
echo "Fetching branch: ${BRANCH} (Shallow)..." && \
36+
git fetch --depth 1 origin $BRANCH && \
37+
git checkout FETCH_HEAD; \
38+
fi && \
39+
rm -rf .git
40+
41+
WORKDIR /app/orbax_repo
42+
43+
# 3. Setup Python Environment & Dependencies
44+
# Uninstall pre-installed orbax if present in base image to avoid conflicts
45+
RUN pip uninstall -y orbax-checkpoint orbax || true
46+
47+
ARG JAX_VERSION=newest
48+
ARG DEVICE=tpu
49+
50+
# Install GCSFS and Portpicker
51+
RUN pip install --no-cache-dir gcsfs portpicker clu tensorflow
52+
53+
# Install requirements from repo root if it exists
54+
RUN if [ -f "requirements.txt" ]; then pip install --no-cache-dir -r requirements.txt; fi
55+
56+
# Install JAX (Flexible Versions)
57+
RUN if [ "$JAX_VERSION" = "newest" ]; then \
58+
if [ "$DEVICE" = "gpu" ]; then \
59+
pip install --no-cache-dir -U "jax[k8s,cuda12]" jaxlib; \
60+
elif [ "$DEVICE" = "tpu" ]; then \
61+
pip install --no-cache-dir -U "jax[k8s,tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
62+
else \
63+
pip install --no-cache-dir -U "jax[k8s]" jaxlib; \
64+
fi \
65+
elif [ "$JAX_VERSION" = "nightly" ]; then \
66+
if [ "$DEVICE" = "gpu" ]; then \
67+
pip install --no-cache-dir -U --pre "jax[k8s,cuda12]" jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/; \
68+
elif [ "$DEVICE" = "tpu" ]; then \
69+
pip install --no-cache-dir -U --pre "jax[k8s,tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/; \
70+
fi \
71+
else \
72+
# Specific version
73+
if [ "$DEVICE" = "gpu" ]; then \
74+
pip install --no-cache-dir "jax[k8s,cuda12]==${JAX_VERSION}" "jaxlib==${JAX_VERSION}"; \
75+
elif [ "$DEVICE" = "tpu" ]; then \
76+
pip install --no-cache-dir "jax[k8s,tpu]==${JAX_VERSION}" "jaxlib==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
77+
else \
78+
pip install --no-cache-dir "jax[k8s]==${JAX_VERSION}" "jaxlib==${JAX_VERSION}"; \
79+
fi \
80+
fi
81+
82+
# 4. Install Orbax from Source
83+
WORKDIR /app/orbax_repo/checkpoint
84+
RUN pip install --no-cache-dir .
85+
RUN pip install --no-cache-dir pathwaysutils
86+
87+
# 5. Environment Setup
88+
# Set PYTHONPATH so 'import orbax' works from the correctly mapped directory
89+
ENV PYTHONPATH=/app/orbax_repo/checkpoint
90+
91+
# Verify installation
92+
RUN python3 -c "import orbax.checkpoint; print('Orbax installed:', orbax.checkpoint.__file__)"
93+
94+
# 6. Entrypoint
95+
# We point to the benchmark script relative to the repo root structure
96+
WORKDIR /app/orbax_repo/checkpoint
97+
ENTRYPOINT ["python3", "orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py"]

checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,11 @@
274274
False,
275275
'If True, run workload creation and execution twice to test restart.',
276276
)
277+
_SKIP_VALIDATION = flags.DEFINE_boolean(
278+
'skip_validation',
279+
False,
280+
'If True, skip validation of the benchmark results.',
281+
)
277282

278283
# --- Pathways Flags ---
279284
# Pathways uses a "Sidecar" architecture on XPK:
@@ -646,6 +651,8 @@ def construct_xpk_command(
646651
base_cmd.append('--enable-ops-agent')
647652
if _RAMDISK_DIRECTORY.value is not None:
648653
base_cmd.append('--mtc-enabled')
654+
if _SKIP_VALIDATION.value:
655+
base_cmd.append('--skip-validation')
649656

650657
if _ENABLE_PATHWAYS.value:
651658
if not _PATHWAYS_SERVER_IMAGE.value:
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Base image argument (defaulting to slim python image)
2+
ARG BASE_IMAGE=python:3.11-slim
3+
FROM $BASE_IMAGE
4+
5+
WORKDIR /app
6+
7+
# 1. Install System Dependencies
8+
# common utils + git (needed for checkout)
9+
# --no-install-recommends limits bloat
10+
# python3-pip is standard in python images, no need to install
11+
RUN apt-get update && apt-get install -y --no-install-recommends \
12+
git \
13+
dnsutils \
14+
&& rm -rf /var/lib/apt/lists/*
15+
16+
RUN apt-get update && apt-get install -y --no-install-recommends \
17+
git \
18+
dnsutils \
19+
curl \
20+
ca-certificates \
21+
gnupg \
22+
apt-transport-https \
23+
gettext-base \
24+
gawk \
25+
&& curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg \
26+
&& echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee /etc/apt/sources.list.d/google-cloud-sdk.list \
27+
&& apt-get update \
28+
&& apt-get install -y --no-install-recommends \
29+
google-cloud-cli \
30+
google-cloud-cli-gke-gcloud-auth-plugin \
31+
kubectl \
32+
&& apt-get clean \
33+
&& rm -rf /var/lib/apt/lists/*
34+
35+
COPY ./checkpoint/orbax/checkpoint/_src/testin[g] /app/orbax_repo/checkpoint/orbax/checkpoint/_src/testing
36+
37+
# 3. Setup Python Environment & Dependencies
38+
# Uninstall pre-installed orbax if present in base image to avoid conflicts
39+
RUN pip uninstall -y orbax-checkpoint orbax || true
40+
41+
# # Create a fake docker binary that always returns success (exit 0)
42+
# RUN echo '#!/bin/bash\nexit 0' > /usr/local/bin/docker && chmod +x /usr/local/bin/docker
43+
44+
# Install requirements from repo root if it exists
45+
RUN if [ -f "requirements.txt" ]; then pip install --no-cache-dir -r requirements.txt; fi
46+
47+
RUN pip install --no-cache-dir gcsfs portpicker clu tensorflow pyyaml
48+
49+
# 4. Install Orbax from Source
50+
WORKDIR /app/orbax_repo/checkpoint
51+
RUN pip install xpk
52+
53+
# 5. Environment Setup
54+
# Set PYTHONPATH so 'import orbax' works from the correctly mapped directory
55+
ENV PYTHONPATH=/app/orbax_repo/checkpoint
56+
57+
58+
CMD ["python3", "orbax/checkpoint/_src/testing/oss/cloud_run_it.py"]

0 commit comments

Comments
 (0)