From 1d7c68509462ebd9dc7d7fa150334a1bb0837377 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Mon, 16 Jun 2025 15:23:47 +0000 Subject: [PATCH 01/47] added requirements-tensorflow-tpu.txt and tpu configuration in .kokoro --- .kokoro/github/ubuntu/tpu/build.sh | 36 +++++++++++++++++++ .../ubuntu/tpu/tensorflow/continuous.cfg | 16 +++++++++ .../ubuntu/tpu/tensorflow/presubmit.cfg | 16 +++++++++ requirements-tensorflow-tpu.txt | 14 ++++++++ 4 files changed, 82 insertions(+) create mode 100644 .kokoro/github/ubuntu/tpu/build.sh create mode 100644 .kokoro/github/ubuntu/tpu/tensorflow/continuous.cfg create mode 100644 .kokoro/github/ubuntu/tpu/tensorflow/presubmit.cfg create mode 100644 requirements-tensorflow-tpu.txt diff --git a/.kokoro/github/ubuntu/tpu/build.sh b/.kokoro/github/ubuntu/tpu/build.sh new file mode 100644 index 000000000000..b3d4d0e9cd78 --- /dev/null +++ b/.kokoro/github/ubuntu/tpu/build.sh @@ -0,0 +1,36 @@ +set -e +set -x + +cd "${KOKORO_ROOT}/" + +sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 + +PYTHON_BINARY="/usr/bin/python3.10" + +"${PYTHON_BINARY}" -m venv venv +source venv/bin/activate +# Check the python version +python --version +python3 --version + +cd "src/github/keras" +pip install -U pip setuptools +# psutil is used by background log reader +pip install -U psutil + +if [ "$KERAS_BACKEND" == "tensorflow" ] +then + echo "TensorFlow backend detected." + pip install -r requirements-tensorflow-tpu.txt --progress-bar off --timeout 1000 + pip uninstall -y keras keras-nightly + echo "Check that TensorFlow uses TPU" + python3 -c 'import tensorflow as tf;print(tf.__version__);print(tf.config.list_physical_devices("TPU"))' + # Raise error if GPU is not detected. + python3 -c 'import tensorflow as tf;assert len(tf.config.list_physical_devices("TPU")) > 0' + + # TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted + pytest keras --ignore keras/src/applications \ + --ignore keras/src/layers/merging/merging_test.py \ + --cov=keras \ + --cov-config=pyproject.toml +fi \ No newline at end of file diff --git a/.kokoro/github/ubuntu/tpu/tensorflow/continuous.cfg b/.kokoro/github/ubuntu/tpu/tensorflow/continuous.cfg new file mode 100644 index 000000000000..0da48805d3e4 --- /dev/null +++ b/.kokoro/github/ubuntu/tpu/tensorflow/continuous.cfg @@ -0,0 +1,16 @@ +build_file: "keras/.kokoro/github/ubuntu/tpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "tensorflow" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/tpu/tensorflow/presubmit.cfg b/.kokoro/github/ubuntu/tpu/tensorflow/presubmit.cfg new file mode 100644 index 000000000000..0da48805d3e4 --- /dev/null +++ b/.kokoro/github/ubuntu/tpu/tensorflow/presubmit.cfg @@ -0,0 +1,16 @@ +build_file: "keras/.kokoro/github/ubuntu/tpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "tensorflow" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/requirements-tensorflow-tpu.txt b/requirements-tensorflow-tpu.txt new file mode 100644 index 000000000000..8cafa92379d7 --- /dev/null +++ b/requirements-tensorflow-tpu.txt @@ -0,0 +1,14 @@ +tensorflow==2.18.0 +--find-links https://storage.googleapis.com/libtpu-tf-releases/index.html +tensorflow-tpu==2.18.0 + +tf2onnx + +# Torch cpu-only version (needed for testing). +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0 + +# Jax cpu-only version (needed for testing). +jax[cpu] + +-r requirements-common.txt From 19b5e6be6bab2f3635ed31fbbb2b0485ceb4b7ee Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Mon, 16 Jun 2025 15:57:47 +0000 Subject: [PATCH 02/47] updated .kokoro/github/ubuntu/tpu/build.sh with jax and torch backend configs --- .kokoro/github/ubuntu/tpu/build.sh | 40 +++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/.kokoro/github/ubuntu/tpu/build.sh b/.kokoro/github/ubuntu/tpu/build.sh index b3d4d0e9cd78..f9bad68c19de 100644 --- a/.kokoro/github/ubuntu/tpu/build.sh +++ b/.kokoro/github/ubuntu/tpu/build.sh @@ -33,4 +33,42 @@ then --ignore keras/src/layers/merging/merging_test.py \ --cov=keras \ --cov-config=pyproject.toml -fi \ No newline at end of file +fi + +if [ "$KERAS_BACKEND" == "jax" ] +then + echo "JAX backend detected." + pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000 + pip uninstall -y keras keras-nightly + python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())' + # Raise error if GPU is not detected. + python3 -c 'import jax;assert jax.default_backend().lower() == "gpu"' + + # TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted + # TODO: keras/trainers/data_adapters/py_dataset_adapter_test.py::PyDatasetAdapterTest::test_basic_flow0 Fatal Python error: Aborted + # keras/backend/jax/distribution_lib_test.py is configured for CPU test for now. + pytest keras --ignore keras/src/applications \ + --ignore keras/src/layers/merging/merging_test.py \ + --ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \ + --ignore keras/src/backend/jax/distribution_lib_test.py \ + --ignore keras/src/distribution/distribution_lib_test.py \ + --cov=keras \ + --cov-config=pyproject.toml + + pytest keras/src/distribution/distribution_lib_test.py --cov=keras --cov-config=pyproject.toml +fi + +if [ "$KERAS_BACKEND" == "torch" ] +then + echo "PyTorch backend detected." + pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000 + pip uninstall -y keras keras-nightly + python3 -c 'import torch;print(torch.__version__);print(torch.cuda.is_available())' + # Raise error if GPU is not detected. + python3 -c 'import torch;assert torch.cuda.is_available()' + + pytest keras --ignore keras/src/applications \ + --cov=keras \ + --cov-config=pyproject.toml + +fi From f45e5d0f5d43869882dfdaf89dbd21add9dd20d6 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Wed, 18 Jun 2025 05:52:42 +0000 Subject: [PATCH 03/47] Changed the tpu CI config files path to .github from .kokoro --- {.kokoro/github/ubuntu => .github/workflows}/tpu/build.sh | 0 .../ubuntu => .github/workflows}/tpu/tensorflow/continuous.cfg | 0 .../ubuntu => .github/workflows}/tpu/tensorflow/presubmit.cfg | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename {.kokoro/github/ubuntu => .github/workflows}/tpu/build.sh (100%) rename {.kokoro/github/ubuntu => .github/workflows}/tpu/tensorflow/continuous.cfg (100%) rename {.kokoro/github/ubuntu => .github/workflows}/tpu/tensorflow/presubmit.cfg (100%) diff --git a/.kokoro/github/ubuntu/tpu/build.sh b/.github/workflows/tpu/build.sh similarity index 100% rename from .kokoro/github/ubuntu/tpu/build.sh rename to .github/workflows/tpu/build.sh diff --git a/.kokoro/github/ubuntu/tpu/tensorflow/continuous.cfg b/.github/workflows/tpu/tensorflow/continuous.cfg similarity index 100% rename from .kokoro/github/ubuntu/tpu/tensorflow/continuous.cfg rename to .github/workflows/tpu/tensorflow/continuous.cfg diff --git a/.kokoro/github/ubuntu/tpu/tensorflow/presubmit.cfg b/.github/workflows/tpu/tensorflow/presubmit.cfg similarity index 100% rename from .kokoro/github/ubuntu/tpu/tensorflow/presubmit.cfg rename to .github/workflows/tpu/tensorflow/presubmit.cfg From 6771cc0e02109fb5eed24cdd607490a8c2068c00 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Wed, 18 Jun 2025 07:18:51 +0000 Subject: [PATCH 04/47] Added new job in .github/workflows/actions.yml to run TPU tests --- .github/workflows/actions.yml | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index b9e785dfc949..14d7be428690 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -11,6 +11,27 @@ permissions: contents: read jobs: + tpu_build: + strategy: + fail-fast: false + matrix: + python-version: ['3.10'] + backend: [tensorflow] + name: Run TPU tests + runs-on: ubuntu-latest + env: + PYTHON: ${{ matrix.python-version }} + KERAS_HOME: .github/workflows/config/${{ matrix.backend }} + KERAS_BACKEND: tensorflow + steps: + - uses: actions/checkout@v4 # Checks-out your repository under $GITHUB_WORKSPACE, so your workflow can access it. + + - name: Make script executable + run: chmod +x .github/workflows/tpu/build.sh # Assuming your script is named my-script.sh + + - name: Run my shell script + run: .github/workflows/tpu/build.sh # Execute the script + build: strategy: fail-fast: false From 87d36e7f0b216c7852827e72ffe9f90f107da664 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Wed, 18 Jun 2025 07:37:45 +0000 Subject: [PATCH 05/47] fixed runs-on option in acvtions.yml for tpu_build job to run on self hosted TPU based runner --- .github/workflows/actions.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 14d7be428690..f4aeada63eb4 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -18,7 +18,7 @@ jobs: python-version: ['3.10'] backend: [tensorflow] name: Run TPU tests - runs-on: ubuntu-latest + runs-on: [self-hosted, linux-x86-ct5lp-112-4tpu] env: PYTHON: ${{ matrix.python-version }} KERAS_HOME: .github/workflows/config/${{ matrix.backend }} From 99012985de520a4a6647954c83741015b0b1fea0 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Wed, 18 Jun 2025 09:27:10 +0000 Subject: [PATCH 06/47] Added another runner in the actions TPU job --- .github/workflows/actions.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index f4aeada63eb4..78dd485f5d8d 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -18,7 +18,7 @@ jobs: python-version: ['3.10'] backend: [tensorflow] name: Run TPU tests - runs-on: [self-hosted, linux-x86-ct5lp-112-4tpu] + runs-on: [self-hosted, linux-x86-ct5lp-112-4tpu, linux-x86-ct6e-44-1tpu] env: PYTHON: ${{ matrix.python-version }} KERAS_HOME: .github/workflows/config/${{ matrix.backend }} From be97210d6ffe992ae8c717395703544bfdcb7d36 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Wed, 18 Jun 2025 21:44:59 +0530 Subject: [PATCH 07/47] Update continuous.cfg updated build file path --- .github/workflows/tpu/tensorflow/continuous.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tpu/tensorflow/continuous.cfg b/.github/workflows/tpu/tensorflow/continuous.cfg index 0da48805d3e4..e535b0608202 100644 --- a/.github/workflows/tpu/tensorflow/continuous.cfg +++ b/.github/workflows/tpu/tensorflow/continuous.cfg @@ -1,4 +1,4 @@ -build_file: "keras/.kokoro/github/ubuntu/tpu/build.sh" +build_file: "keras/.github/workflows/tpu/build.sh" action { define_artifacts { @@ -13,4 +13,4 @@ env_vars: { } # Set timeout to 60 mins from default 180 mins -timeout_mins: 60 \ No newline at end of file +timeout_mins: 60 From a1cd5c3e8c81f98f24a035b5af477c944d07c65a Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Wed, 18 Jun 2025 21:45:43 +0530 Subject: [PATCH 08/47] Update presubmit.cfg updated build file path --- .github/workflows/tpu/tensorflow/presubmit.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tpu/tensorflow/presubmit.cfg b/.github/workflows/tpu/tensorflow/presubmit.cfg index 0da48805d3e4..e535b0608202 100644 --- a/.github/workflows/tpu/tensorflow/presubmit.cfg +++ b/.github/workflows/tpu/tensorflow/presubmit.cfg @@ -1,4 +1,4 @@ -build_file: "keras/.kokoro/github/ubuntu/tpu/build.sh" +build_file: "keras/.github/workflows/tpu/build.sh" action { define_artifacts { @@ -13,4 +13,4 @@ env_vars: { } # Set timeout to 60 mins from default 180 mins -timeout_mins: 60 \ No newline at end of file +timeout_mins: 60 From f0ab6762d8e3a784e262285c97ad53c50766f00f Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Mon, 23 Jun 2025 22:38:26 +0530 Subject: [PATCH 09/47] Update actions.yml Updated tpu_build job of actions.yml with specific runner label --- .github/workflows/actions.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 78dd485f5d8d..84038adde637 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -18,7 +18,7 @@ jobs: python-version: ['3.10'] backend: [tensorflow] name: Run TPU tests - runs-on: [self-hosted, linux-x86-ct5lp-112-4tpu, linux-x86-ct6e-44-1tpu] + runs-on: linux-x86-ct5lp-112-4tpu env: PYTHON: ${{ matrix.python-version }} KERAS_HOME: .github/workflows/config/${{ matrix.backend }} From 09161d716a3de5b6bfecb7a7da9633246dd6a99e Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 24 Jun 2025 08:51:42 +0000 Subject: [PATCH 10/47] Developed Dockerfile for TPU build job in actions.yml --- .github/workflows/actions.yml | 17 ++++++++++------- .github/workflows/tpu/Dockerfile | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 7 deletions(-) create mode 100644 .github/workflows/tpu/Dockerfile diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 84038adde637..84887bd4cd15 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -24,14 +24,17 @@ jobs: KERAS_HOME: .github/workflows/config/${{ matrix.backend }} KERAS_BACKEND: tensorflow steps: - - uses: actions/checkout@v4 # Checks-out your repository under $GITHUB_WORKSPACE, so your workflow can access it. + - uses: actions/checkout@v4 - - name: Make script executable - run: chmod +x .github/workflows/tpu/build.sh # Assuming your script is named my-script.sh + - name: Build and run Docker image for TPU tests + run: | + docker build -f .github/workflows/tpu/Dockerfile -t keras-tpu-test . + docker run --rm \ + -e PYTHON=${{ matrix.python-version }} \ + -e KERAS_HOME=.github/workflows/config/${{ matrix.backend }} \ + -e KERAS_BACKEND=tensorflow \ + keras-tpu-test - - name: Run my shell script - run: .github/workflows/tpu/build.sh # Execute the script - build: strategy: fail-fast: false @@ -148,4 +151,4 @@ jobs: pip uninstall -y keras keras-nightly pip install -e "." --progress-bar off --upgrade - name: Run pre-commit - run: pre-commit run --all-files --hook-stage manual + run: pre-commit run --all-files --hook-stage manual \ No newline at end of file diff --git a/.github/workflows/tpu/Dockerfile b/.github/workflows/tpu/Dockerfile new file mode 100644 index 000000000000..7d0eeb2280f1 --- /dev/null +++ b/.github/workflows/tpu/Dockerfile @@ -0,0 +1,28 @@ +FROM python:3.10-slim + +ENV KERAS_HOME=/github/workspace/.github/workflows/config/tensorflow \ + KERAS_BACKEND=tensorflow + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + sudo \ + && rm -rf /var/lib/apt/lists/* + +# Copy the entire codebase into the container +COPY . /github/workspace +WORKDIR /github/workspace + +# Create and activate venv, install pip/setuptools/psutil, then run tests +RUN cd src/github/keras && \ + pip install -U pip setuptools && \ + pip install -U psutil && \ + pip install -r requirements-tensorflow-tpu.txt && \ + pip uninstall -y keras keras-nightly && \ + python3 -c 'import tensorflow as tf;print(tf.__version__);print(tf.config.list_physical_devices("TPU"))' && \ + python3 -c 'import tensorflow as tf;assert len(tf.config.list_physical_devices("TPU")) > 0' && \ + pytest keras --ignore keras/src/applications \ + --ignore keras/src/layers/merging/merging_test.py \ + --cov=keras \ + --cov-config=pyproject.toml + +CMD ["bash"] \ No newline at end of file From 058fdff4b4e746a2e164e0fb1a3fd856287deadc Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 24 Jun 2025 14:27:59 +0530 Subject: [PATCH 11/47] Update actions.yml Added container section --- .github/workflows/actions.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 84887bd4cd15..ffdf82fb0ba7 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -19,6 +19,8 @@ jobs: backend: [tensorflow] name: Run TPU tests runs-on: linux-x86-ct5lp-112-4tpu + container: + image: docker:latest env: PYTHON: ${{ matrix.python-version }} KERAS_HOME: .github/workflows/config/${{ matrix.backend }} @@ -151,4 +153,4 @@ jobs: pip uninstall -y keras keras-nightly pip install -e "." --progress-bar off --upgrade - name: Run pre-commit - run: pre-commit run --all-files --hook-stage manual \ No newline at end of file + run: pre-commit run --all-files --hook-stage manual From d47e39ed7e19e02557929692e5130dc279c5bb25 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Thu, 26 Jun 2025 06:48:59 +0000 Subject: [PATCH 12/47] Included few more runners in tpu_build job --- .github/workflows/actions.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index ffdf82fb0ba7..d180647a012a 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -18,7 +18,13 @@ jobs: python-version: ['3.10'] backend: [tensorflow] name: Run TPU tests - runs-on: linux-x86-ct5lp-112-4tpu + runs-on: + # - linux-x86-ct5lp-112-4tpu + - linux-x86-ct5lp-112-4tpu-fvn6n-runner-6kb8n + # - linux-x86-ct6e-44-1tpu + # - linux-x86-ct6e-44-1tpu-4khbn-runner-x4st4 + # - linux-x86-ct6e-44-1tpu-4khbn-runner-45nmc + container: image: docker:latest env: From ba4f6aef1ead757f6b2538fa92cb285774aa1377 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Thu, 26 Jun 2025 07:02:43 +0000 Subject: [PATCH 13/47] Using linux-x86-ct6e-44-1tpu --- .github/workflows/actions.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index d180647a012a..d0a6785af23b 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -20,8 +20,8 @@ jobs: name: Run TPU tests runs-on: # - linux-x86-ct5lp-112-4tpu - - linux-x86-ct5lp-112-4tpu-fvn6n-runner-6kb8n - # - linux-x86-ct6e-44-1tpu + # - linux-x86-ct5lp-112-4tpu-fvn6n-runner-6kb8n + - linux-x86-ct6e-44-1tpu # - linux-x86-ct6e-44-1tpu-4khbn-runner-x4st4 # - linux-x86-ct6e-44-1tpu-4khbn-runner-45nmc From a5a362442003f674b9c7a55c901e4e5e29df518d Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Mon, 30 Jun 2025 08:34:25 +0000 Subject: [PATCH 14/47] Modified requirement-commmon.txt and updated requirements-tensorflow-tpu.txt --- requirements-common-old.txt | 26 ++++++++++++++++++++++++++ requirements-common.txt | 20 ++++++++++---------- requirements-tensorflow-tpu.txt | 8 ++++---- 3 files changed, 40 insertions(+), 14 deletions(-) create mode 100644 requirements-common-old.txt diff --git a/requirements-common-old.txt b/requirements-common-old.txt new file mode 100644 index 000000000000..7edc40c97a1a --- /dev/null +++ b/requirements-common-old.txt @@ -0,0 +1,26 @@ +pre-commit +namex>=0.0.8 +ruff +pytest +numpy +scipy +scikit-learn +pillow +pandas +absl-py +requests +h5py +ml-dtypes +protobuf +tensorboard-plugin-profile +rich +build +optree +pytest-cov +packaging +# for tree_test.py +dm_tree +coverage!=7.6.5 # 7.6.5 breaks CI +# for onnx_test.py +onnxruntime +openvino diff --git a/requirements-common.txt b/requirements-common.txt index 7edc40c97a1a..ad7f0caa46b4 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -1,23 +1,23 @@ pre-commit -namex>=0.0.8 +#namex>=0.0.8 ruff pytest -numpy +#numpy scipy scikit-learn pillow pandas -absl-py -requests -h5py -ml-dtypes -protobuf +#absl-py +#requests +#h5py +#ml-dtypes +#protobuf tensorboard-plugin-profile -rich +#rich build -optree +#optree pytest-cov -packaging +#packaging # for tree_test.py dm_tree coverage!=7.6.5 # 7.6.5 breaks CI diff --git a/requirements-tensorflow-tpu.txt b/requirements-tensorflow-tpu.txt index 8cafa92379d7..8f78f85d8882 100644 --- a/requirements-tensorflow-tpu.txt +++ b/requirements-tensorflow-tpu.txt @@ -1,8 +1,8 @@ -tensorflow==2.18.0 ---find-links https://storage.googleapis.com/libtpu-tf-releases/index.html -tensorflow-tpu==2.18.0 +#tensorflow==2.18.0 +#--find-links https://storage.googleapis.com/libtpu-tf-releases/index.html +#tensorflow-tpu==2.18.0 -tf2onnx +#tf2onnx # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu From b9998af9c161659c6b10c6931b7a1896f1628eb4 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 22 Jul 2025 05:06:36 +0000 Subject: [PATCH 15/47] Added Dtypes_TPU_tests.py and requirements-jax-tpu.txt --- conftest.py | 2 + keras/src/backend/common/dtypes_TPU_test.py | 135 +++++++++++++++++ keras/src/backend/common/dtypes_test.py | 158 +++++++++++++++++++- requirements-jax-tpu.txt | 14 ++ 4 files changed, 301 insertions(+), 8 deletions(-) create mode 100644 keras/src/backend/common/dtypes_TPU_test.py create mode 100644 requirements-jax-tpu.txt diff --git a/conftest.py b/conftest.py index 0ade560a1bdf..cf7ff5599db4 100644 --- a/conftest.py +++ b/conftest.py @@ -59,3 +59,5 @@ def pytest_collection_modifyitems(config, items): def skip_if_backend(given_backend, reason): return pytest.mark.skipif(backend() == given_backend, reason=reason) + + diff --git a/keras/src/backend/common/dtypes_TPU_test.py b/keras/src/backend/common/dtypes_TPU_test.py new file mode 100644 index 000000000000..a4e728e9d113 --- /dev/null +++ b/keras/src/backend/common/dtypes_TPU_test.py @@ -0,0 +1,135 @@ +import tensorflow as tf +from unittest.mock import patch +import os +import time + +from absl.testing import parameterized + +from keras.src import backend +from keras.src import ops +from keras.src.backend.common import dtypes +from keras.src.testing import test_case +from keras.src.testing.test_utils import named_product + +# Ensure the backend is set to TensorFlow +os.environ["KERAS_BACKEND"] = "tensorflow" + +os.environ["TPU_NAME"] = "harshith-tf-4" +os.environ["JAX_PLATFORMS"] = "" + +# Define dtypes that are generally problematic or unsupported on TPUs for direct operations. +TPU_UNSUPPORTED_DTYPES = [ + "string", + "complex64", + "complex128", + "float8_e4m3fn", + "float8_e5m2", + "float64", + # Based on your latest failure logs involving bfloat16 and float16/float32, + # the 'bool' might not be the direct cause, but rather the promotion rules. + # We will keep it for now as it did appear in previous skips. + "bool" +] + +# Filter ALLOWED_DTYPES to create a list suitable for TPU tests +ALL_DTYPES_FOR_TPU_TESTS = [ + x for x in dtypes.ALLOWED_DTYPES if x not in TPU_UNSUPPORTED_DTYPES +] + [None] + + +class DtypesTPUTest(test_case.TestCase): + """Test the dtype to verify that the behavior matches JAX, with TPU support.""" + + TPU_MAX_RETRIES = 2 + TPU_BASE_DELAY = 1.0 + + if backend.backend() != "tensorflow": + raise RuntimeError("This test class is specifically designed for the TensorFlow backend with TPU.") + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.tpu_available = False + cls.tpu_strategy = None + + try: + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + cls.tpu_strategy = tf.distribute.TPUStrategy(resolver) + cls.tpu_available = True + print("✓ TPU initialization successful!") + print(f"Number of TPU devices: {cls.tpu_strategy.num_replicas_in_sync}") + print(f"Logical TPU devices: {tf.config.list_logical_devices('TPU')}") + except Exception as e: + print(f"✗ TPU initialization failed: {e}") + print("Falling back to CPU/GPU testing") + cls.tpu_available = False + + def setUp(self): + tf.keras.backend.clear_session() + return super().setUp() + + def tearDown(self): + tf.keras.backend.clear_session() + return super().tearDown() + + @parameterized.named_parameters( + named_product(dtype1=ALL_DTYPES_FOR_TPU_TESTS, dtype2=ALL_DTYPES_FOR_TPU_TESTS) + ) + def test_result_type_with_tensor_on_tpu(self, dtype1, dtype2): + """Test dtype result_type behavior specifically on TPU with supported dtypes.""" + if not self.tpu_available: + self.skipTest("TPU not available") + + # import jax.numpy as jnp # JAX is not needed if we assert against Keras's own behavior + + with self.tpu_strategy.scope(): + try: + x1_on_tpu = ops.ones((1,), dtype=dtype1) + x2_on_tpu = ops.ones((1,), dtype=dtype2) + + print(f"Initial (Eager Context) X1 Device : {x1_on_tpu.device}") + print(f"Initial (Eager Context) X2 Device : {x2_on_tpu.device}") + + # This operation might run on CPU if not part of a tf.function for TPU + result_eager_attempt = ops.add(x1_on_tpu, x2_on_tpu) + print(f"Initial (Eager Context) Add Result Device : {result_eager_attempt.device}") + + @tf.function + def tpu_compute(a, b): + add_result = ops.add(a, b) + return add_result + + distributed_result = self.tpu_strategy.run(tpu_compute, args=(x1_on_tpu, x2_on_tpu)) + + actual_result_dtype = None + if isinstance(distributed_result, tf.distribute.DistributedValues): + replica_result = distributed_result.values[0] + print(f"Device of result from TPU replica 0: {replica_result.device}") + self.assertIn("TPU", replica_result.device) + actual_result_dtype = replica_result.dtype + else: + print(f"Device of direct distributed result: {distributed_result.device}") + self.assertIn("TPU", distributed_result.device) + actual_result_dtype = distributed_result.dtype + + # Get the expected result type according to Keras's backend + # This is the primary source of truth for the Keras backend's behavior + expected_keras_result_type = backend.result_type(x1_on_tpu.dtype, x2_on_tpu.dtype) + + print(f"Test case: dtype1={dtype1}, dtype2={dtype2}") + print(f"Keras backend.result_type: {expected_keras_result_type}") + print(f"Actual result dtype from TPU operation: {actual_result_dtype}") + + # Assert that the actual result's dtype matches Keras's backend's expected result type + self.assertEqual(actual_result_dtype, expected_keras_result_type) + + # Removed JAX comparison, as it's not strictly necessary for testing Keras's internal consistency + # If you need to verify JAX compatibility, consider a separate test or a more nuanced comparison. + + except Exception as e: + if "context_id" in str(e).lower() or "socket closed" in str(e).lower(): + self.skipTest(f"TPU context issue or socket closed: {e}") + else: + raise diff --git a/keras/src/backend/common/dtypes_test.py b/keras/src/backend/common/dtypes_test.py index 7750dcecdd11..0f1ca48b6413 100644 --- a/keras/src/backend/common/dtypes_test.py +++ b/keras/src/backend/common/dtypes_test.py @@ -1,3 +1,5 @@ +import tensorflow as tf # New import +import os from unittest.mock import patch from absl.testing import parameterized @@ -8,19 +10,50 @@ from keras.src.testing import test_case from keras.src.testing.test_utils import named_product +# Ensure the backend is set to TensorFlow if you intend to use TPU. +# This environment variable should ideally be set before Python starts for the process. + +os.environ["KERAS_BACKEND"] = "tensorflow" # Moved to test_case module in Keras + +# Set TPU_NAME if connecting to a specific TPU worker +os.environ["TPU_NAME"] = "harshith-tf-4" +# JAX_PLATFORMS is typically for JAX-specific environments, not directly for TF/Keras TPU. +os.environ["JAX_PLATFORMS"] = "" + + +# --- TPU-specific Dtype Definitions --- +# These must be defined at the module level for absl.testing.parameterized +# to find them when the class is being defined. + +TPU_UNSUPPORTED_DTYPES = [ + "string", + "complex64", + "complex128", + "float8_e4m3fn", + "float8_e5m2", + "float64", # Often problematic for general ops on TPU, or leads to performance issues + "bool" # Can cause issues with type promotion/XLA on TPU in some contexts +] + +ALL_DTYPES_FOR_TPU_TESTS = [ + x for x in dtypes.ALLOWED_DTYPES if x not in TPU_UNSUPPORTED_DTYPES +] + [None] + +# --- End TPU-specific Dtype Definitions --- + class DtypesTest(test_case.TestCase): - """Test the dtype to verify that the behavior matches JAX.""" + """Test the dtype to verify that the behavior matches JAX, with optional TPU support.""" + # Original ALL_DTYPES logic (backend-dependent) remains for non-TPU tests if backend.backend() == "torch": from keras.src.backend.torch.core import to_torch_dtype - # TODO: torch doesn't support uint64. ALL_DTYPES = [] for x in dtypes.ALLOWED_DTYPES: if x not in ["string", "uint64"]: x = str(to_torch_dtype(x)).split(".")[-1] - if x not in ALL_DTYPES: # skip duplicates created by remapping + if x not in ALL_DTYPES: ALL_DTYPES.append(x) ALL_DTYPES += [None] elif backend.backend() == "openvino": @@ -29,28 +62,78 @@ class DtypesTest(test_case.TestCase): for x in dtypes.ALLOWED_DTYPES if x not in ["string", "complex64", "complex128"] ] + [None] - else: + else: # Default to TensorFlow or other backends ALL_DTYPES = [x for x in dtypes.ALLOWED_DTYPES if x != "string"] + [ None ] - # Remove float8 dtypes for the following tests + # Remove float8 dtypes for the following tests (original logic) ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES] + # --- NEW: setUpClass for TPU initialization (no fixtures/markers) --- + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.tpu_available = False + cls.tpu_strategy = None + + # Only attempt TPU initialization if the Keras backend is TensorFlow + if backend.backend() == "tensorflow": + print("\nAttempting TPU initialization from DtypesTest.setUpClass...") + try: + # Use empty string '' for auto-detection or 'grpc://:8470' + # or your specific TPU_NAME from env var + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + cls.tpu_strategy = tf.distribute.TPUStrategy(resolver) + cls.tpu_available = True + print("✓ TPU initialization successful from DtypesTest.setUpClass!") + print(f"Number of TPU devices: {cls.tpu_strategy.num_replicas_in_sync}") + print(f"Logical TPU devices: {tf.config.list_logical_devices('TPU')}") + except Exception as e: + print(f"✗ TPU initialization failed from DtypesTest.setUpClass: {e}") + print("Falling back to CPU/GPU testing for this class.") + cls.tpu_available = False + else: + print(f"Skipping TPU initialization for backend: {backend.backend()}") + + @classmethod + def tearDownClass(cls): + # Optional: Shut down TPU system if it was initialized + if cls.tpu_available: + try: + # This can sometimes cause issues if other processes are using it, + # or if the context was already lost. Use with caution. + # tf.tpu.experimental.shutdown_tpu_system() + print("TPU system teardown (if applicable) completed.") + except Exception as e: + print(f"Error during TPU system teardown: {e}") + super().tearDownClass() + # --- END setUpClass for TPU --- + def setUp(self): + # The JAX x64 setup is for JAX backend tests, keep it. from jax.experimental import enable_x64 - self.jax_enable_x64 = enable_x64() self.jax_enable_x64.__enter__() + # Clear Keras session for each test + if backend.backend() == "tensorflow": # Only clear if TF backend is active + tf.keras.backend.clear_session() return super().setUp() def tearDown(self): + # JAX x64 teardown self.jax_enable_x64.__exit__(None, None, None) + # Clear Keras session for each test + if backend.backend() == "tensorflow": # Only clear if TF backend is active + tf.keras.backend.clear_session() return super().tearDown() @parameterized.named_parameters( named_product(dtype1=ALL_DTYPES, dtype2=[bool, int, float]) ) def test_result_type_with_python_scalar_types(self, dtype1, dtype2): + """Test dtype result_type behavior with Python scalar types (non-TPU).""" import jax.numpy as jnp out = backend.result_type(dtype1, dtype2) @@ -61,6 +144,9 @@ def test_result_type_with_python_scalar_types(self, dtype1, dtype2): named_product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES) ) def test_result_type_with_tensor(self, dtype1, dtype2): + """Test dtype result_type behavior with tensors (non-TPU).""" + # This test will run for all backends as per original logic, + # but will not explicitly use TPU. import jax.numpy as jnp x1 = ops.ones((1,), dtype=dtype1) @@ -72,9 +158,66 @@ def test_result_type_with_tensor(self, dtype1, dtype2): expected = jnp.result_type(x1_jax, x2_jax).name self.assertEqual(out, expected) + # --- NEW TPU-ENABLED TEST METHOD (no fixtures/markers) --- + @parameterized.named_parameters( + named_product(dtype1=ALL_DTYPES_FOR_TPU_TESTS, dtype2=ALL_DTYPES_FOR_TPU_TESTS) + ) + def test_result_type_with_tensor_on_tpu(self, dtype1, dtype2): + """Test dtype result_type behavior specifically on TPU with supported dtypes.""" + # Check if backend is TensorFlow and TPU is available for this class + if backend.backend() != "tensorflow": + self.skipTest("TPU tests are only applicable for TensorFlow backend.") + if not self.tpu_available: + self.skipTest("TPU not available for this test class.") + + with self.tpu_strategy.scope(): # Use the class-level strategy object + try: + x1_on_tpu = ops.ones((1,), dtype=dtype1) + x2_on_tpu = ops.ones((1,), dtype=dtype2) + + print(f"Initial (Eager Context) X1 Device : {x1_on_tpu.device}") + print(f"Initial (Eager Context) X2 Device : {x2_on_tpu.device}") + + result_eager_attempt = ops.add(x1_on_tpu, x2_on_tpu) + print(f"Initial (Eager Context) Add Result Device : {result_eager_attempt.device}") + + @tf.function + def tpu_compute(a, b): + add_result = ops.add(a, b) + return add_result + + distributed_result = self.tpu_strategy.run(tpu_compute, args=(x1_on_tpu, x2_on_tpu)) + + actual_result_dtype = None + if isinstance(distributed_result, tf.distribute.DistributedValues): + replica_result = distributed_result.values[0] + print(f"Device of result from TPU replica 0: {replica_result.device}") + self.assertIn("TPU", replica_result.device) + actual_result_dtype = replica_result.dtype + else: + print(f"Device of direct distributed result: {distributed_result.device}") + self.assertIn("TPU", distributed_result.device) + actual_result_dtype = distributed_result.dtype + + expected_keras_result_type = backend.result_type(x1_on_tpu.dtype, x2_on_tpu.dtype) + + print(f"Test case: dtype1={dtype1}, dtype2={dtype2}") + print(f"Keras backend.result_type: {expected_keras_result_type}") + print(f"Actual result dtype from TPU operation: {actual_result_dtype}") + + self.assertEqual(actual_result_dtype, expected_keras_result_type) + + except Exception as e: + if "context_id" in str(e).lower() or "socket closed" in str(e).lower(): + self.skipTest(f"TPU context issue or socket closed: {e}") + else: + raise + # --- END NEW TPU-ENABLED TEST METHOD --- + + + # Original tests below remain unchanged def test_result_type_with_none(self): import jax.numpy as jnp - self.assertEqual(backend.result_type(None), jnp.result_type(None).name) def test_result_type_empty_list(self): @@ -210,7 +353,6 @@ def test_resolve_weak_type_float(self): ) def test_least_upper_bound_ensure_order_independence(self): - # Test to ensure _least_upper_bound is order-independent. result1 = dtypes._least_upper_bound("float32", "int32") result2 = dtypes._least_upper_bound("int32", "float32") self.assertEqual(result1, result2) diff --git a/requirements-jax-tpu.txt b/requirements-jax-tpu.txt new file mode 100644 index 000000000000..4febbe8e8aab --- /dev/null +++ b/requirements-jax-tpu.txt @@ -0,0 +1,14 @@ +# Tensorflow cpu-only version (needed for testing). +tensorflow-cpu~=2.18.1 +tf2onnx + +# Torch cpu-only version (needed for testing). +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0 + +# Jax with cuda support. +--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html +jax[tpu] +flax + +-r requirements-common.txt From f68be9775658df86f52c059beb16611f22c2418a Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Thu, 26 Jun 2025 09:51:18 -0700 Subject: [PATCH 16/47] Progress bar now handles `steps_per_execution`. (#21422) Progress bar would always report the starting batch + 1 at the end of the batch. Now it takes into account `steps_per_execution` for the last batch reported. Fixes https://github.com/keras-team/keras/issues/20861 --- keras/src/backend/jax/trainer.py | 22 +++++----- keras/src/backend/numpy/trainer.py | 14 +++--- keras/src/backend/openvino/trainer.py | 6 +-- .../src/backend/tensorflow/distribute_test.py | 2 +- keras/src/backend/tensorflow/trainer.py | 22 +++++----- keras/src/backend/torch/trainer.py | 18 ++++---- keras/src/trainers/epoch_iterator.py | 18 +++++--- keras/src/trainers/epoch_iterator_test.py | 11 ++--- keras/src/trainers/trainer.py | 2 +- keras/src/trainers/trainer_test.py | 43 ++++++++----------- 10 files changed, 80 insertions(+), 78 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 2577de297d78..199227b2e315 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -408,9 +408,9 @@ def fit( self._jax_state_synced = True with epoch_iterator.catch_stop_iteration(): - for step, iterator in epoch_iterator: + for begin_step, end_step, iterator in epoch_iterator: # Callbacks - callbacks.on_train_batch_begin(step) + callbacks.on_train_batch_begin(begin_step) # Train step if self._jax_state_synced: @@ -441,7 +441,7 @@ def fit( "metrics_variables": metrics_variables, } # Dispatch callbacks. This takes care of async dispatch. - callbacks.on_train_batch_end(step, logs) + callbacks.on_train_batch_end(end_step, logs) if self.stop_training: # Stop training if a callback has set @@ -569,8 +569,8 @@ def evaluate( self._jax_state_synced = True with epoch_iterator.catch_stop_iteration(): - for step, iterator in epoch_iterator: - callbacks.on_test_batch_begin(step) + for begin_step, end_step, iterator in epoch_iterator: + callbacks.on_test_batch_begin(begin_step) if self._jax_state_synced: # The state may have been synced by a callback. @@ -600,7 +600,7 @@ def evaluate( } # Dispatch callbacks. This takes care of async dispatch. - callbacks.on_test_batch_end(step, logs) + callbacks.on_test_batch_end(end_step, logs) if self.stop_evaluating: break @@ -633,7 +633,7 @@ def predict( if not all(layer.built for layer in self._flatten_layers()): # Build the model on one batch of data. - for _, iterator in epoch_iterator: + for _, _, iterator in epoch_iterator: # Build model x, _, _ = data_adapter_utils.unpack_x_y_sample_weight( next(iterator) @@ -677,8 +677,8 @@ def append_to_outputs(batch_outputs, outputs): outputs = None non_trainable_variables = None with epoch_iterator.catch_stop_iteration(): - for step, iterator in epoch_iterator: - callbacks.on_predict_batch_begin(step) + for begin_step, end_step, iterator in epoch_iterator: + callbacks.on_predict_batch_begin(begin_step) if self._jax_state_synced: # The state may have been synced by a callback. state = self._get_jax_state( @@ -701,7 +701,9 @@ def append_to_outputs(batch_outputs, outputs): outputs = append_to_outputs(batch_outputs, outputs) # Dispatch callbacks. This takes care of async dispatch. - callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) + callbacks.on_predict_batch_end( + end_step, {"outputs": batch_outputs} + ) if self.stop_predicting: break diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py index 80494a540be9..fd8c276a86d2 100644 --- a/keras/src/backend/numpy/trainer.py +++ b/keras/src/backend/numpy/trainer.py @@ -211,11 +211,11 @@ def append_to_outputs(batch_outputs, outputs): self.stop_predicting = False callbacks.on_predict_begin() outputs = None - for step, data in epoch_iterator: - callbacks.on_predict_batch_begin(step) + for begin_step, end_step, data in epoch_iterator: + callbacks.on_predict_batch_begin(begin_step) batch_outputs = self.predict_function(data) outputs = append_to_outputs(batch_outputs, outputs) - callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) + callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) if self.stop_predicting: break callbacks.on_predict_end() @@ -255,7 +255,7 @@ def evaluate( if not all(layer.built for layer in self._flatten_layers()): # Build the model on one batch of data. - for _, data in epoch_iterator: + for _, _, data in epoch_iterator: data_batch = data[0] self._symbolic_build(data_batch) break @@ -276,10 +276,10 @@ def evaluate( callbacks.on_test_begin() logs = {} self.reset_metrics() - for step, data in epoch_iterator: - callbacks.on_test_batch_begin(step) + for begin_step, end_step, data in epoch_iterator: + callbacks.on_test_batch_begin(begin_step) logs = self.test_function(data) - callbacks.on_test_batch_end(step, logs) + callbacks.on_test_batch_end(end_step, logs) if self.stop_evaluating: break logs = self._get_metrics_result_or_logs(logs) diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index b95f635002aa..00921becafc7 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -213,11 +213,11 @@ def append_to_outputs(batch_outputs, outputs): self.stop_predicting = False callbacks.on_predict_begin() outputs = None - for step, data in epoch_iterator.enumerate_epoch(): - callbacks.on_predict_batch_begin(step) + for begin_step, end_step, data in epoch_iterator.enumerate_epoch(): + callbacks.on_predict_batch_begin(begin_step) batch_outputs = self.predict_function(data) outputs = append_to_outputs(batch_outputs, outputs) - callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) + callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) if self.stop_predicting: break callbacks.on_predict_end() diff --git a/keras/src/backend/tensorflow/distribute_test.py b/keras/src/backend/tensorflow/distribute_test.py index e034a65864bc..d2381bf64c14 100644 --- a/keras/src/backend/tensorflow/distribute_test.py +++ b/keras/src/backend/tensorflow/distribute_test.py @@ -104,7 +104,7 @@ def test_epoch_iterator(self): distribute_strategy=strategy, ) steps_seen = [] - for step, data_iterator in epoch_iterator: + for step, _, data_iterator in epoch_iterator: steps_seen.append(step) batch = next(data_iterator) self.assertEqual(len(batch), 3) diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index bc632e8dd589..fa2f5770098b 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -372,10 +372,10 @@ def fit( self.reset_metrics() callbacks.on_epoch_begin(epoch) with epoch_iterator.catch_stop_iteration(): - for step, iterator in epoch_iterator: - callbacks.on_train_batch_begin(step) + for begin_step, end_step, iterator in epoch_iterator: + callbacks.on_train_batch_begin(begin_step) logs = self.train_function(iterator) - callbacks.on_train_batch_end(step, logs) + callbacks.on_train_batch_end(end_step, logs) if self.stop_training: break @@ -484,10 +484,10 @@ def evaluate( logs = {} self.reset_metrics() with epoch_iterator.catch_stop_iteration(): - for step, iterator in epoch_iterator: - callbacks.on_test_batch_begin(step) + for begin_step, end_step, iterator in epoch_iterator: + callbacks.on_test_batch_begin(begin_step) logs = self.test_function(iterator) - callbacks.on_test_batch_end(step, logs) + callbacks.on_test_batch_end(end_step, logs) if self.stop_evaluating: break logs = self._get_metrics_result_or_logs(logs) @@ -560,12 +560,14 @@ def get_data(iterator): callbacks.on_predict_begin() outputs = None with epoch_iterator.catch_stop_iteration(): - for step, iterator in epoch_iterator: - callbacks.on_predict_batch_begin(step) + for begin_step, end_step, iterator in epoch_iterator: + callbacks.on_predict_batch_begin(begin_step) data = get_data(iterator) batch_outputs = self.predict_function(data) outputs = append_to_outputs(batch_outputs, outputs) - callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) + callbacks.on_predict_batch_end( + end_step, {"outputs": batch_outputs} + ) if self.stop_predicting: break callbacks.on_predict_end() @@ -696,7 +698,7 @@ def _maybe_symbolic_build(self, iterator=None, data_batch=None): # Unlike jax/torch iterator, tf iterator returns an iterator instead # of data batch in `iterator`. if iterator is not None: - for _, it in iterator: + for _, _, it in iterator: maybe_distributed_data_batch = next(it) has_distributed_values = tree.map_structure( lambda x: isinstance(x, tf.distribute.DistributedValues), diff --git a/keras/src/backend/torch/trainer.py b/keras/src/backend/torch/trainer.py index 6469ae32ea42..b0a52e65cc6c 100644 --- a/keras/src/backend/torch/trainer.py +++ b/keras/src/backend/torch/trainer.py @@ -256,14 +256,14 @@ def fit( self.train() logs = {} - for step, data in epoch_iterator: + for begin_step, end_step, data in epoch_iterator: # Callbacks - callbacks.on_train_batch_begin(step) + callbacks.on_train_batch_begin(begin_step) logs = self.train_function(data) # Callbacks - callbacks.on_train_batch_end(step, logs) + callbacks.on_train_batch_end(end_step, logs) if self.stop_training: break @@ -374,10 +374,10 @@ def evaluate( callbacks.on_test_begin() logs = {} self.reset_metrics() - for step, data in epoch_iterator: - callbacks.on_test_batch_begin(step) + for begin_step, end_step, data in epoch_iterator: + callbacks.on_test_batch_begin(begin_step) logs = self.test_function(data) - callbacks.on_test_batch_end(step, logs) + callbacks.on_test_batch_end(end_step, logs) if self.stop_evaluating: break logs = self._get_metrics_result_or_logs(logs) @@ -433,11 +433,11 @@ def append_to_outputs(batch_outputs, outputs): self.stop_predicting = False callbacks.on_predict_begin() outputs = None - for step, data in epoch_iterator: - callbacks.on_predict_batch_begin(step) + for begin_step, end_step, data in epoch_iterator: + callbacks.on_predict_batch_begin(begin_step) batch_outputs = self.predict_function(data) outputs = append_to_outputs(batch_outputs, outputs) - callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) + callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) if self.stop_predicting: break callbacks.on_predict_end() diff --git a/keras/src/trainers/epoch_iterator.py b/keras/src/trainers/epoch_iterator.py index 9564611abaf1..67a603093d8e 100644 --- a/keras/src/trainers/epoch_iterator.py +++ b/keras/src/trainers/epoch_iterator.py @@ -116,7 +116,11 @@ def _enumerate_iterator(self): self._interrupted_warning() break self._steps_seen += self.steps_per_execution - yield step, self._current_iterator + yield ( + step, + step + self.steps_per_execution - 1, + self._current_iterator, + ) if self._num_batches and self._steps_seen >= self._num_batches: self._current_iterator = iter(self._get_iterator()) self._steps_seen = 0 @@ -126,7 +130,7 @@ def _enumerate_iterator(self): while True: step += self.steps_per_execution self._steps_seen = step + self.steps_per_execution - yield step, iterator + yield step, step + self.steps_per_execution - 1, iterator self.data_adapter.on_epoch_end() def __iter__(self): @@ -135,19 +139,19 @@ def __iter__(self): def __next__(self): buffer = [] - step, iterator = next(self._epoch_iterator) + begin_step, end_step, iterator = next(self._epoch_iterator) with self.catch_stop_iteration(): for _ in range(self.steps_per_execution): data = next(iterator) buffer.append(data) - return step, buffer + return begin_step, end_step, buffer if buffer: - return step, buffer + return begin_step, end_step, buffer raise StopIteration def enumerate_epoch(self): - for step, data in self: - yield step, data + for begin_step, end_step, data in self: + yield begin_step, end_step, data @contextlib.contextmanager def catch_stop_iteration(self): diff --git a/keras/src/trainers/epoch_iterator_test.py b/keras/src/trainers/epoch_iterator_test.py index 31d617c74aea..e674c3220a9b 100644 --- a/keras/src/trainers/epoch_iterator_test.py +++ b/keras/src/trainers/epoch_iterator_test.py @@ -31,9 +31,10 @@ def test_basic_flow(self, call_type): generator = iterator else: generator = iterator.enumerate_epoch() - for step, batch in generator: + for begin_step, end_step, batch in generator: batch = batch[0] - steps_seen.append(step) + steps_seen.append(begin_step) + self.assertEqual(begin_step, end_step) self.assertEqual(len(batch), 3) self.assertIsInstance(batch[0], np.ndarray) self.assertEqual(steps_seen, [0, 1, 2, 3, 4, 5, 6]) @@ -52,7 +53,7 @@ def test_insufficient_data(self): ) steps_seen = [] with pytest.warns(match="Your input ran out of data"): - for step, _ in iterator: + for step, _, _ in iterator: steps_seen.append(step) self.assertLen(steps_seen, steps_per_epoch - 2) @@ -96,7 +97,7 @@ def __getitem__(self, idx): torch_dataset, batch_size=8, shuffle=True ) iterator = epoch_iterator.EpochIterator(torch_dataloader) - for _, batch in iterator: + for _, _, batch in iterator: batch = batch[0] self.assertEqual(batch[0].shape, (8, 2)) self.assertEqual(batch[1].shape, (8, 1)) @@ -226,7 +227,7 @@ def on_epoch_end(self): num_epochs = 5 for epoch in range(num_epochs): - for step, batch in epoch_iter: + for _, _, _ in epoch_iter: pass self.assertAllEqual(ds.tracker, [1, 2] * num_epochs) diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 1a882d570da2..b6e3cde44560 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -1072,7 +1072,7 @@ def to_symbolic_input(v): ) if data_batch is None: - for _, data_or_iterator in iterator: + for _, _, data_or_iterator in iterator: if isinstance(data_or_iterator, (list, tuple)): data_batch = data_or_iterator[0] else: diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 26f5a7ffad72..95f0cc69d150 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -284,14 +284,13 @@ def on_batch_end(self, batch, logs=None): class StepCount(Callback): - def __init__(self, batches_indices, batch_size): + def __init__(self, steps_per_execution=1): super().__init__() self.begin_count = 0 self.end_count = 0 self.epoch_begin_count = 0 self.epoch_end_count = 0 - self.batches = batches_indices - self.batch_size = batch_size + self.steps_per_execution = steps_per_execution def on_epoch_begin(self, epoch, logs=None): self.begin_count = 0 @@ -302,13 +301,12 @@ def on_epoch_end(self, epoch, logs=None): self.epoch_end_count += 1 def on_batch_begin(self, batch, logs=None): - if self.begin_count < len(self.batches): - assert batch == self.batches[self.begin_count] // self.batch_size + assert batch == self.begin_count * self.steps_per_execution self.begin_count += 1 def on_batch_end(self, batch, logs=None): - assert batch == self.batches[self.end_count] // self.batch_size self.end_count += 1 + assert batch == self.end_count * self.steps_per_execution - 1 class TestTrainer(testing.TestCase): @@ -976,10 +974,6 @@ def test_steps_per_execution_steps_count(self, steps_per_execution, mode): batch_size = 16 epochs = 2 - batches_indices = list( - range(0, data_size, steps_per_execution * batch_size) - ) - x = np.ones((data_size, 4)) y = np.ones((data_size, 1)) @@ -991,7 +985,7 @@ def test_steps_per_execution_steps_count(self, steps_per_execution, mode): run_eagerly=(mode == "eager"), jit_compile=(mode == "jit"), ) - step_count = StepCount(batches_indices, batch_size) + step_count = StepCount(steps_per_execution) history = model.fit( x=x, @@ -1002,7 +996,10 @@ def test_steps_per_execution_steps_count(self, steps_per_execution, mode): verbose=0, ) - self.assertEqual(step_count.begin_count, len(batches_indices)) + self.assertEqual( + step_count.begin_count, + 1 + (data_size - 1) // (steps_per_execution * batch_size), + ) self.assertEqual(step_count.end_count, step_count.begin_count) self.assertEqual(step_count.epoch_begin_count, epochs) self.assertEqual( @@ -1046,10 +1043,6 @@ def test_steps_per_execution_unrolled_steps_steps_count( epochs = 2 unrolled_steps_per_execution = 8 - batches_indices = list( - range(0, data_size, steps_per_execution * batch_size) - ) - x = np.ones((data_size, 4)) y = np.ones((data_size, 1)) @@ -1060,7 +1053,7 @@ def test_steps_per_execution_unrolled_steps_steps_count( steps_per_execution=steps_per_execution, jit_compile=True, ) - step_count = StepCount(batches_indices, batch_size) + step_count = StepCount(steps_per_execution) model.unrolled_steps_per_execution = unrolled_steps_per_execution history = model.fit( x=x, @@ -1071,7 +1064,10 @@ def test_steps_per_execution_unrolled_steps_steps_count( verbose=0, ) - self.assertEqual(step_count.begin_count, len(batches_indices)) + self.assertEqual( + step_count.begin_count, + 1 + (data_size - 1) // (steps_per_execution * batch_size), + ) self.assertEqual(step_count.end_count, step_count.begin_count) self.assertEqual(step_count.epoch_begin_count, epochs) self.assertEqual( @@ -1209,10 +1205,6 @@ def test_steps_per_execution_steps_count_unknown_dataset_size( batch_size = 16 epochs = 2 - batches_indices = list( - range(0, data_size, steps_per_execution * batch_size) - ) - def data_generator(): x = np.ones((data_size, 4), dtype=np.float32) y = np.ones((data_size, 1), dtype=np.float32) @@ -1238,7 +1230,7 @@ def data_generator(): run_eagerly=(mode == "eager"), jit_compile=(mode == "jit"), ) - step_count = StepCount(batches_indices, batch_size) + step_count = StepCount(steps_per_execution) history = model.fit( dataset, @@ -1247,8 +1239,9 @@ def data_generator(): verbose=0, ) - self.assertGreaterEqual(step_count.begin_count, len(batches_indices)) - self.assertEqual(step_count.end_count, len(batches_indices)) + batch_count = 1 + (data_size - 1) // (steps_per_execution * batch_size) + self.assertGreaterEqual(step_count.begin_count, batch_count) + self.assertEqual(step_count.end_count, batch_count) self.assertEqual(step_count.epoch_begin_count, epochs) self.assertEqual( step_count.epoch_end_count, step_count.epoch_begin_count From 1018abf74d80d14b25d8a7a11716040717216351 Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Fri, 27 Jun 2025 12:11:36 -0700 Subject: [PATCH 17/47] Fix symbolic call of `logsumexp` with int axis. (#21428) Using `keras.ops.math.logsumexp` with an int for `axis` in a functional model would throw an error. --- keras/src/ops/math_test.py | 6 ++++-- keras/src/ops/operation_utils.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/keras/src/ops/math_test.py b/keras/src/ops/math_test.py index b2cad1fb2cb1..748b62a3513f 100644 --- a/keras/src/ops/math_test.py +++ b/keras/src/ops/math_test.py @@ -179,8 +179,10 @@ def test_in_top_k(self): def test_logsumexp(self): x = KerasTensor((None, 2, 3), dtype="float32") - result = kmath.logsumexp(x) - self.assertEqual(result.shape, ()) + self.assertEqual(kmath.logsumexp(x).shape, ()) + self.assertEqual(kmath.logsumexp(x, axis=1).shape, (None, 3)) + self.assertEqual(kmath.logsumexp(x, axis=(1, 2)).shape, (None,)) + self.assertEqual(kmath.logsumexp(x, keepdims=True).shape, (1, 1, 1)) def test_extract_sequences(self): # Defined dimension diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index f5ca1857c039..5fcd57fb7817 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -375,6 +375,8 @@ def reduce_shape(shape, axis=None, keepdims=False): return tuple([1 for _ in shape]) else: return tuple([]) + elif isinstance(axis, int): + axis = (axis,) if keepdims: for ax in axis: From 0da77e4d72d85656462b04714d5db085cf962a48 Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Sun, 29 Jun 2025 10:32:40 -0700 Subject: [PATCH 18/47] Only allow deserialization of `KerasSaveable`s by module and name. (#21429) Arbitrary functions and classes are not allowed. - Made `Operation` extend `KerasSaveable`, this required moving imports to avoid circular imports - `Layer` no longer need to extend `KerasSaveable` directly - Made feature space `Cross` and `Feature` extend `KerasSaveable` - Also dissallow public function `enable_unsafe_deserialization` --- keras/src/layers/layer.py | 3 +- .../src/layers/preprocessing/feature_space.py | 11 ++++-- keras/src/legacy/saving/legacy_h5_format.py | 5 ++- keras/src/legacy/saving/saving_utils.py | 8 +++-- keras/src/ops/operation.py | 6 +++- keras/src/saving/saving_lib.py | 35 +++++++++++++++---- keras/src/saving/serialization_lib.py | 9 ++++- 7 files changed, 61 insertions(+), 16 deletions(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index eaff1a8376a2..4ef338b668a1 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -44,7 +44,6 @@ from keras.src.metrics.metric import Metric from keras.src.ops.node import Node from keras.src.ops.operation import Operation -from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils import python_utils from keras.src.utils import summary_utils from keras.src.utils import traceback_utils @@ -67,7 +66,7 @@ @keras_export(["keras.Layer", "keras.layers.Layer"]) -class Layer(BackendLayer, Operation, KerasSaveable): +class Layer(BackendLayer, Operation): """This is the class from which all layers inherit. A layer is a callable object that takes as input one or more tensors and diff --git a/keras/src/layers/preprocessing/feature_space.py b/keras/src/layers/preprocessing/feature_space.py index 5fc5e34afafa..5f219dc1cf1c 100644 --- a/keras/src/layers/preprocessing/feature_space.py +++ b/keras/src/layers/preprocessing/feature_space.py @@ -6,12 +6,13 @@ from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer from keras.src.saving import saving_lib from keras.src.saving import serialization_lib +from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils import backend_utils from keras.src.utils.module_utils import tensorflow as tf from keras.src.utils.naming import auto_name -class Cross: +class Cross(KerasSaveable): def __init__(self, feature_names, crossing_dim, output_mode="one_hot"): if output_mode not in {"int", "one_hot"}: raise ValueError( @@ -23,6 +24,9 @@ def __init__(self, feature_names, crossing_dim, output_mode="one_hot"): self.crossing_dim = crossing_dim self.output_mode = output_mode + def _obj_type(self): + return "Cross" + @property def name(self): return "_X_".join(self.feature_names) @@ -39,7 +43,7 @@ def from_config(cls, config): return cls(**config) -class Feature: +class Feature(KerasSaveable): def __init__(self, dtype, preprocessor, output_mode): if output_mode not in {"int", "one_hot", "float"}: raise ValueError( @@ -55,6 +59,9 @@ def __init__(self, dtype, preprocessor, output_mode): self.preprocessor = preprocessor self.output_mode = output_mode + def _obj_type(self): + return "Feature" + def get_config(self): return { "dtype": self.dtype, diff --git a/keras/src/legacy/saving/legacy_h5_format.py b/keras/src/legacy/saving/legacy_h5_format.py index d7f3c3eb7ded..5b919f80e7c6 100644 --- a/keras/src/legacy/saving/legacy_h5_format.py +++ b/keras/src/legacy/saving/legacy_h5_format.py @@ -6,7 +6,6 @@ from absl import logging from keras.src import backend -from keras.src import optimizers from keras.src.backend.common import global_state from keras.src.legacy.saving import json_utils from keras.src.legacy.saving import saving_options @@ -161,6 +160,8 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # Set optimizer weights. if "optimizer_weights" in f: try: + from keras.src import optimizers + if isinstance(model.optimizer, optimizers.Optimizer): model.optimizer.build(model._trainable_variables) else: @@ -249,6 +250,8 @@ def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer): hdf5_group: HDF5 group. optimizer: optimizer instance. """ + from keras.src import optimizers + if isinstance(optimizer, optimizers.Optimizer): symbolic_weights = optimizer.variables else: diff --git a/keras/src/legacy/saving/saving_utils.py b/keras/src/legacy/saving/saving_utils.py index 5780ad701163..1373ba11e785 100644 --- a/keras/src/legacy/saving/saving_utils.py +++ b/keras/src/legacy/saving/saving_utils.py @@ -4,11 +4,8 @@ from absl import logging from keras.src import backend -from keras.src import layers from keras.src import losses from keras.src import metrics as metrics_module -from keras.src import models -from keras.src import optimizers from keras.src import tree from keras.src.legacy.saving import serialization from keras.src.saving import object_registration @@ -49,6 +46,9 @@ def model_from_config(config, custom_objects=None): global MODULE_OBJECTS if not hasattr(MODULE_OBJECTS, "ALL_OBJECTS"): + from keras.src import layers + from keras.src import models + MODULE_OBJECTS.ALL_OBJECTS = layers.__dict__ MODULE_OBJECTS.ALL_OBJECTS["InputLayer"] = layers.InputLayer MODULE_OBJECTS.ALL_OBJECTS["Functional"] = models.Functional @@ -132,6 +132,8 @@ def compile_args_from_training_config(training_config, custom_objects=None): custom_objects = {} with object_registration.CustomObjectScope(custom_objects): + from keras.src import optimizers + optimizer_config = training_config["optimizer_config"] optimizer = optimizers.deserialize(optimizer_config) # Ensure backwards compatibility for optimizers in legacy H5 files diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index 9529a8e689f1..5813593340e3 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -7,13 +7,14 @@ from keras.src.api_export import keras_export from keras.src.backend.common.keras_tensor import any_symbolic_tensors from keras.src.ops.node import Node +from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils import python_utils from keras.src.utils import traceback_utils from keras.src.utils.naming import auto_name @keras_export("keras.Operation") -class Operation: +class Operation(KerasSaveable): def __init__(self, name=None): if name is None: name = auto_name(self.__class__.__name__) @@ -311,6 +312,9 @@ def _get_node_attribute_at_index(self, node_index, attr, attr_name): else: return values + def _obj_type(self): + return "Operation" + # Hooks for backend layer classes def _post_build(self): """Can be overridden for per backend post build actions.""" diff --git a/keras/src/saving/saving_lib.py b/keras/src/saving/saving_lib.py index 3d19e81ddec6..01d0b0bbb031 100644 --- a/keras/src/saving/saving_lib.py +++ b/keras/src/saving/saving_lib.py @@ -16,14 +16,9 @@ from keras.src import backend from keras.src.backend.common import global_state -from keras.src.layers.layer import Layer -from keras.src.losses.loss import Loss -from keras.src.metrics.metric import Metric -from keras.src.optimizers.optimizer import Optimizer from keras.src.saving.serialization_lib import ObjectSharingScope from keras.src.saving.serialization_lib import deserialize_keras_object from keras.src.saving.serialization_lib import serialize_keras_object -from keras.src.trainers.compile_utils import CompileMetrics from keras.src.utils import dtype_utils from keras.src.utils import file_utils from keras.src.utils import io_utils @@ -1584,32 +1579,60 @@ def get_attr_skipset(obj_type): "_self_unconditional_dependency_names", ] ) + if obj_type == "Operation": + from keras.src.ops.operation import Operation + + ref_obj = Operation() + skipset.update(dir(ref_obj)) if obj_type == "Layer": + from keras.src.layers.layer import Layer + ref_obj = Layer() skipset.update(dir(ref_obj)) elif obj_type == "Functional": + from keras.src.layers.layer import Layer + ref_obj = Layer() skipset.update(dir(ref_obj) + ["operations", "_operations"]) elif obj_type == "Sequential": + from keras.src.layers.layer import Layer + ref_obj = Layer() skipset.update(dir(ref_obj) + ["_functional"]) elif obj_type == "Metric": + from keras.src.metrics.metric import Metric + from keras.src.trainers.compile_utils import CompileMetrics + ref_obj_a = Metric() ref_obj_b = CompileMetrics([], []) skipset.update(dir(ref_obj_a) + dir(ref_obj_b)) elif obj_type == "Optimizer": + from keras.src.optimizers.optimizer import Optimizer + ref_obj = Optimizer(1.0) skipset.update(dir(ref_obj)) skipset.remove("variables") elif obj_type == "Loss": + from keras.src.losses.loss import Loss + ref_obj = Loss() skipset.update(dir(ref_obj)) + elif obj_type == "Cross": + from keras.src.layers.preprocessing.feature_space import Cross + + ref_obj = Cross((), 1) + skipset.update(dir(ref_obj)) + elif obj_type == "Feature": + from keras.src.layers.preprocessing.feature_space import Feature + + ref_obj = Feature("int32", lambda x: x, "int") + skipset.update(dir(ref_obj)) else: raise ValueError( f"get_attr_skipset got invalid {obj_type=}. " "Accepted values for `obj_type` are " "['Layer', 'Functional', 'Sequential', 'Metric', " - "'Optimizer', 'Loss']" + "'Optimizer', 'Loss', 'Cross', 'Feature']" ) global_state.set_global_attribute( diff --git a/keras/src/saving/serialization_lib.py b/keras/src/saving/serialization_lib.py index 53b88f389407..180176698d76 100644 --- a/keras/src/saving/serialization_lib.py +++ b/keras/src/saving/serialization_lib.py @@ -12,6 +12,7 @@ from keras.src.api_export import keras_export from keras.src.backend.common import global_state from keras.src.saving import object_registration +from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils import python_utils from keras.src.utils.module_utils import tensorflow as tf @@ -32,6 +33,7 @@ LOADING_APIS = frozenset( { + "keras.config.enable_unsafe_deserialization", "keras.models.load_model", "keras.preprocessing.image.load_img", "keras.saving.load_model", @@ -817,8 +819,13 @@ def _retrieve_class_or_fn( try: mod = importlib.import_module(module) obj = vars(mod).get(name, None) - if obj is not None: + if isinstance(obj, type) and issubclass(obj, KerasSaveable): return obj + else: + raise ValueError( + f"Could not deserialize '{module}.{name}' because " + "it is not a KerasSaveable subclass" + ) except ModuleNotFoundError: raise TypeError( f"Could not deserialize {obj_type} '{name}' because " From cb639c526a1c621271f3db93d402790d0748885d Mon Sep 17 00:00:00 2001 From: Harshith K Date: Wed, 2 Jul 2025 13:02:30 +0530 Subject: [PATCH 19/47] commented tensorflow deps --- requirements-tensorflow-tpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-tensorflow-tpu.txt b/requirements-tensorflow-tpu.txt index 8f78f85d8882..d2bec8f5c2e4 100644 --- a/requirements-tensorflow-tpu.txt +++ b/requirements-tensorflow-tpu.txt @@ -9,6 +9,6 @@ torch==2.6.0 # Jax cpu-only version (needed for testing). -jax[cpu] +jax -r requirements-common.txt From c0d1743e065dd03e5d03a65a41733a95ab88aa51 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Wed, 2 Jul 2025 19:17:38 +0530 Subject: [PATCH 20/47] Added log of dtypes_test_tpu.py and the test script for the same --- dtypes_test_tpu.log | 547 ++++++++++++++++++++ keras/src/backend/common/dtypes_test_TPU.py | 159 ++++++ 2 files changed, 706 insertions(+) create mode 100644 dtypes_test_tpu.log create mode 100644 keras/src/backend/common/dtypes_test_TPU.py diff --git a/dtypes_test_tpu.log b/dtypes_test_tpu.log new file mode 100644 index 000000000000..88073fc51de1 --- /dev/null +++ b/dtypes_test_tpu.log @@ -0,0 +1,547 @@ +(env) kharshith@t1v-n-e3deefd5-w-0:~/keras/keras/src/backend/common$ pytest -s dtypes_test_tpu.py +2025-07-02 13:24:39.095603: I tensorflow/core/tpu/tpu_api_dlsym_initializer.cc:95] Opening library: /home/kharshith/.local/lib/python3.10/site-packages/tensorflow/python/platform/../../libtensorflow_cc.so.2 +2025-07-02 13:24:39.095727: I tensorflow/core/tpu/tpu_api_dlsym_initializer.cc:121] Libtpu path is: /home/kharshith/.local/lib/python3.10/site-packages/libtpu/libtpu.so +2025-07-02 13:24:39.135321: I tensorflow/core/tpu/tpu_api_dlsym_initializer.cc:138] FindAndLoadTpuLibrary failed with ABORTED: The TPU is already in use by another process probably owned by another user. Run "$ sudo lsof -w /dev/vfio/0" to figure out which process is using the TPU. If you still get this message, run "$ sudo rm /tmp/libtpu_lockfile".. This is expected if TPU is not used. +2025-07-02 13:24:39.139591: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. +================================================================================= test session starts ================================================================================== +platform linux -- Python 3.10.12, pytest-8.4.1, pluggy-1.6.0 -- /home/kharshith/env/bin/python +cachedir: .pytest_cache +rootdir: /home/kharshith/keras +configfile: pyproject.toml +plugins: cov-6.2.1 +collected 512 items + +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_bfloat16 TPU initialization attempt 1/3 +2025-07-02 13:24:42.738075: I external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:47] Platform manager status: NOT_FOUND: Could not find registered platform with name: "TPU". Available platform names are: Host +2025-07-02 13:24:42.738114: I external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:82] No TPU platform registered. Waiting 1 second and trying again... (4 tries left) Platform manager status: OK +2025-07-02 13:24:43.738193: I external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:47] Platform manager status: NOT_FOUND: Could not find registered platform with name: "TPU". Available platform names are: Host +2025-07-02 13:24:43.738209: I external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:82] No TPU platform registered. Waiting 1 second and trying again... (3 tries left) Platform manager status: OK +2025-07-02 13:24:44.738285: I external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:47] Platform manager status: NOT_FOUND: Could not find registered platform with name: "TPU". Available platform names are: Host +2025-07-02 13:24:44.738300: I external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:82] No TPU platform registered. Waiting 1 second and trying again... (2 tries left) Platform manager status: OK +2025-07-02 13:24:45.738351: I external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:47] Platform manager status: NOT_FOUND: Could not find registered platform with name: "TPU". Available platform names are: Host +2025-07-02 13:24:45.738365: I external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:82] No TPU platform registered. Waiting 1 second and trying again... (1 tries left) Platform manager status: OK +2025-07-02 13:24:46.738440: I external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:47] Platform manager status: NOT_FOUND: Could not find registered platform with name: "TPU". Available platform names are: Host +2025-07-02 13:24:46.738456: I external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:78] No TPU platform found. Platform manager status: OK +WARNING: All log messages before absl::InitializeLog() is called are written to STDERR +I0000 00:00:1751462687.288045 16166 grpc_server_lib.cc:463] Started server with target: grpc://localhost:48934 +✓ TPU initialization successful! +TPU devices found: [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU_SYSTEM:0', device_type='TPU_SYSTEM')] +Number of TPU cores: 4 +WARNING:2025-07-02 13:25:03,730:jax._src.xla_bridge:798: A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu. +PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bfloat16_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_bool_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex128_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_complex64_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float16_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float32_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_float64_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int16_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int32_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int64_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_int8_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_none_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bfloat16_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_bool W0000 00:00:1751462704.562413 16546 remote_tensor_handle_data.cc:80] Unable to destroy remote tensor handles. If you are running a tf.function, it usually indicates some op in the graph gets an error: Value for attr 'T' of bool is not in the list of allowed values: bfloat16, half, float, double, uint8, uint16, uint32, uint64, int8, int16, int32, int64, complex64, complex128 + ; NodeDef: {{node AddV2}}; Op z:T; attr=T:type,allowed=[DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128]; is_commutative=true; is_aggregate=true> +PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_bool_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex128_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_complex64_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float16_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float32_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_float64_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int16_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int32_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int64_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_int8_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_none_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint16_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint32_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint64_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_on_tpu_uint8_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint16_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint32_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint64_uint8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_bfloat16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_bool PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_complex128 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_complex64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_float16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_float32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_float64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_int16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_int32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_int64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_int8 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_none PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_uint16 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_uint32 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_uint64 PASSED +dtypes_test_tpu.py::DtypesTPUTest::test_result_type_with_tensor_uint8_uint8 PASSED + +================================================================================= 512 passed in 23.21s ================================================================================= \ No newline at end of file diff --git a/keras/src/backend/common/dtypes_test_TPU.py b/keras/src/backend/common/dtypes_test_TPU.py new file mode 100644 index 000000000000..16a5091fe127 --- /dev/null +++ b/keras/src/backend/common/dtypes_test_TPU.py @@ -0,0 +1,159 @@ +import tensorflow as tf +from unittest.mock import patch +import os + +from absl.testing import parameterized + +from keras.src import backend +from keras.src import ops +from keras.src.backend.common import dtypes +from keras.src.testing import test_case +from keras.src.testing.test_utils import named_product + + +os.environ['TPU_NAME'] = 'harshith-tf-4' +os.environ['JAX_PLATFORMS'] = '' + +class DtypesTPUTest(test_case.TestCase): + """Test the dtype to verify that the behavior matches JAX, with TPU support.""" + + # Configuration for TPU retry logic + TPU_MAX_RETRIES = int(os.environ.get('TPU_MAX_RETRIES', '3')) + TPU_BASE_DELAY = float(os.environ.get('TPU_BASE_DELAY', '2.0')) + + if backend.backend() == "torch": + from keras.src.backend.torch.core import to_torch_dtype + + # TODO: torch doesn't support uint64. + ALL_DTYPES = [] + for x in dtypes.ALLOWED_DTYPES: + if x not in ["string", "uint64"]: + x = str(to_torch_dtype(x)).split(".")[-1] + if x not in ALL_DTYPES: # skip duplicates created by remapping + ALL_DTYPES.append(x) + ALL_DTYPES += [None] + elif backend.backend() == "openvino": + ALL_DTYPES = [ + x + for x in dtypes.ALLOWED_DTYPES + if x not in ["string", "complex64", "complex128"] + ] + [None] + else: + ALL_DTYPES = [x for x in dtypes.ALLOWED_DTYPES if x != "string"] + [ + None + ] + # Remove float8 dtypes for the following tests + ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES] + + @classmethod + def _cleanup_tpu_state(cls): + """Clean up any partial TPU initialization state.""" + try: + tf.config.experimental_disconnect_from_cluster() + except: + pass + + try: + tf.config.experimental_reset_memory_stats('TPU_SYSTEM') + except: + pass + + @classmethod + def setUpClass(cls): + """Initialize TPU if available, with retry logic.""" + import time + + super().setUpClass() + cls.tpu_available = False + cls.tpu_strategy = None + + max_retries = cls.TPU_MAX_RETRIES + base_delay = cls.TPU_BASE_DELAY + + for attempt in range(max_retries): + try: + print(f"TPU initialization attempt {attempt + 1}/{max_retries}") + + cls._cleanup_tpu_state() + + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + + tpu_devices = tf.config.list_logical_devices('TPU_SYSTEM') + if not tpu_devices: + raise RuntimeError("No TPU devices found after initialization") + + cls.tpu_strategy = tf.distribute.TPUStrategy(resolver) + cls.tpu_available = True + + print("✓ TPU initialization successful!") + print("TPU devices found: ", tpu_devices) + print(f"Number of TPU cores: {cls.tpu_strategy.num_replicas_in_sync}") + break + + except (ValueError, RuntimeError, Exception) as e: + print(f"✗ TPU initialization attempt {attempt + 1} failed: {e}") + if attempt < max_retries - 1: + delay = base_delay * (2 ** attempt) + (attempt * 0.5) + print(f"Retrying in {delay:.1f} seconds...") + time.sleep(delay) + cls._cleanup_tpu_state() + else: + print("All TPU initialization attempts failed. Falling back to CPU/GPU testing") + cls.tpu_available = False + + def setUp(self): + from jax.experimental import enable_x64 + + self.jax_enable_x64 = enable_x64() + self.jax_enable_x64.__enter__() + return super().setUp() + + def tearDown(self): + self.jax_enable_x64.__exit__(None, None, None) + return super().tearDown() + + @parameterized.named_parameters( + named_product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES) + ) + def test_result_type_with_tensor(self, dtype1, dtype2): + import jax.numpy as jnp + + x1 = ops.ones((1,), dtype=dtype1) + x2 = ops.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + + out = backend.result_type(x1.dtype, x2.dtype) + expected = jnp.result_type(x1_jax, x2_jax).name + self.assertEqual(out, expected) + + @parameterized.named_parameters( + named_product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES) + ) + def test_result_type_with_tensor_on_tpu(self, dtype1, dtype2): + """Test dtype result_type behavior specifically on TPU.""" + if not self.tpu_available: + self.skipTest("TPU not available") + + import jax.numpy as jnp + + def _test_on_tpu(): + x1 = ops.ones((1,), dtype=dtype1) + x2 = ops.ones((1,), dtype=dtype2) + + result = ops.add(x1, x2) + + out = backend.result_type(x1.dtype, x2.dtype) + return out, result.dtype + + with self.tpu_strategy.scope(): + out, result_dtype = _test_on_tpu() + + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected = jnp.result_type(x1_jax, x2_jax).name + + self.assertEqual(out, expected) + self.assertEqual(result_dtype, expected) \ No newline at end of file From 306e6e76b981126abcab7617d93f6db1aa8c1e4b Mon Sep 17 00:00:00 2001 From: Harshith K Date: Wed, 2 Jul 2025 19:27:32 +0530 Subject: [PATCH 21/47] modified dtypes_test_tpu.py as per pre-commit standards --- keras/src/backend/common/dtypes_test_TPU.py | 36 +++++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/keras/src/backend/common/dtypes_test_TPU.py b/keras/src/backend/common/dtypes_test_TPU.py index 16a5091fe127..6944eb434104 100644 --- a/keras/src/backend/common/dtypes_test_TPU.py +++ b/keras/src/backend/common/dtypes_test_TPU.py @@ -1,7 +1,6 @@ -import tensorflow as tf -from unittest.mock import patch import os +import tensorflow as tf from absl.testing import parameterized from keras.src import backend @@ -10,16 +9,17 @@ from keras.src.testing import test_case from keras.src.testing.test_utils import named_product +os.environ["TPU_NAME"] = "harshith-tf-4" +os.environ["JAX_PLATFORMS"] = "" -os.environ['TPU_NAME'] = 'harshith-tf-4' -os.environ['JAX_PLATFORMS'] = '' class DtypesTPUTest(test_case.TestCase): - """Test the dtype to verify that the behavior matches JAX, with TPU support.""" + """Test the dtype to verify that the behavior matches + JAX, with TPU support.""" # Configuration for TPU retry logic - TPU_MAX_RETRIES = int(os.environ.get('TPU_MAX_RETRIES', '3')) - TPU_BASE_DELAY = float(os.environ.get('TPU_BASE_DELAY', '2.0')) + TPU_MAX_RETRIES = int(os.environ.get("TPU_MAX_RETRIES", "3")) + TPU_BASE_DELAY = float(os.environ.get("TPU_BASE_DELAY", "2.0")) if backend.backend() == "torch": from keras.src.backend.torch.core import to_torch_dtype @@ -54,7 +54,7 @@ def _cleanup_tpu_state(cls): pass try: - tf.config.experimental_reset_memory_stats('TPU_SYSTEM') + tf.config.experimental_reset_memory_stats("TPU_SYSTEM") except: pass @@ -80,27 +80,35 @@ def setUpClass(cls): tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) - tpu_devices = tf.config.list_logical_devices('TPU_SYSTEM') + tpu_devices = tf.config.list_logical_devices("TPU_SYSTEM") if not tpu_devices: - raise RuntimeError("No TPU devices found after initialization") + raise RuntimeError( + "No TPU devices found after initialization" + ) cls.tpu_strategy = tf.distribute.TPUStrategy(resolver) cls.tpu_available = True print("✓ TPU initialization successful!") print("TPU devices found: ", tpu_devices) - print(f"Number of TPU cores: {cls.tpu_strategy.num_replicas_in_sync}") + print( + f"Number of TPU cores: \ + {cls.tpu_strategy.num_replicas_in_sync}" + ) break except (ValueError, RuntimeError, Exception) as e: print(f"✗ TPU initialization attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: - delay = base_delay * (2 ** attempt) + (attempt * 0.5) + delay = base_delay * (2**attempt) + (attempt * 0.5) print(f"Retrying in {delay:.1f} seconds...") time.sleep(delay) cls._cleanup_tpu_state() else: - print("All TPU initialization attempts failed. Falling back to CPU/GPU testing") + print( + "All TPU initialization attempts failed. \ + Falling back to CPU/GPU testing" + ) cls.tpu_available = False def setUp(self): @@ -156,4 +164,4 @@ def _test_on_tpu(): expected = jnp.result_type(x1_jax, x2_jax).name self.assertEqual(out, expected) - self.assertEqual(result_dtype, expected) \ No newline at end of file + self.assertEqual(result_dtype, expected) From 4e584fc8639daa714f1cd0d0f7c1edfd519e73f4 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Thu, 3 Jul 2025 14:55:15 +0530 Subject: [PATCH 22/47] Added TPU initiaization and teardown functionalities in conftest.py, developed dtypes_new_test.py to use requires_tpu marker --- conftest.py | 67 +++++ keras/src/backend/common/dtypes_new_test.py | 277 ++++++++++++++++++++ keras/src/backend/common/dtypes_test.py | 2 +- keras/src/backend/common/dtypes_test_TPU.py | 78 +----- test_files.txt | 244 +++++++++++++++++ 5 files changed, 592 insertions(+), 76 deletions(-) create mode 100644 keras/src/backend/common/dtypes_new_test.py create mode 100644 test_files.txt diff --git a/conftest.py b/conftest.py index cf7ff5599db4..83231954c58d 100644 --- a/conftest.py +++ b/conftest.py @@ -22,6 +22,9 @@ def pytest_configure(config): "markers", "requires_trainable_backend: mark test for trainable backend only", ) + config.addinivalue_line( + "markers", "requires_tpu: mark test to run only on TPU" + ) def pytest_collection_modifyitems(config, items): @@ -61,3 +64,67 @@ def skip_if_backend(given_backend, reason): return pytest.mark.skipif(backend() == given_backend, reason=reason) + + +def _cleanup_tpu_state(): + import tensorflow as tf + + try: + tf.config.experimental_disconnect_from_cluster() + except: + pass + + try: + tf.config.experimental_reset_memory_stats("TPU_SYSTEM") + except: + pass + + +@pytest.fixture(scope="session") +def tpu_strategy_fixture(): + import tensorflow as tf + import time + + os.environ["TPU_NAME"] = "harshith-tf-4" + os.environ["JAX_PLATFORMS"] = "" + max_retries = int(os.environ.get("TPU_MAX_RETRIES", "3")) + base_delay = float(os.environ.get("TPU_BASE_DELAY", "2.0")) + tpu_available = False + strategy = None + + for attempt in range(max_retries): + try: + print(f"TPU initialization attempt {attempt + 1}/{max_retries}") + _cleanup_tpu_state() + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + strategy = tf.distribute.TPUStrategy(resolver) + tpu_available = True + print("✓ TPU initialization successful!") + break + except (ValueError, RuntimeError, Exception) as e: + print(f"✗ TPU initialization attempt {attempt + 1} failed: {e}") + if attempt < max_retries - 1: + delay = base_delay * (2**attempt) + (attempt * 0.5) + print(f"Retrying in {delay:.1f} seconds...") + time.sleep(delay) + _cleanup_tpu_state() + else: + print("All TPU initialization attempts failed.") + + if not tpu_available: + pytest.skip("TPU not available") + + yield strategy + + # Teardown + _cleanup_tpu_state() + + +@pytest.fixture(autouse=True) +def tpu(request): + marker = request.node.get_closest_marker("requires_tpu") + if marker: + strategy = request.getfixturevalue("tpu_strategy_fixture") + request.node.cls.tpu_strategy = strategy \ No newline at end of file diff --git a/keras/src/backend/common/dtypes_new_test.py b/keras/src/backend/common/dtypes_new_test.py new file mode 100644 index 000000000000..00acfc59df3e --- /dev/null +++ b/keras/src/backend/common/dtypes_new_test.py @@ -0,0 +1,277 @@ +from unittest.mock import patch + +from absl.testing import parameterized + +from keras.src import backend +from keras.src import ops +from keras.src.backend.common import dtypes +from keras.src.testing import test_case +from keras.src.testing.test_utils import named_product + +os.environ["TPU_NAME"] = "harshith-tf-4" +os.environ["JAX_PLATFORMS"] = "" + +@pytest.mark.requires_tpu +class DtypesTest(test_case.TestCase): + """Test the dtype to verify that the behavior matches JAX.""" + + if backend.backend() == "torch": + from keras.src.backend.torch.core import to_torch_dtype + + # TODO: torch doesn't support uint64. + ALL_DTYPES = [] + for x in dtypes.ALLOWED_DTYPES: + if x not in ["string", "uint64"]: + x = str(to_torch_dtype(x)).split(".")[-1] + if x not in ALL_DTYPES: # skip duplicates created by remapping + ALL_DTYPES.append(x) + ALL_DTYPES += [None] + elif backend.backend() == "openvino": + ALL_DTYPES = [ + x + for x in dtypes.ALLOWED_DTYPES + if x not in ["string", "complex64", "complex128"] + ] + [None] + else: + ALL_DTYPES = [x for x in dtypes.ALLOWED_DTYPES if x != "string"] + [ + None + ] + # Remove float8 dtypes for the following tests + ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES] + + def setUp(self): + from jax.experimental import enable_x64 + + self.jax_enable_x64 = enable_x64() + self.jax_enable_x64.__enter__() + return super().setUp() + + def tearDown(self): + self.jax_enable_x64.__exit__(None, None, None) + return super().tearDown() + + @parameterized.named_parameters( + named_product(dtype1=ALL_DTYPES, dtype2=[bool, int, float]) + ) + def test_result_type_with_python_scalar_types(self, dtype1, dtype2): + import jax.numpy as jnp + + out = backend.result_type(dtype1, dtype2) + expected = jnp.result_type(dtype1, dtype2).name + self.assertEqual(out, expected) + + @parameterized.named_parameters( + named_product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES) + ) + def test_result_type_with_tensor(self, dtype1, dtype2): + import jax.numpy as jnp + + x1 = ops.ones((1,), dtype=dtype1) + x2 = ops.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + + out = backend.result_type(x1.dtype, x2.dtype) + expected = jnp.result_type(x1_jax, x2_jax).name + self.assertEqual(out, expected) + + def test_result_type_with_none(self): + import jax.numpy as jnp + + self.assertEqual(backend.result_type(None), jnp.result_type(None).name) + + def test_result_type_empty_list(self): + self.assertEqual(backend.result_type(), "float32") + + def test_respect_weak_type_for_bool(self): + self.assertEqual(dtypes._respect_weak_type("bool", True), "bool") + + def test_respect_weak_type_for_int(self): + self.assertEqual(dtypes._respect_weak_type("int32", True), "int") + + def test_respect_weak_type_for_float(self): + self.assertEqual(dtypes._respect_weak_type("float32", True), "float") + + def test_resolve_weak_type_for_bfloat16(self): + self.assertEqual(dtypes._resolve_weak_type("bfloat16"), "float32") + + def test_resolve_weak_type_for_bfloat16_with_precision(self): + self.assertEqual( + dtypes._resolve_weak_type("bfloat16", precision="64"), "float64" + ) + + def test_respect_weak_type_for_complex64(self): + self.assertAllEqual( + dtypes._respect_weak_type("complex64", True), "complex" + ) + + def test_respect_weak_type_for_complex128(self): + self.assertAllEqual( + dtypes._respect_weak_type("complex128", True), "complex" + ) + + def test_invalid_dtype_for_keras_promotion(self): + with self.assertRaisesRegex( + ValueError, "is not a valid dtype for Keras type promotion." + ): + dtypes._least_upper_bound("invalid_dtype") + + def test_resolve_weak_type_for_invalid_dtype(self): + with self.assertRaisesRegex( + ValueError, "Invalid value for argument `dtype`. Expected one of" + ): + dtypes._resolve_weak_type("invalid_dtype") + + def test_resolve_weak_type_for_invalid_precision(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value for argument `precision`. Expected one of", + ): + dtypes._resolve_weak_type("int32", precision="invalid_precision") + + def test_cycle_detection_in_make_lattice_upper_bounds(self): + original_lattice_function = dtypes._type_promotion_lattice + + def mock_lattice(): + lattice = original_lattice_function() + lattice["int32"].append("float32") + lattice["float32"].append("int32") + return lattice + + dtypes._type_promotion_lattice = mock_lattice + + with self.assertRaisesRegex( + ValueError, "cycle detected in type promotion lattice for node" + ): + dtypes._make_lattice_upper_bounds() + + dtypes._type_promotion_lattice = original_lattice_function + + def test_respect_weak_type_for_invalid_dtype(self): + with self.assertRaisesRegex( + ValueError, "Invalid value for argument `dtype`. Expected one of" + ): + dtypes._respect_weak_type("invalid_dtype", True) + + def test_invalid_dtype_in_least_upper_bound(self): + invalid_dtype = "non_existent_dtype" + with self.assertRaisesRegex( + ValueError, "is not a valid dtype for Keras type promotion" + ): + dtypes._least_upper_bound(invalid_dtype) + + def test_empty_lub_in_least_upper_bound(self): + dtype1 = "float32" + dtype2 = "int32" + with patch.dict( + dtypes.LATTICE_UPPER_BOUNDS, + {"float32": set(), "int32": set()}, + clear=True, + ): + with self.assertRaisesRegex( + ValueError, "no available implicit dtype promotion path" + ): + dtypes._least_upper_bound(dtype1, dtype2) + + def test_valid_dtype_leading_to_single_lub_element(self): + self.assertEqual( + dtypes._least_upper_bound("float32", "int32"), "float32" + ) + + def test_valid_dtype_leading_to_keyerror_and_valueerror(self): + invalid_dtype = "non_existent_dtype" + with self.assertRaisesRegex( + ValueError, "is not a valid dtype for Keras type promotion" + ): + dtypes._least_upper_bound(invalid_dtype) + + def test_resolve_weak_type_bool(self): + self.assertEqual(dtypes._resolve_weak_type("bool"), "bool") + + def test_resolve_weak_type_int(self): + self.assertEqual( + dtypes._resolve_weak_type("int32", precision="32"), "int32" + ) + self.assertEqual( + dtypes._resolve_weak_type("int64", precision="64"), "int64" + ) + + def test_resolve_weak_type_uint(self): + self.assertEqual( + dtypes._resolve_weak_type("uint32", precision="32"), "uint32" + ) + self.assertEqual( + dtypes._resolve_weak_type("uint64", precision="64"), "uint64" + ) + + def test_resolve_weak_type_float(self): + self.assertEqual( + dtypes._resolve_weak_type("float32", precision="32"), "float32" + ) + self.assertEqual( + dtypes._resolve_weak_type("float64", precision="64"), "float64" + ) + + def test_least_upper_bound_ensure_order_independence(self): + # Test to ensure _least_upper_bound is order-independent. + result1 = dtypes._least_upper_bound("float32", "int32") + result2 = dtypes._least_upper_bound("int32", "float32") + self.assertEqual(result1, result2) + + def test_least_upper_bound_single_element(self): + dtypes.LATTICE_UPPER_BOUNDS["test_dtype"] = {"test_dtype"} + self.assertEqual(dtypes._least_upper_bound("test_dtype"), "test_dtype") + + def test_least_upper_bound_no_element(self): + dtypes.LATTICE_UPPER_BOUNDS["test_dtype"] = set() + with self.assertRaisesRegex( + ValueError, "no available implicit dtype promotion path" + ): + dtypes._least_upper_bound("test_dtype") + + def test_least_upper_bound_with_no_common_upper_bound(self): + with patch.dict( + dtypes.LATTICE_UPPER_BOUNDS, + {"test_dtype1": set(), "test_dtype2": set()}, + clear=True, + ): + with self.assertRaisesRegex( + ValueError, "no available implicit dtype promotion path" + ): + dtypes._least_upper_bound("test_dtype1", "test_dtype2") + + def test_invalid_float8_dtype(self): + with self.assertRaisesRegex( + ValueError, "There is no implicit conversions from float8 dtypes" + ): + dtypes.result_type("float8_e4m3fn", "bfloat16") + with self.assertRaisesRegex( + ValueError, "There is no implicit conversions from float8 dtypes" + ): + dtypes.result_type("float8_e5m2", "bfloat16") + + @parameterized.named_parameters( + named_product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES) + ) + def test_result_type_with_tensor_on_tpu(self, dtype1, dtype2): + """Test dtype result_type behavior specifically on TPU.""" + import jax.numpy as jnp + + def _test_on_tpu(): + x1 = ops.ones((1,), dtype=dtype1) + x2 = ops.ones((1,), dtype=dtype2) + + result = ops.add(x1, x2) + + out = backend.result_type(x1.dtype, x2.dtype) + return out, result.dtype + + with self.tpu_strategy.scope(): + out, result_dtype = _test_on_tpu() + + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected = jnp.result_type(x1_jax, x2_jax).name + + self.assertEqual(out, expected) + self.assertEqual(result_dtype, expected) \ No newline at end of file diff --git a/keras/src/backend/common/dtypes_test.py b/keras/src/backend/common/dtypes_test.py index 0f1ca48b6413..8e34e7b03f79 100644 --- a/keras/src/backend/common/dtypes_test.py +++ b/keras/src/backend/common/dtypes_test.py @@ -387,4 +387,4 @@ def test_invalid_float8_dtype(self): with self.assertRaisesRegex( ValueError, "There is no implicit conversions from float8 dtypes" ): - dtypes.result_type("float8_e5m2", "bfloat16") + dtypes.result_type("float8_e5m2", "bfloat16") \ No newline at end of file diff --git a/keras/src/backend/common/dtypes_test_TPU.py b/keras/src/backend/common/dtypes_test_TPU.py index 6944eb434104..6a89462a8621 100644 --- a/keras/src/backend/common/dtypes_test_TPU.py +++ b/keras/src/backend/common/dtypes_test_TPU.py @@ -1,6 +1,6 @@ import os -import tensorflow as tf +import pytest from absl.testing import parameterized from keras.src import backend @@ -13,14 +13,11 @@ os.environ["JAX_PLATFORMS"] = "" +@pytest.mark.requires_tpu class DtypesTPUTest(test_case.TestCase): """Test the dtype to verify that the behavior matches JAX, with TPU support.""" - # Configuration for TPU retry logic - TPU_MAX_RETRIES = int(os.environ.get("TPU_MAX_RETRIES", "3")) - TPU_BASE_DELAY = float(os.environ.get("TPU_BASE_DELAY", "2.0")) - if backend.backend() == "torch": from keras.src.backend.torch.core import to_torch_dtype @@ -45,72 +42,6 @@ class DtypesTPUTest(test_case.TestCase): # Remove float8 dtypes for the following tests ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES] - @classmethod - def _cleanup_tpu_state(cls): - """Clean up any partial TPU initialization state.""" - try: - tf.config.experimental_disconnect_from_cluster() - except: - pass - - try: - tf.config.experimental_reset_memory_stats("TPU_SYSTEM") - except: - pass - - @classmethod - def setUpClass(cls): - """Initialize TPU if available, with retry logic.""" - import time - - super().setUpClass() - cls.tpu_available = False - cls.tpu_strategy = None - - max_retries = cls.TPU_MAX_RETRIES - base_delay = cls.TPU_BASE_DELAY - - for attempt in range(max_retries): - try: - print(f"TPU initialization attempt {attempt + 1}/{max_retries}") - - cls._cleanup_tpu_state() - - resolver = tf.distribute.cluster_resolver.TPUClusterResolver() - tf.config.experimental_connect_to_cluster(resolver) - tf.tpu.experimental.initialize_tpu_system(resolver) - - tpu_devices = tf.config.list_logical_devices("TPU_SYSTEM") - if not tpu_devices: - raise RuntimeError( - "No TPU devices found after initialization" - ) - - cls.tpu_strategy = tf.distribute.TPUStrategy(resolver) - cls.tpu_available = True - - print("✓ TPU initialization successful!") - print("TPU devices found: ", tpu_devices) - print( - f"Number of TPU cores: \ - {cls.tpu_strategy.num_replicas_in_sync}" - ) - break - - except (ValueError, RuntimeError, Exception) as e: - print(f"✗ TPU initialization attempt {attempt + 1} failed: {e}") - if attempt < max_retries - 1: - delay = base_delay * (2**attempt) + (attempt * 0.5) - print(f"Retrying in {delay:.1f} seconds...") - time.sleep(delay) - cls._cleanup_tpu_state() - else: - print( - "All TPU initialization attempts failed. \ - Falling back to CPU/GPU testing" - ) - cls.tpu_available = False - def setUp(self): from jax.experimental import enable_x64 @@ -142,9 +73,6 @@ def test_result_type_with_tensor(self, dtype1, dtype2): ) def test_result_type_with_tensor_on_tpu(self, dtype1, dtype2): """Test dtype result_type behavior specifically on TPU.""" - if not self.tpu_available: - self.skipTest("TPU not available") - import jax.numpy as jnp def _test_on_tpu(): @@ -164,4 +92,4 @@ def _test_on_tpu(): expected = jnp.result_type(x1_jax, x2_jax).name self.assertEqual(out, expected) - self.assertEqual(result_dtype, expected) + self.assertEqual(result_dtype, expected) \ No newline at end of file diff --git a/test_files.txt b/test_files.txt new file mode 100644 index 000000000000..cdeb78958b56 --- /dev/null +++ b/test_files.txt @@ -0,0 +1,244 @@ +/Users/kharshith/keras/keras/src/backend/common/dtypes_test.py +/Users/kharshith/keras/keras/src/ops/math_test.py +/Users/kharshith/keras/keras/src/trainers/trainer_test.py +/Users/kharshith/keras/keras/src/trainers/epoch_iterator_test.py +/Users/kharshith/keras/keras/src/backend/tensorflow/distribute_test.py +/Users/kharshith/keras/keras/src/wrappers/sklearn_test.py +/Users/kharshith/keras/keras/src/utils/tracking_test.py +/Users/kharshith/keras/keras/src/utils/torch_utils_test.py +/Users/kharshith/keras/keras/src/utils/timeseries_dataset_utils_test.py +/Users/kharshith/keras/keras/src/utils/text_dataset_utils_test.py +/Users/kharshith/keras/keras/src/utils/summary_utils_test.py +/Users/kharshith/keras/keras/src/utils/sequence_utils_test.py +/Users/kharshith/keras/keras/src/utils/rng_utils_test.py +/Users/kharshith/keras/keras/src/utils/python_utils_test.py +/Users/kharshith/keras/keras/src/utils/numerical_utils_test.py +/Users/kharshith/keras/keras/src/utils/naming_test.py +/Users/kharshith/keras/keras/src/utils/jax_layer_test.py +/Users/kharshith/keras/keras/src/utils/io_utils_test.py +/Users/kharshith/keras/keras/src/utils/image_dataset_utils_test.py +/Users/kharshith/keras/keras/src/utils/file_utils_test.py +/Users/kharshith/keras/keras/src/utils/dtype_utils_test.py +/Users/kharshith/keras/keras/src/utils/dataset_utils_test.py +/Users/kharshith/keras/keras/src/utils/code_stats_test.py +/Users/kharshith/keras/keras/src/utils/backend_utils_test.py +/Users/kharshith/keras/keras/src/utils/audio_dataset_utils_test.py +/Users/kharshith/keras/keras/src/tree/tree_test.py +/Users/kharshith/keras/keras/src/layers/convolutional/conv_transpose_test.py +/Users/kharshith/keras/keras/src/layers/convolutional/conv_test.py +/Users/kharshith/keras/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py +/Users/kharshith/keras/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py +/Users/kharshith/keras/keras/src/trainers/data_adapters/py_dataset_adapter_test.py +/Users/kharshith/keras/keras/src/layers/attention/multi_head_attention_test.py +/Users/kharshith/keras/keras/src/trainers/data_adapters/generator_data_adapter_test.py +/Users/kharshith/keras/keras/src/trainers/data_adapters/data_adapter_utils_test.py +/Users/kharshith/keras/keras/src/layers/attention/grouped_query_attention_test.py +/Users/kharshith/keras/keras/src/layers/attention/attention_test.py +/Users/kharshith/keras/keras/src/trainers/data_adapters/array_data_adapter_test.py +/Users/kharshith/keras/keras/src/layers/attention/additive_attention_test.py +/Users/kharshith/keras/keras/src/layers/activations/softmax_test.py +/Users/kharshith/keras/keras/src/trainers/compile_utils_test.py +/Users/kharshith/keras/keras/src/layers/activations/relu_test.py +/Users/kharshith/keras/keras/src/layers/activations/prelu_test.py +/Users/kharshith/keras/keras/src/testing/test_utils_test.py +/Users/kharshith/keras/keras/src/layers/activations/leaky_relu_test.py +/Users/kharshith/keras/keras/src/layers/activations/elu_test.py +/Users/kharshith/keras/keras/src/saving/serialization_lib_test.py +/Users/kharshith/keras/keras/src/layers/activations/activation_test.py +/Users/kharshith/keras/keras/src/saving/saving_lib_test.py +/Users/kharshith/keras/keras/src/initializers/random_initializers_test.py +/Users/kharshith/keras/keras/src/saving/saving_api_test.py +/Users/kharshith/keras/keras/src/initializers/constant_initializers_test.py +/Users/kharshith/keras/keras/src/saving/object_registration_test.py +/Users/kharshith/keras/keras/src/saving/file_editor_test.py +/Users/kharshith/keras/keras/src/export/tfsm_layer_test.py +/Users/kharshith/keras/keras/src/regularizers/regularizers_test.py +/Users/kharshith/keras/keras/src/export/saved_model_test.py +/Users/kharshith/keras/keras/src/random/seed_generator_test.py +/Users/kharshith/keras/keras/src/export/onnx_test.py +/Users/kharshith/keras/keras/src/random/random_test.py +/Users/kharshith/keras/keras/src/dtype_policies/dtype_policy_test.py +/Users/kharshith/keras/keras/src/quantizers/quantizers_test.py +/Users/kharshith/keras/keras/src/dtype_policies/dtype_policy_map_test.py +/Users/kharshith/keras/keras/src/optimizers/sgd_test.py +/Users/kharshith/keras/keras/src/distribution/distribution_lib_test.py +/Users/kharshith/keras/keras/src/optimizers/schedules/learning_rate_schedule_test.py +/Users/kharshith/keras/keras/src/optimizers/rmsprop_test.py +/Users/kharshith/keras/keras/src/optimizers/optimizer_test.py +/Users/kharshith/keras/keras/src/optimizers/optimizer_sparse_test.py +/Users/kharshith/keras/keras/src/optimizers/nadam_test.py +/Users/kharshith/keras/keras/src/optimizers/muon_test.py +/Users/kharshith/keras/keras/src/constraints/constraints_test.py +/Users/kharshith/keras/keras/src/optimizers/loss_scale_optimizer_test.py +/Users/kharshith/keras/keras/src/optimizers/lion_test.py +/Users/kharshith/keras/keras/src/callbacks/terminate_on_nan_test.py +/Users/kharshith/keras/keras/src/optimizers/lamb_test.py +/Users/kharshith/keras/keras/src/callbacks/tensorboard_test.py +/Users/kharshith/keras/keras/src/optimizers/ftrl_test.py +/Users/kharshith/keras/keras/src/callbacks/swap_ema_weights_test.py +/Users/kharshith/keras/keras/src/callbacks/remote_monitor_test.py +/Users/kharshith/keras/keras/src/optimizers/adamw_test.py +/Users/kharshith/keras/keras/src/callbacks/reduce_lr_on_plateau_test.py +/Users/kharshith/keras/keras/src/optimizers/adamax_test.py +/Users/kharshith/keras/keras/src/callbacks/monitor_callback_test.py +/Users/kharshith/keras/keras/src/optimizers/adam_test.py +/Users/kharshith/keras/keras/src/callbacks/model_checkpoint_test.py +/Users/kharshith/keras/keras/src/optimizers/adagrad_test.py +/Users/kharshith/keras/keras/src/callbacks/learning_rate_scheduler_test.py +/Users/kharshith/keras/keras/src/optimizers/adafactor_test.py +/Users/kharshith/keras/keras/src/callbacks/lambda_callback_test.py +/Users/kharshith/keras/keras/src/optimizers/adadelta_test.py +/Users/kharshith/keras/keras/src/callbacks/early_stopping_test.py +/Users/kharshith/keras/keras/src/ops/symbolic_arguments_test.py +/Users/kharshith/keras/keras/src/ops/ops_test.py +/Users/kharshith/keras/keras/src/callbacks/csv_logger_test.py +/Users/kharshith/keras/keras/src/ops/operation_utils_test.py +/Users/kharshith/keras/keras/src/callbacks/callback_test.py +/Users/kharshith/keras/keras/src/ops/operation_test.py +/Users/kharshith/keras/keras/src/ops/numpy_test.py +/Users/kharshith/keras/keras/src/callbacks/backup_and_restore_test.py +/Users/kharshith/keras/keras/src/ops/node_test.py +/Users/kharshith/keras/keras/src/ops/nn_test.py +/Users/kharshith/keras/keras/src/ops/linalg_test.py +/Users/kharshith/keras/keras/src/ops/image_test.py +/Users/kharshith/keras/keras/src/ops/function_test.py +/Users/kharshith/keras/keras/src/ops/einops_test.py +/Users/kharshith/keras/keras/src/ops/core_test.py +/Users/kharshith/keras/keras/src/backend/tests/device_scope_test.py +/Users/kharshith/keras/keras/src/backend/tests/compute_output_spec_test.py +/Users/kharshith/keras/keras/src/models/variable_mapping_test.py +/Users/kharshith/keras/keras/src/models/sequential_test.py +/Users/kharshith/keras/keras/src/backend/tensorflow/saved_model_test.py +/Users/kharshith/keras/keras/src/models/model_test.py +/Users/kharshith/keras/keras/src/models/functional_test.py +/Users/kharshith/keras/keras/src/backend/tensorflow/optimizer_distribute_test.py +/Users/kharshith/keras/keras/src/models/cloning_test.py +/Users/kharshith/keras/keras/src/backend/tensorflow/name_scope_test.py +/Users/kharshith/keras/keras/src/metrics/regression_metrics_test.py +/Users/kharshith/keras/keras/src/metrics/reduction_metrics_test.py +/Users/kharshith/keras/keras/src/metrics/probabilistic_metrics_test.py +/Users/kharshith/keras/keras/src/metrics/metric_test.py +/Users/kharshith/keras/keras/src/metrics/iou_metrics_test.py +/Users/kharshith/keras/keras/src/metrics/hinge_metrics_test.py +/Users/kharshith/keras/keras/src/metrics/f_score_metrics_test.py +/Users/kharshith/keras/keras/src/metrics/correlation_metrics_test.py +/Users/kharshith/keras/keras/src/metrics/confusion_metrics_test.py +/Users/kharshith/keras/keras/src/metrics/accuracy_metrics_test.py +/Users/kharshith/keras/keras/src/losses/losses_test.py +/Users/kharshith/keras/keras/src/losses/loss_test.py +/Users/kharshith/keras/keras/src/legacy/saving/legacy_h5_format_test.py +/Users/kharshith/keras/keras/src/legacy/saving/json_utils_test.py +/Users/kharshith/keras/keras/src/backend/jax/distribution_lib_test.py +/Users/kharshith/keras/keras/src/layers/rnn/time_distributed_test.py +/Users/kharshith/keras/keras/src/layers/rnn/stacked_rnn_cells_test.py +/Users/kharshith/keras/keras/src/backend/common/variables_test.py +/Users/kharshith/keras/keras/src/layers/rnn/simple_rnn_test.py +/Users/kharshith/keras/keras/src/backend/common/thread_safe_test.py +/Users/kharshith/keras/keras/src/layers/rnn/rnn_test.py +/Users/kharshith/keras/keras/src/backend/common/symbolic_scope_test.py +/Users/kharshith/keras/keras/src/backend/common/stateless_scope_test.py +/Users/kharshith/keras/keras/src/layers/rnn/lstm_test.py +/Users/kharshith/keras/keras/src/backend/common/remat_test.py +/Users/kharshith/keras/keras/src/layers/rnn/gru_test.py +/Users/kharshith/keras/keras/src/backend/common/name_scope_test.py +/Users/kharshith/keras/keras/src/layers/rnn/dropout_rnn_cell_test.py +/Users/kharshith/keras/keras/src/backend/common/masking_test.py +/Users/kharshith/keras/keras/src/backend/common/keras_tensor_test.py +/Users/kharshith/keras/keras/src/layers/rnn/conv_lstm_test.py +/Users/kharshith/keras/keras/src/layers/rnn/conv_lstm3d_test.py +/Users/kharshith/keras/keras/src/backend/common/global_state_test.py +/Users/kharshith/keras/keras/src/layers/rnn/conv_lstm2d_test.py +/Users/kharshith/keras/keras/src/layers/rnn/conv_lstm1d_test.py +/Users/kharshith/keras/keras/src/backend/common/compute_output_spec_test.py +/Users/kharshith/keras/keras/src/backend/common/backend_utils_test.py +/Users/kharshith/keras/keras/src/layers/rnn/bidirectional_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/zero_padding3d_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/zero_padding2d_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/zero_padding1d_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/up_sampling3d_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/up_sampling2d_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/up_sampling1d_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/reshape_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/repeat_vector_test.py +/Users/kharshith/keras/keras/src/applications/imagenet_utils_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/permute_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/flatten_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/cropping3d_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/cropping2d_test.py +/Users/kharshith/keras/keras/src/applications/applications_test.py +/Users/kharshith/keras/keras/src/layers/reshaping/cropping1d_test.py +/Users/kharshith/keras/keras/src/activations/activations_test.py +/Users/kharshith/keras/keras/src/layers/regularization/spatial_dropout_test.py +/Users/kharshith/keras/keras/src/layers/regularization/gaussian_noise_test.py +/Users/kharshith/keras/keras/src/layers/regularization/gaussian_dropout_test.py +/Users/kharshith/keras/keras/src/layers/regularization/dropout_test.py +/Users/kharshith/keras/keras/src/layers/regularization/alpha_dropout_test.py +/Users/kharshith/keras/keras/src/layers/regularization/activity_regularization_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/text_vectorization_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/string_lookup_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/stft_spectrogram_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/rescaling_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/pipeline_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/normalization_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/mel_spectrogram_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/integer_lookup_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/index_lookup_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/solarization_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_translation_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_sharpness_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_saturation_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_rotation_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_perspective_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_invert_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_hue_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_flip_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_erasing_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/random_brightness_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/mix_up_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/equalization_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/cut_mix_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/center_crop_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/iou_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/converters_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/auto_contrast_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/image_preprocessing/aug_mix_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/hashing_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/hashed_crossing_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/feature_space_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/discretization_test.py +/Users/kharshith/keras/keras/src/layers/preprocessing/category_encoding_test.py +/Users/kharshith/keras/keras/src/layers/pooling/max_pooling_test.py +/Users/kharshith/keras/keras/src/layers/pooling/global_max_pooling_test.py +/Users/kharshith/keras/keras/src/layers/pooling/global_average_pooling_test.py +/Users/kharshith/keras/keras/src/layers/pooling/average_pooling_test.py +/Users/kharshith/keras/keras/src/layers/normalization/unit_normalization_test.py +/Users/kharshith/keras/keras/src/layers/normalization/spectral_normalization_test.py +/Users/kharshith/keras/keras/src/layers/normalization/rms_normalization_test.py +/Users/kharshith/keras/keras/src/layers/normalization/layer_normalization_test.py +/Users/kharshith/keras/keras/src/layers/normalization/group_normalization_test.py +/Users/kharshith/keras/keras/src/layers/normalization/batch_normalization_test.py +/Users/kharshith/keras/keras/src/layers/merging/merging_test.py +/Users/kharshith/keras/keras/src/layers/layer_test.py +/Users/kharshith/keras/keras/src/layers/core/wrapper_test.py +/Users/kharshith/keras/keras/src/layers/core/masking_test.py +/Users/kharshith/keras/keras/src/layers/core/lambda_layer_test.py +/Users/kharshith/keras/keras/src/layers/core/input_layer_test.py +/Users/kharshith/keras/keras/src/layers/core/identity_test.py +/Users/kharshith/keras/keras/src/layers/core/embedding_test.py +/Users/kharshith/keras/keras/src/layers/core/einsum_dense_test.py +/Users/kharshith/keras/keras/src/layers/core/dense_test.py +/Users/kharshith/keras/keras/src/layers/convolutional/separable_conv_test.py +/Users/kharshith/keras/keras/src/layers/convolutional/depthwise_conv_test.py \ No newline at end of file From bb09e955804816ee26e5cfa9d6965201e5a0b7aa Mon Sep 17 00:00:00 2001 From: Harshith K Date: Wed, 9 Jul 2025 20:59:45 +0530 Subject: [PATCH 23/47] Added dtypes_test_TPU.py and dtypes_new_test.py, modified conftest.py --- conftest.py | 5 +++-- keras/src/backend/common/dtypes_new_test.py | 19 +++++++------------ keras/src/backend/common/dtypes_test.py | 2 +- keras/src/backend/common/dtypes_test_TPU.py | 3 ++- 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/conftest.py b/conftest.py index 83231954c58d..43682b9eed77 100644 --- a/conftest.py +++ b/conftest.py @@ -82,9 +82,10 @@ def _cleanup_tpu_state(): @pytest.fixture(scope="session") def tpu_strategy_fixture(): - import tensorflow as tf import time + import tensorflow as tf + os.environ["TPU_NAME"] = "harshith-tf-4" os.environ["JAX_PLATFORMS"] = "" max_retries = int(os.environ.get("TPU_MAX_RETRIES", "3")) @@ -127,4 +128,4 @@ def tpu(request): marker = request.node.get_closest_marker("requires_tpu") if marker: strategy = request.getfixturevalue("tpu_strategy_fixture") - request.node.cls.tpu_strategy = strategy \ No newline at end of file + request.node.cls.tpu_strategy = strategy diff --git a/keras/src/backend/common/dtypes_new_test.py b/keras/src/backend/common/dtypes_new_test.py index 00acfc59df3e..afd88c9a4a97 100644 --- a/keras/src/backend/common/dtypes_new_test.py +++ b/keras/src/backend/common/dtypes_new_test.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import pytest from absl.testing import parameterized from keras.src import backend @@ -8,8 +9,6 @@ from keras.src.testing import test_case from keras.src.testing.test_utils import named_product -os.environ["TPU_NAME"] = "harshith-tf-4" -os.environ["JAX_PLATFORMS"] = "" @pytest.mark.requires_tpu class DtypesTest(test_case.TestCase): @@ -40,15 +39,15 @@ class DtypesTest(test_case.TestCase): ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES] def setUp(self): + super().setUp() from jax.experimental import enable_x64 self.jax_enable_x64 = enable_x64() self.jax_enable_x64.__enter__() - return super().setUp() def tearDown(self): self.jax_enable_x64.__exit__(None, None, None) - return super().tearDown() + super().tearDown() @parameterized.named_parameters( named_product(dtype1=ALL_DTYPES, dtype2=[bool, int, float]) @@ -257,21 +256,17 @@ def test_result_type_with_tensor_on_tpu(self, dtype1, dtype2): """Test dtype result_type behavior specifically on TPU.""" import jax.numpy as jnp - def _test_on_tpu(): + with self.tpu_strategy.scope(): x1 = ops.ones((1,), dtype=dtype1) x2 = ops.ones((1,), dtype=dtype2) - result = ops.add(x1, x2) - out = backend.result_type(x1.dtype, x2.dtype) - return out, result.dtype - - with self.tpu_strategy.scope(): - out, result_dtype = _test_on_tpu() + result_dtype = result.dtype + self.assertIn("TPU", x1.device) x1_jax = jnp.ones((1,), dtype=dtype1) x2_jax = jnp.ones((1,), dtype=dtype2) expected = jnp.result_type(x1_jax, x2_jax).name self.assertEqual(out, expected) - self.assertEqual(result_dtype, expected) \ No newline at end of file + self.assertEqual(result_dtype, expected) diff --git a/keras/src/backend/common/dtypes_test.py b/keras/src/backend/common/dtypes_test.py index 8e34e7b03f79..0f1ca48b6413 100644 --- a/keras/src/backend/common/dtypes_test.py +++ b/keras/src/backend/common/dtypes_test.py @@ -387,4 +387,4 @@ def test_invalid_float8_dtype(self): with self.assertRaisesRegex( ValueError, "There is no implicit conversions from float8 dtypes" ): - dtypes.result_type("float8_e5m2", "bfloat16") \ No newline at end of file + dtypes.result_type("float8_e5m2", "bfloat16") diff --git a/keras/src/backend/common/dtypes_test_TPU.py b/keras/src/backend/common/dtypes_test_TPU.py index 6a89462a8621..c41e6bb84b36 100644 --- a/keras/src/backend/common/dtypes_test_TPU.py +++ b/keras/src/backend/common/dtypes_test_TPU.py @@ -78,6 +78,7 @@ def test_result_type_with_tensor_on_tpu(self, dtype1, dtype2): def _test_on_tpu(): x1 = ops.ones((1,), dtype=dtype1) x2 = ops.ones((1,), dtype=dtype2) + self.assertIn("TPU", x1.device) result = ops.add(x1, x2) @@ -92,4 +93,4 @@ def _test_on_tpu(): expected = jnp.result_type(x1_jax, x2_jax).name self.assertEqual(out, expected) - self.assertEqual(result_dtype, expected) \ No newline at end of file + self.assertEqual(result_dtype, expected) From 8a63d09a6d0f478e4e328104ced90560d6e42fa7 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Wed, 23 Jul 2025 14:50:19 +0530 Subject: [PATCH 24/47] Added Dcokerfile and tests list command --- .../workflows/tpu/Dockerfile => Dockerfile | 0 list_tests.sh | 3 + tests_list.txt | 259 ++++++++++++++++++ 3 files changed, 262 insertions(+) rename .github/workflows/tpu/Dockerfile => Dockerfile (100%) create mode 100644 list_tests.sh create mode 100644 tests_list.txt diff --git a/.github/workflows/tpu/Dockerfile b/Dockerfile similarity index 100% rename from .github/workflows/tpu/Dockerfile rename to Dockerfile diff --git a/list_tests.sh b/list_tests.sh new file mode 100644 index 000000000000..e8093647b8c4 --- /dev/null +++ b/list_tests.sh @@ -0,0 +1,3 @@ +#! /bin/bash + +pytest --collect-only -q | grep "//' | sort -u | xargs -I {} find . -name "{}" -print diff --git a/tests_list.txt b/tests_list.txt new file mode 100644 index 000000000000..3dc80dea9928 --- /dev/null +++ b/tests_list.txt @@ -0,0 +1,259 @@ +./keras/src/metrics/accuracy_metrics_test.py +./keras/src/layers/activations/activation_test.py +./keras/src/activations/activations_test.py +./keras/src/layers/regularization/activity_regularization_test.py +./keras/src/optimizers/adadelta_test.py +./keras/src/optimizers/adafactor_test.py +./keras/src/optimizers/adagrad_test.py +./keras/src/optimizers/adam_test.py +./keras/src/optimizers/adamax_test.py +./keras/src/optimizers/adamw_test.py +./keras/src/layers/attention/additive_attention_test.py +./keras/src/layers/regularization/alpha_dropout_test.py +./keras/src/applications/applications_test.py +./keras/src/trainers/data_adapters/array_data_adapter_test.py +./keras/src/layers/attention/attention_test.py +./keras/src/utils/audio_dataset_utils_test.py +./keras/src/layers/preprocessing/image_preprocessing/aug_mix_test.py +./keras/src/layers/preprocessing/image_preprocessing/auto_contrast_test.py +./keras/src/layers/pooling/average_pooling_test.py +./keras/src/utils/backend_utils_test.py +./keras/src/backend/common/backend_utils_test.py +./keras/src/callbacks/backup_and_restore_test.py +./keras/src/layers/normalization/batch_normalization_test.py +./keras/src/layers/rnn/bidirectional_test.py +./integration_tests/dataset_tests/boston_housing_test.py +./integration_tests/dataset_tests/california_housing_test.py +./keras/src/callbacks/callback_test.py +./keras/src/layers/preprocessing/category_encoding_test.py +./keras/src/layers/preprocessing/image_preprocessing/center_crop_test.py +./integration_tests/dataset_tests/cifar10_test.py +./integration_tests/dataset_tests/cifar100_test.py +./keras/src/models/cloning_test.py +./keras/src/utils/code_stats_test.py +./keras/src/trainers/compile_utils_test.py +./keras/src/backend/tests/compute_output_spec_test.py +./keras/src/backend/common/compute_output_spec_test.py +./keras/src/metrics/confusion_metrics_test.py +./keras/src/initializers/constant_initializers_test.py +./keras/src/constraints/constraints_test.py +./keras/src/layers/rnn/conv_lstm_test.py +./keras/src/layers/rnn/conv_lstm1d_test.py +./keras/src/layers/rnn/conv_lstm2d_test.py +./keras/src/layers/rnn/conv_lstm3d_test.py +./keras/src/layers/convolutional/conv_test.py +./keras/src/layers/convolutional/conv_transpose_test.py +./keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/converters_test.py +./keras/src/ops/core_test.py +./keras/src/metrics/correlation_metrics_test.py +./keras/src/layers/reshaping/cropping1d_test.py +./keras/src/layers/reshaping/cropping2d_test.py +./keras/src/layers/reshaping/cropping3d_test.py +./keras/src/callbacks/csv_logger_test.py +./keras/src/layers/preprocessing/image_preprocessing/cut_mix_test.py +./keras/src/trainers/data_adapters/data_adapter_utils_test.py +./keras/src/utils/dataset_utils_test.py +./keras/src/layers/core/dense_test.py +./keras/src/layers/convolutional/depthwise_conv_test.py +./keras/src/backend/tests/device_scope_test.py +./keras/src/layers/preprocessing/discretization_test.py +./keras/src/backend/tensorflow/distribute_test.py +./keras/src/distribution/distribution_lib_test.py +./keras/src/backend/jax/distribution_lib_test.py +./keras/src/layers/rnn/dropout_rnn_cell_test.py +./keras/src/layers/regularization/dropout_test.py +./keras/src/dtype_policies/dtype_policy_map_test.py +./keras/src/dtype_policies/dtype_policy_test.py +./keras/src/utils/dtype_utils_test.py +./keras/src/backend/common/dtypes_test.py +./keras/src/callbacks/early_stopping_test.py +./keras/src/ops/einops_test.py +./keras/src/layers/core/einsum_dense_test.py +./keras/src/layers/activations/elu_test.py +./keras/src/layers/core/embedding_test.py +./keras/src/trainers/epoch_iterator_test.py +./keras/src/layers/preprocessing/image_preprocessing/equalization_test.py +./keras/src/metrics/f_score_metrics_test.py +./integration_tests/dataset_tests/fashion_mnist_test.py +./keras/src/layers/preprocessing/feature_space_test.py +./keras/src/saving/file_editor_test.py +./keras/src/utils/file_utils_test.py +./keras/src/layers/reshaping/flatten_test.py +./keras/src/optimizers/ftrl_test.py +./keras/src/ops/function_test.py +./keras/src/models/functional_test.py +./keras/src/layers/regularization/gaussian_dropout_test.py +./keras/src/layers/regularization/gaussian_noise_test.py +./keras/src/trainers/data_adapters/generator_data_adapter_test.py +./keras/src/layers/pooling/global_average_pooling_test.py +./keras/src/layers/pooling/global_max_pooling_test.py +./keras/src/backend/common/global_state_test.py +./keras/src/layers/normalization/group_normalization_test.py +./keras/src/layers/attention/grouped_query_attention_test.py +./keras/src/layers/rnn/gru_test.py +./keras/src/layers/preprocessing/hashed_crossing_test.py +./keras/src/layers/preprocessing/hashing_test.py +./keras/src/metrics/hinge_metrics_test.py +./keras/src/layers/core/identity_test.py +./keras/src/utils/image_dataset_utils_test.py +./keras/src/ops/image_test.py +./keras/src/applications/imagenet_utils_test.py +./integration_tests/dataset_tests/imdb_test.py +./integration_tests/import_test.py +./keras/src/layers/preprocessing/index_lookup_test.py +./keras/src/layers/core/input_layer_test.py +./keras/src/layers/preprocessing/integer_lookup_test.py +./keras/src/utils/io_utils_test.py +./keras/src/metrics/iou_metrics_test.py +./keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/iou_test.py +./integration_tests/jax_custom_fit_test.py +./keras/src/utils/jax_layer_test.py +./keras/src/legacy/saving/json_utils_test.py +./keras/src/backend/common/keras_tensor_test.py +./keras/src/optimizers/lamb_test.py +./keras/src/callbacks/lambda_callback_test.py +./keras/src/layers/core/lambda_layer_test.py +./keras/src/layers/normalization/layer_normalization_test.py +./keras/src/layers/layer_test.py +./keras/src/layers/activations/leaky_relu_test.py +./keras/src/optimizers/schedules/learning_rate_schedule_test.py +./keras/src/callbacks/learning_rate_scheduler_test.py +./keras/src/legacy/saving/legacy_h5_format_test.py +./keras/src/ops/linalg_test.py +./keras/src/optimizers/lion_test.py +./keras/src/optimizers/loss_scale_optimizer_test.py +./keras/src/losses/loss_test.py +./keras/src/losses/losses_test.py +./keras/src/layers/rnn/lstm_test.py +./keras/src/layers/core/masking_test.py +./keras/src/backend/common/masking_test.py +./keras/src/ops/math_test.py +./keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box_test.py +./keras/src/layers/pooling/max_pooling_test.py +./keras/src/layers/preprocessing/mel_spectrogram_test.py +./keras/src/layers/merging/merging_test.py +./keras/src/metrics/metric_test.py +./keras/src/layers/preprocessing/image_preprocessing/mix_up_test.py +./integration_tests/dataset_tests/mnist_test.py +./keras/src/callbacks/model_checkpoint_test.py +./keras/src/models/model_test.py +./integration_tests/model_visualization_test.py +./keras/src/callbacks/monitor_callback_test.py +./keras/src/layers/attention/multi_head_attention_test.py +./keras/src/optimizers/muon_test.py +./keras/src/optimizers/nadam_test.py +./keras/src/backend/common/name_scope_test.py +./keras/src/backend/tensorflow/name_scope_test.py +./keras/src/utils/naming_test.py +./keras/src/ops/nn_test.py +./keras/src/ops/node_test.py +./keras/src/layers/preprocessing/normalization_test.py +./keras/src/utils/numerical_utils_test.py +./keras/src/ops/numpy_test.py +./keras/src/saving/object_registration_test.py +./keras/src/export/onnx_test.py +./keras/src/ops/operation_test.py +./keras/src/ops/operation_utils_test.py +./keras/src/ops/ops_test.py +./keras/src/backend/tensorflow/optimizer_distribute_test.py +./keras/src/optimizers/optimizer_sparse_test.py +./keras/src/optimizers/optimizer_test.py +./keras/src/layers/reshaping/permute_test.py +./keras/src/layers/preprocessing/pipeline_test.py +./keras/src/layers/activations/prelu_test.py +./keras/src/metrics/probabilistic_metrics_test.py +./keras/src/trainers/data_adapters/py_dataset_adapter_test.py +./keras/src/utils/python_utils_test.py +./keras/src/quantizers/quantizers_test.py +./keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_brightness_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_erasing_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_flip_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_hue_test.py +./keras/src/initializers/random_initializers_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_invert_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_perspective_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_rotation_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_saturation_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_sharpness_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py +./keras/src/random/random_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_translation_test.py +./keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py +./keras/src/callbacks/reduce_lr_on_plateau_test.py +./keras/src/metrics/reduction_metrics_test.py +./keras/src/metrics/regression_metrics_test.py +./keras/src/regularizers/regularizers_test.py +./keras/src/layers/activations/relu_test.py +./keras/src/backend/common/remat_test.py +./keras/src/callbacks/remote_monitor_test.py +./keras/src/layers/reshaping/repeat_vector_test.py +./keras/src/layers/preprocessing/rescaling_test.py +./keras/src/layers/reshaping/reshape_test.py +./keras/src/layers/preprocessing/image_preprocessing/resizing_test.py +./integration_tests/dataset_tests/reuters_test.py +./keras/src/layers/normalization/rms_normalization_test.py +./keras/src/optimizers/rmsprop_test.py +./keras/src/utils/rng_utils_test.py +./keras/src/layers/rnn/rnn_test.py +./keras/src/backend/tensorflow/saved_model_test.py +./keras/src/export/saved_model_test.py +./keras/src/saving/saving_api_test.py +./keras/src/saving/saving_lib_test.py +./keras/src/random/seed_generator_test.py +./keras/src/layers/convolutional/separable_conv_test.py +./keras/src/utils/sequence_utils_test.py +./keras/src/models/sequential_test.py +./keras/src/saving/serialization_lib_test.py +./keras/src/optimizers/sgd_test.py +./keras/src/layers/rnn/simple_rnn_test.py +./keras/src/wrappers/sklearn_test.py +./keras/src/layers/activations/softmax_test.py +./keras/src/layers/preprocessing/image_preprocessing/solarization_test.py +./keras/src/layers/regularization/spatial_dropout_test.py +./keras/src/layers/normalization/spectral_normalization_test.py +./keras/src/layers/rnn/stacked_rnn_cells_test.py +./keras/src/backend/common/stateless_scope_test.py +./keras/src/layers/preprocessing/stft_spectrogram_test.py +./keras/src/layers/preprocessing/string_lookup_test.py +./keras/src/utils/summary_utils_test.py +./keras/src/callbacks/swap_ema_weights_test.py +./keras/src/ops/symbolic_arguments_test.py +./keras/src/backend/common/symbolic_scope_test.py +./keras/src/callbacks/tensorboard_test.py +./keras/src/callbacks/terminate_on_nan_test.py +./keras/src/testing/test_utils_test.py +./keras/src/utils/text_dataset_utils_test.py +./keras/src/layers/preprocessing/text_vectorization_test.py +./integration_tests/tf_custom_fit_test.py +./keras/src/trainers/data_adapters/tf_dataset_adapter_test.py +./integration_tests/tf_distribute_training_test.py +./keras/src/export/tfsm_layer_test.py +./keras/src/backend/common/thread_safe_test.py +./keras/src/layers/rnn/time_distributed_test.py +./keras/src/utils/timeseries_dataset_utils_test.py +./integration_tests/torch_custom_fit_test.py +./keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py +./keras/src/utils/torch_utils_test.py +./integration_tests/torch_workflow_test.py +./keras/src/utils/tracking_test.py +./keras/src/trainers/trainer_test.py +./keras/src/tree/tree_test.py +./keras/src/layers/normalization/unit_normalization_test.py +./keras/src/layers/reshaping/up_sampling1d_test.py +./keras/src/layers/reshaping/up_sampling2d_test.py +./keras/src/layers/reshaping/up_sampling3d_test.py +./keras/src/models/variable_mapping_test.py +./keras/src/backend/common/variables_test.py +./keras/src/layers/core/wrapper_test.py +./keras/src/layers/reshaping/zero_padding1d_test.py +./keras/src/layers/reshaping/zero_padding2d_test.py +./keras/src/layers/reshaping/zero_padding3d_test.py From 4651454661de269dad8c15d22fea361c254e2a26 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Mon, 28 Jul 2025 12:26:05 +0530 Subject: [PATCH 25/47] Updated Dockerfile --- Dockerfile | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Dockerfile b/Dockerfile index 7d0eeb2280f1..bdd81ab8f228 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ -FROM python:3.10-slim +FROM --platform=linux/amd64 python:3.10-slim -ENV KERAS_HOME=/github/workspace/.github/workflows/config/tensorflow \ - KERAS_BACKEND=tensorflow +ENV KERAS_HOME=/github/workspace/.github/workflows/config/jax \ + KERAS_BACKEND=jax RUN apt-get update && apt-get install -y --no-install-recommends \ git \ @@ -13,16 +13,16 @@ COPY . /github/workspace WORKDIR /github/workspace # Create and activate venv, install pip/setuptools/psutil, then run tests -RUN cd src/github/keras && \ - pip install -U pip setuptools && \ - pip install -U psutil && \ - pip install -r requirements-tensorflow-tpu.txt && \ +# RUN cd ./keras/src/github/keras && \ +RUN pip install --no-cache-dir -U pip setuptools && \ + pip install --no-cache-dir -U psutil && \ + pip install --no-cache-dir -r requirements-jax-tpu.txt && \ pip uninstall -y keras keras-nightly && \ - python3 -c 'import tensorflow as tf;print(tf.__version__);print(tf.config.list_physical_devices("TPU"))' && \ - python3 -c 'import tensorflow as tf;assert len(tf.config.list_physical_devices("TPU")) > 0' && \ + python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())' && \ + python3 -c 'import jax;assert jax.default_backend().lower() == "tpu"' && \ pytest keras --ignore keras/src/applications \ --ignore keras/src/layers/merging/merging_test.py \ --cov=keras \ --cov-config=pyproject.toml -CMD ["bash"] \ No newline at end of file +CMD ["/bin/bash"] From 40af241265439b873fa3482aaeb653678da533f8 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Mon, 28 Jul 2025 12:53:02 +0530 Subject: [PATCH 26/47] Restored Dockerfile to previous changes --- Dockerfile | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index bdd81ab8f228..324f3798742f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,12 +17,12 @@ WORKDIR /github/workspace RUN pip install --no-cache-dir -U pip setuptools && \ pip install --no-cache-dir -U psutil && \ pip install --no-cache-dir -r requirements-jax-tpu.txt && \ - pip uninstall -y keras keras-nightly && \ - python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())' && \ - python3 -c 'import jax;assert jax.default_backend().lower() == "tpu"' && \ - pytest keras --ignore keras/src/applications \ - --ignore keras/src/layers/merging/merging_test.py \ - --cov=keras \ - --cov-config=pyproject.toml + pip uninstall -y keras keras-nightly + # python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())' && \ + # python3 -c 'import jax;assert jax.default_backend().lower() == "tpu"' && \ + # pytest keras --ignore keras/src/applications \ + # --ignore keras/src/layers/merging/merging_test.py \ + # --cov=keras \ + # --cov-config=pyproject.toml CMD ["/bin/bash"] From 64420d5e0d4819e7d67eab899273d58fdb98ccd1 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Mon, 28 Jul 2025 13:19:53 +0530 Subject: [PATCH 27/47] updated actions.yml file to install and configure docker engine on self hosted runner, build the image and check TPU support on jax backend --- .github/workflows/actions.yml | 82 ++++++++++++++++++++++++++++++----- 1 file changed, 71 insertions(+), 11 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index d0a6785af23b..49ad4a60f937 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -11,12 +11,12 @@ permissions: contents: read jobs: - tpu_build: + build-and-test-on-tpu: strategy: fail-fast: false matrix: python-version: ['3.10'] - backend: [tensorflow] + backend: [jax] name: Run TPU tests runs-on: # - linux-x86-ct5lp-112-4tpu @@ -25,23 +25,83 @@ jobs: # - linux-x86-ct6e-44-1tpu-4khbn-runner-x4st4 # - linux-x86-ct6e-44-1tpu-4khbn-runner-45nmc - container: - image: docker:latest env: PYTHON: ${{ matrix.python-version }} KERAS_HOME: .github/workflows/config/${{ matrix.backend }} - KERAS_BACKEND: tensorflow + KERAS_BACKEND: jax + PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID + GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) + IMAGE_NAME: keras-jax-tpu-amd64:latest # Name of your Docker image + TPU_VM_NAME: kharshith-jax-tpu # Replace with your TPU VM instance name + TPU_VM_ZONE: us-central1-b # Replace with your TPU VM zone + steps: - - uses: actions/checkout@v4 - - name: Build and run Docker image for TPU tests + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Install Docker (if not present) run: | + # Check if docker is already installed + if ! command -v docker &> /dev/null + then + echo "Docker not found. Installing Docker..." + # Update apt package index + sudo apt-get update + # Install packages to allow apt to use a repository over HTTPS + sudo apt-get install -y ca-certificates curl gnupg + # Add Docker's official GPG key + sudo install -m 0755 -d /etc/apt/keyrings + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg + sudo chmod a+r /etc/apt/keyrings/docker.gpg + # Add the repository to Apt sources + echo \ + "deb [arch=\"$(dpkg --print-architecture)\" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + \"$(. /etc/os-release && echo \"$VERSION_CODENAME\")\" stable" | \ + sudo tee /etc/apt/sources.list.d/docker.list > /dev/null + # Install Docker Engine, containerd, and Docker Compose + sudo apt-get update + sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin + # Add the current user (runner user) to the docker group to run docker without sudo + sudo usermod -aG docker $USER + # You might need to log out and log back in for group changes to take effect, + # or restart the Docker daemon and the runner agent. + # For a CI environment, `newgrp docker` might work temporarily or a restart is implied. + sudo systemctl start docker + sudo systemctl enable docker + echo "Docker installed." + else + echo "Docker is already installed." + fi + + - name: Set up Docker BuildX + uses: docker/setup-buildx-action@v3 + + + - name: Build Docker image for TPU tests + run: | + echo "Building Docker image using Dockerfile at .github/workflows/tpu/Dockerfile..." + # Use 'sudo docker' if the 'docker' group membership hasn't fully applied yet. docker build -f .github/workflows/tpu/Dockerfile -t keras-tpu-test . + echo "Docker image built successfully." + + - name: Run Docker container and execute tests on TPU + run: | + echo "Running Docker container with TPU access and executing tests..." + # Use 'sudo docker' if the 'docker' group membership hasn't fully applied yet. docker run --rm \ - -e PYTHON=${{ matrix.python-version }} \ - -e KERAS_HOME=.github/workflows/config/${{ matrix.backend }} \ - -e KERAS_BACKEND=tensorflow \ - keras-tpu-test + --privileged \ + --network host \ + -e PYTHON=${{ env.PYTHON }} \ + -e KERAS_HOME=${{ env.KERAS_HOME }} \ + -e KERAS_BACKEND=${{ env.KERAS_BACKEND }} \ + keras-tpu-test \ + /bin/bash -c "\ + echo 'Verifying JAX TPU backend inside container...' && \ + python3 -c 'import jax; print(\"JAX Version:\", jax.__version__); print(\"Default Backend:\", jax.default_backend()); assert jax.default_backend().lower() == \"tpu\", \"TPU backend not found or not default\"; print(\"TPU verification successful!\")' \ + # Add your actual pytest command here. Ensure pytest is installed inside your Docker image. + " + echo "Docker container finished running tests." build: strategy: From d69277d1ef72eb5cccd883aa6702bdfed16a581c Mon Sep 17 00:00:00 2001 From: Harshith K Date: Mon, 28 Jul 2025 13:26:32 +0530 Subject: [PATCH 28/47] updated actions.yml file to include container option --- .github/workflows/actions.yml | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 19c8bec983de..384acfa6e837 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -27,7 +27,23 @@ jobs: - linux-x86-ct6e-44-1tpu # - linux-x86-ct6e-44-1tpu-4khbn-runner-x4st4 # - linux-x86-ct6e-44-1tpu-4khbn-runner-45nmc - + + + container: + # Use an official Docker image that includes the Docker CLI. + # This allows you to run 'docker' commands from within this job's container. + # 'docker:latest' is a good choice. You could also specify a version like 'docker:24.0.5'. + image: docker:latest + # Mount the host's Docker socket into this container. + # This is CRUCIAL: It allows 'docker' commands executed *inside* this container + # to control the *host's* Docker daemon. + volumes: + - /var/run/docker.sock:/var/run/docker.sock + # Running this "controlling" container in privileged mode is often necessary + # when it needs to manage other containers and access host resources like TPUs + # through the host's Docker daemon. + options: --privileged + env: PYTHON: ${{ matrix.python-version }} KERAS_HOME: .github/workflows/config/${{ matrix.backend }} From 1c307fcfe1b117a1c8aebd1133e1d1e9f995e523 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Mon, 28 Jul 2025 13:29:49 +0530 Subject: [PATCH 29/47] updated actions.yml file to include container option without volume binding --- .github/workflows/actions.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 384acfa6e837..c0f2057cc608 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -37,8 +37,8 @@ jobs: # Mount the host's Docker socket into this container. # This is CRUCIAL: It allows 'docker' commands executed *inside* this container # to control the *host's* Docker daemon. - volumes: - - /var/run/docker.sock:/var/run/docker.sock + # volumes: + # - /var/run/docker.sock:/var/run/docker.sock # Running this "controlling" container in privileged mode is often necessary # when it needs to manage other containers and access host resources like TPUs # through the host's Docker daemon. From 693886b09627bd0331778ff1f7b5988da0fe1319 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Mon, 28 Jul 2025 13:31:32 +0530 Subject: [PATCH 30/47] updated actions.yml file to change TPU --- .github/workflows/actions.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index c0f2057cc608..90694901fcea 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -22,9 +22,9 @@ jobs: backend: [jax] name: Run TPU tests runs-on: - # - linux-x86-ct5lp-112-4tpu + - linux-x86-ct5lp-112-4tpu # - linux-x86-ct5lp-112-4tpu-fvn6n-runner-6kb8n - - linux-x86-ct6e-44-1tpu + # - linux-x86-ct6e-44-1tpu # - linux-x86-ct6e-44-1tpu-4khbn-runner-x4st4 # - linux-x86-ct6e-44-1tpu-4khbn-runner-45nmc From e74b851f9422bc685f819f24c7fc18ebe6d388e5 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Tue, 29 Jul 2025 10:04:16 +0530 Subject: [PATCH 31/47] Updated container path in build-and-test-on-tpu job --- .github/workflows/actions.yml | 173 ++++++++++++++++------------------ 1 file changed, 79 insertions(+), 94 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 90694901fcea..f44d949c4375 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -12,6 +12,18 @@ on: permissions: contents: read + id-token: write + +env: + PYTHON: ${{ matrix.python-version }} + KERAS_HOME: .github/workflows/config/${{ matrix.backend }} + KERAS_BACKEND: jax + PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID + GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) + IMAGE_REPO: keras-docker-images + IMAGE_NAME: keras-jax-tpu-amd64:latest # Name of your Docker image + TPU_VM_NAME: kharshith-jax-tpu # Replace with your TPU VM instance name + TPU_VM_ZONE: us-central1-b # Replace with your TPU VM zone jobs: build-and-test-on-tpu: @@ -22,105 +34,78 @@ jobs: backend: [jax] name: Run TPU tests runs-on: - - linux-x86-ct5lp-112-4tpu + # - keras-jax-tpu-runner + # - linux-x86-ct5lp-112-4tpu # - linux-x86-ct5lp-112-4tpu-fvn6n-runner-6kb8n - # - linux-x86-ct6e-44-1tpu + - linux-x86-ct6e-44-1tpu # - linux-x86-ct6e-44-1tpu-4khbn-runner-x4st4 # - linux-x86-ct6e-44-1tpu-4khbn-runner-45nmc + container: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest + + # container: + # image: docker:latest # Provides the Docker CLI within the job container + # volumes: + # - /var/run/docker.sock:/var/run/docker.sock # Mounts host's Docker socket for control + # options: --privileged + + # steps: + # - name: Checkout Repository + # uses: actions/checkout@v4 + + # - name: Set up Docker BuildX + # uses: docker/setup-buildx-action@v3 + + # - name: Authenticate to Google Cloud (Workload Identity Federation) + # id: 'auth' + # uses: 'google-github-actions/auth@v2' + # with: + # # Replace with your Workload Identity Federation provider details. + # # This service account needs 'Artifact Registry Writer' role. + # workload_identity_provider: 'projects/YOUR_PROJECT_NUMBER/locations/global/workloadIdentityPools/YOUR_POOL_ID/providers/YOUR_PROVIDER_ID' + # service_account: 'your-github-actions-sa@${{ env.PROJECT_ID }}.iam.gserviceaccount.com' + + # - name: Configure Docker to use Google Artifact Registry + # run: gcloud auth configure-docker ${{ env.GAR_LOCATION }}-docker.pkg.dev + + # - name: Build Docker Image + # run: | + # IMAGE_TAG="${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:${{ github.sha }}" + # echo "Building Docker image: $IMAGE_TAG" + # docker build \ + # --platform=linux/amd64 \ + # -f .github/workflows/tpu/Dockerfile \ + # -t "$IMAGE_TAG" \ + # -t "${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:latest" \ + # . + # echo "Built Docker image: $IMAGE_TAG" + # echo "LOCAL_TEST_IMAGE_TAG=$IMAGE_TAG" >> $GITHUB_ENV # Store for immediate use in run step + + # - name: Push Docker Image to Artifact Registry + # run: | + # echo "Pushing Docker image to Artifact Registry: ${{ env.LOCAL_TEST_IMAGE_TAG }}" + # docker push "${{ env.LOCAL_TEST_IMAGE_TAG }}" + # docker push "${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:latest" + # echo "Pushed Docker image." + + # - name: Run Docker container and execute tests on TPU + # run: | + # echo "Running Docker container with TPU access and executing tests..." + # docker run --rm \ + # --privileged \ + # --network host \ + # -e PYTHON=3.10 \ # Use a specific version or derive from matrix + # -e KERAS_HOME=.github/workflows/config/jax \ + # -e KERAS_BACKEND=jax \ + # ${{ env.LOCAL_TEST_IMAGE_TAG }} \ + # /bin/bash -c ' \ + # echo "Verifying JAX TPU backend inside container..." && \ + # python3 -c "import jax; print(\"JAX Version:\", jax.__version__); print(\"Default Backend:\", jax.default_backend()); assert jax.default_backend().lower() == \"tpu\", \"TPU backend not found or not default\"; print(\"TPU verification successful!\")" \ + # # Add your actual pytest command here. Ensure pytest is installed inside your Docker image. + # # && pytest keras --ignore keras/src/applications --ignore keras/src/layers/merging/merging_test.py --cov=keras --cov-config=pyproject.toml + # ' + # echo "Docker container finished running tests." - container: - # Use an official Docker image that includes the Docker CLI. - # This allows you to run 'docker' commands from within this job's container. - # 'docker:latest' is a good choice. You could also specify a version like 'docker:24.0.5'. - image: docker:latest - # Mount the host's Docker socket into this container. - # This is CRUCIAL: It allows 'docker' commands executed *inside* this container - # to control the *host's* Docker daemon. - # volumes: - # - /var/run/docker.sock:/var/run/docker.sock - # Running this "controlling" container in privileged mode is often necessary - # when it needs to manage other containers and access host resources like TPUs - # through the host's Docker daemon. - options: --privileged - - env: - PYTHON: ${{ matrix.python-version }} - KERAS_HOME: .github/workflows/config/${{ matrix.backend }} - KERAS_BACKEND: jax - PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID - GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) - IMAGE_NAME: keras-jax-tpu-amd64:latest # Name of your Docker image - TPU_VM_NAME: kharshith-jax-tpu # Replace with your TPU VM instance name - TPU_VM_ZONE: us-central1-b # Replace with your TPU VM zone - - steps: - - - name: Checkout Repository - uses: actions/checkout@v4 - - - name: Install Docker (if not present) - run: | - # Check if docker is already installed - if ! command -v docker &> /dev/null - then - echo "Docker not found. Installing Docker..." - # Update apt package index - sudo apt-get update - # Install packages to allow apt to use a repository over HTTPS - sudo apt-get install -y ca-certificates curl gnupg - # Add Docker's official GPG key - sudo install -m 0755 -d /etc/apt/keyrings - curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg - sudo chmod a+r /etc/apt/keyrings/docker.gpg - # Add the repository to Apt sources - echo \ - "deb [arch=\"$(dpkg --print-architecture)\" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ - \"$(. /etc/os-release && echo \"$VERSION_CODENAME\")\" stable" | \ - sudo tee /etc/apt/sources.list.d/docker.list > /dev/null - # Install Docker Engine, containerd, and Docker Compose - sudo apt-get update - sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin - # Add the current user (runner user) to the docker group to run docker without sudo - sudo usermod -aG docker $USER - # You might need to log out and log back in for group changes to take effect, - # or restart the Docker daemon and the runner agent. - # For a CI environment, `newgrp docker` might work temporarily or a restart is implied. - sudo systemctl start docker - sudo systemctl enable docker - echo "Docker installed." - else - echo "Docker is already installed." - fi - - - name: Set up Docker BuildX - uses: docker/setup-buildx-action@v3 - - - - name: Build Docker image for TPU tests - run: | - echo "Building Docker image using Dockerfile at .github/workflows/tpu/Dockerfile..." - # Use 'sudo docker' if the 'docker' group membership hasn't fully applied yet. - docker build -f .github/workflows/tpu/Dockerfile -t keras-tpu-test . - echo "Docker image built successfully." - - - name: Run Docker container and execute tests on TPU - run: | - echo "Running Docker container with TPU access and executing tests..." - # Use 'sudo docker' if the 'docker' group membership hasn't fully applied yet. - docker run --rm \ - --privileged \ - --network host \ - -e PYTHON=${{ env.PYTHON }} \ - -e KERAS_HOME=${{ env.KERAS_HOME }} \ - -e KERAS_BACKEND=${{ env.KERAS_BACKEND }} \ - keras-tpu-test \ - /bin/bash -c "\ - echo 'Verifying JAX TPU backend inside container...' && \ - python3 -c 'import jax; print(\"JAX Version:\", jax.__version__); print(\"Default Backend:\", jax.default_backend()); assert jax.default_backend().lower() == \"tpu\", \"TPU backend not found or not default\"; print(\"TPU verification successful!\")' \ - # Add your actual pytest command here. Ensure pytest is installed inside your Docker image. - " - echo "Docker container finished running tests." build: strategy: From d31b3c4c2c917e895c3b2bc0263222a498a1dec0 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Tue, 29 Jul 2025 12:16:03 +0530 Subject: [PATCH 32/47] seperated TPU workflow from actions.yml --- .github/workflows/actions.yml | 182 ++++++++++++++++---------------- .github/workflows/tpu-tests.yml | 72 +++++++++++++ 2 files changed, 163 insertions(+), 91 deletions(-) create mode 100644 .github/workflows/tpu-tests.yml diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index f44d949c4375..44d5b69dd2a3 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -12,99 +12,99 @@ on: permissions: contents: read - id-token: write - -env: - PYTHON: ${{ matrix.python-version }} - KERAS_HOME: .github/workflows/config/${{ matrix.backend }} - KERAS_BACKEND: jax - PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID - GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) - IMAGE_REPO: keras-docker-images - IMAGE_NAME: keras-jax-tpu-amd64:latest # Name of your Docker image - TPU_VM_NAME: kharshith-jax-tpu # Replace with your TPU VM instance name - TPU_VM_ZONE: us-central1-b # Replace with your TPU VM zone + # id-token: write + +# env: +# PYTHON: ${{ matrix.python-version }} +# KERAS_HOME: .github/workflows/config/${{ matrix.backend }} +# KERAS_BACKEND: jax +# PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID +# GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) +# IMAGE_REPO: keras-docker-images +# IMAGE_NAME: keras-jax-tpu-amd64:latest # Name of your Docker image +# TPU_VM_NAME: kharshith-jax-tpu # Replace with your TPU VM instance name +# TPU_VM_ZONE: us-central1-b # Replace with your TPU VM zone jobs: - build-and-test-on-tpu: - strategy: - fail-fast: false - matrix: - python-version: ['3.10'] - backend: [jax] - name: Run TPU tests - runs-on: - # - keras-jax-tpu-runner - # - linux-x86-ct5lp-112-4tpu - # - linux-x86-ct5lp-112-4tpu-fvn6n-runner-6kb8n - - linux-x86-ct6e-44-1tpu - # - linux-x86-ct6e-44-1tpu-4khbn-runner-x4st4 - # - linux-x86-ct6e-44-1tpu-4khbn-runner-45nmc - - container: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest - - # container: - # image: docker:latest # Provides the Docker CLI within the job container - # volumes: - # - /var/run/docker.sock:/var/run/docker.sock # Mounts host's Docker socket for control - # options: --privileged - - # steps: - # - name: Checkout Repository - # uses: actions/checkout@v4 - - # - name: Set up Docker BuildX - # uses: docker/setup-buildx-action@v3 - - # - name: Authenticate to Google Cloud (Workload Identity Federation) - # id: 'auth' - # uses: 'google-github-actions/auth@v2' - # with: - # # Replace with your Workload Identity Federation provider details. - # # This service account needs 'Artifact Registry Writer' role. - # workload_identity_provider: 'projects/YOUR_PROJECT_NUMBER/locations/global/workloadIdentityPools/YOUR_POOL_ID/providers/YOUR_PROVIDER_ID' - # service_account: 'your-github-actions-sa@${{ env.PROJECT_ID }}.iam.gserviceaccount.com' - - # - name: Configure Docker to use Google Artifact Registry - # run: gcloud auth configure-docker ${{ env.GAR_LOCATION }}-docker.pkg.dev - - # - name: Build Docker Image - # run: | - # IMAGE_TAG="${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:${{ github.sha }}" - # echo "Building Docker image: $IMAGE_TAG" - # docker build \ - # --platform=linux/amd64 \ - # -f .github/workflows/tpu/Dockerfile \ - # -t "$IMAGE_TAG" \ - # -t "${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:latest" \ - # . - # echo "Built Docker image: $IMAGE_TAG" - # echo "LOCAL_TEST_IMAGE_TAG=$IMAGE_TAG" >> $GITHUB_ENV # Store for immediate use in run step - - # - name: Push Docker Image to Artifact Registry - # run: | - # echo "Pushing Docker image to Artifact Registry: ${{ env.LOCAL_TEST_IMAGE_TAG }}" - # docker push "${{ env.LOCAL_TEST_IMAGE_TAG }}" - # docker push "${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:latest" - # echo "Pushed Docker image." - - # - name: Run Docker container and execute tests on TPU - # run: | - # echo "Running Docker container with TPU access and executing tests..." - # docker run --rm \ - # --privileged \ - # --network host \ - # -e PYTHON=3.10 \ # Use a specific version or derive from matrix - # -e KERAS_HOME=.github/workflows/config/jax \ - # -e KERAS_BACKEND=jax \ - # ${{ env.LOCAL_TEST_IMAGE_TAG }} \ - # /bin/bash -c ' \ - # echo "Verifying JAX TPU backend inside container..." && \ - # python3 -c "import jax; print(\"JAX Version:\", jax.__version__); print(\"Default Backend:\", jax.default_backend()); assert jax.default_backend().lower() == \"tpu\", \"TPU backend not found or not default\"; print(\"TPU verification successful!\")" \ - # # Add your actual pytest command here. Ensure pytest is installed inside your Docker image. - # # && pytest keras --ignore keras/src/applications --ignore keras/src/layers/merging/merging_test.py --cov=keras --cov-config=pyproject.toml - # ' - # echo "Docker container finished running tests." + # build-and-test-on-tpu: + # strategy: + # fail-fast: false + # matrix: + # python-version: ['3.10'] + # backend: [jax] + # name: Run TPU tests + # runs-on: + # # - keras-jax-tpu-runner + # # - linux-x86-ct5lp-112-4tpu + # # - linux-x86-ct5lp-112-4tpu-fvn6n-runner-6kb8n + # - linux-x86-ct6e-44-1tpu + # # - linux-x86-ct6e-44-1tpu-4khbn-runner-x4st4 + # # - linux-x86-ct6e-44-1tpu-4khbn-runner-45nmc + + # container: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest + + # # container: + # # image: docker:latest # Provides the Docker CLI within the job container + # # volumes: + # # - /var/run/docker.sock:/var/run/docker.sock # Mounts host's Docker socket for control + # # options: --privileged + + # steps: + # - name: Checkout Repository + # uses: actions/checkout@v4 + + # - name: Set up Docker BuildX + # uses: docker/setup-buildx-action@v3 + + # - name: Authenticate to Google Cloud (Workload Identity Federation) + # id: 'auth' + # uses: 'google-github-actions/auth@v2' + # with: + # # Replace with your Workload Identity Federation provider details. + # # This service account needs 'Artifact Registry Writer' role. + # workload_identity_provider: 'projects/YOUR_PROJECT_NUMBER/locations/global/workloadIdentityPools/YOUR_POOL_ID/providers/YOUR_PROVIDER_ID' + # service_account: 'your-github-actions-sa@${{ env.PROJECT_ID }}.iam.gserviceaccount.com' + + # - name: Configure Docker to use Google Artifact Registry + # run: gcloud auth configure-docker ${{ env.GAR_LOCATION }}-docker.pkg.dev + + # - name: Build Docker Image + # run: | + # IMAGE_TAG="${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:${{ github.sha }}" + # echo "Building Docker image: $IMAGE_TAG" + # docker build \ + # --platform=linux/amd64 \ + # -f .github/workflows/tpu/Dockerfile \ + # -t "$IMAGE_TAG" \ + # -t "${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:latest" \ + # . + # echo "Built Docker image: $IMAGE_TAG" + # echo "LOCAL_TEST_IMAGE_TAG=$IMAGE_TAG" >> $GITHUB_ENV # Store for immediate use in run step + + # - name: Push Docker Image to Artifact Registry + # run: | + # echo "Pushing Docker image to Artifact Registry: ${{ env.LOCAL_TEST_IMAGE_TAG }}" + # docker push "${{ env.LOCAL_TEST_IMAGE_TAG }}" + # docker push "${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:latest" + # echo "Pushed Docker image." + + # - name: Run Docker container and execute tests on TPU + # run: | + # echo "Running Docker container with TPU access and executing tests..." + # docker run --rm \ + # --privileged \ + # --network host \ + # -e PYTHON=3.10 \ # Use a specific version or derive from matrix + # -e KERAS_HOME=.github/workflows/config/jax \ + # -e KERAS_BACKEND=jax \ + # ${{ env.LOCAL_TEST_IMAGE_TAG }} \ + # /bin/bash -c ' \ + # echo "Verifying JAX TPU backend inside container..." && \ + # python3 -c "import jax; print(\"JAX Version:\", jax.__version__); print(\"Default Backend:\", jax.default_backend()); assert jax.default_backend().lower() == \"tpu\", \"TPU backend not found or not default\"; print(\"TPU verification successful!\")" \ + # # Add your actual pytest command here. Ensure pytest is installed inside your Docker image. + # # && pytest keras --ignore keras/src/applications --ignore keras/src/layers/merging/merging_test.py --cov=keras --cov-config=pyproject.toml + # ' + # echo "Docker container finished running tests." build: diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml new file mode 100644 index 000000000000..3d6e50a08c04 --- /dev/null +++ b/.github/workflows/tpu-tests.yml @@ -0,0 +1,72 @@ +name: TPU Tests + +on: + workflow_dispatch: # Allows you to manually trigger this workflow + +permissions: + contents: read # Only read permission is needed for checkout + +env: + PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID + GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) + GAR_REPO: keras-docker-images # Replace with your Artifact Registry repository name + IMAGE_NAME: keras-jax-tpu-amd64 # Replace with the name of your Docker image + IMAGE_TAG: latest # Replace with the specific tag you want to pull (e.g., latest or a specific SHA) + +jobs: + pull-and-use-image: + name: Pull & Use Image from GAR + runs-on: + # - linux-x86-ct5lp-112-4tpu + - linux-x86-ct6e-44-1tpu + # - keras-jax-tpu-runner + + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + # - name: Authenticate to Google Cloud with Service Account Key + # id: 'auth' + # uses: 'google-github-actions/auth@v2' + # with: + # # Pass the content of your GitHub Secret directly here. + # credentials_json: '${{ secrets.GCP_SA_KEY }}' + + - name: Configure Docker to use Google Artifact Registry + run: | + echo "Configuring Docker to authenticate with Google Artifact Registry..." + # This command uses the credentials set by the 'auth' step to configure Docker. + gcloud auth configure-docker ${{ env.GAR_LOCATION }}-docker.pkg.dev + echo "Docker configured." + + - name: Pull Docker Image from Artifact Registry + run: | + FULL_IMAGE_PATH="${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.GAR_REPO }}/${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }}" + echo "Attempting to pull image: $FULL_IMAGE_PATH" + docker pull "$FULL_IMAGE_PATH" + echo "Successfully pulled image: $FULL_IMAGE_PATH" + + - name: Verify Pulled Image (Optional) + run: | + echo "Listing local Docker images..." + docker images | grep "${{ env.IMAGE_NAME }}" + echo "Image verification complete." + + - name: Run Docker Container (with TPU options if on TPU VM) + run: | + FULL_IMAGE_PATH="${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.GAR_REPO }}/${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }}" + echo "Running Docker container: $FULL_IMAGE_PATH" + # IMPORTANT: Add --privileged and --network host ONLY IF this job is running + # on your TPU VM self-hosted runner AND it has the necessary permissions. + # If running on 'ubuntu-latest', these flags are not meaningful for TPU access. + docker run --rm \ + --privileged \ + --network host \ + "$FULL_IMAGE_PATH" \ + /bin/bash -c " \ + echo 'Container is running...'; \ + # Add your test or verification commands here, e.g.: + # python3 -c 'import jax; print(jax.default_backend())'; \ + # pytest your_tests.py; \ + echo 'Container execution finished.'; \ + " \ No newline at end of file From a70d19e6494117e550f7a6dd6f32c737a9968aa6 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Tue, 29 Jul 2025 12:22:46 +0530 Subject: [PATCH 33/47] updated trigger condition for TPU tests workflow --- .github/workflows/tpu-tests.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 3d6e50a08c04..a3111f0ff92c 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -1,7 +1,11 @@ name: TPU Tests on: - workflow_dispatch: # Allows you to manually trigger this workflow + push: + branches: [ master ] + pull_request: + release: + types: [created] permissions: contents: read # Only read permission is needed for checkout From 5f5b609cf9bc008f1f0dcff5601353bd03dd5e53 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Tue, 29 Jul 2025 12:31:35 +0530 Subject: [PATCH 34/47] updated container usage configuration for TPU tests workflow --- .github/workflows/tpu-tests.yml | 83 +++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 35 deletions(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index a3111f0ff92c..2778e2e5a463 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -25,10 +25,23 @@ jobs: - linux-x86-ct6e-44-1tpu # - keras-jax-tpu-runner + container: + image: ${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.GAR_REPO }}/${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }} + options: --privileged --network host # Add any necessary options for the job container + steps: - name: Checkout Repository uses: actions/checkout@v4 + - name: Verify Environment inside Container + run: | + echo "Current working directory: $(pwd)" + echo "Contents of current directory:" + ls -la + # Verify Python, JAX, etc., inside the container + python3 -c "import jax; print(jax.default_backend()); print(jax.devices())" + # Add any other verification steps here + # - name: Authenticate to Google Cloud with Service Account Key # id: 'auth' # uses: 'google-github-actions/auth@v2' @@ -36,41 +49,41 @@ jobs: # # Pass the content of your GitHub Secret directly here. # credentials_json: '${{ secrets.GCP_SA_KEY }}' - - name: Configure Docker to use Google Artifact Registry - run: | - echo "Configuring Docker to authenticate with Google Artifact Registry..." - # This command uses the credentials set by the 'auth' step to configure Docker. - gcloud auth configure-docker ${{ env.GAR_LOCATION }}-docker.pkg.dev - echo "Docker configured." + # - name: Configure Docker to use Google Artifact Registry + # run: | + # echo "Configuring Docker to authenticate with Google Artifact Registry..." + # # This command uses the credentials set by the 'auth' step to configure Docker. + # gcloud auth configure-docker ${{ env.GAR_LOCATION }}-docker.pkg.dev + # echo "Docker configured." - - name: Pull Docker Image from Artifact Registry - run: | - FULL_IMAGE_PATH="${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.GAR_REPO }}/${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }}" - echo "Attempting to pull image: $FULL_IMAGE_PATH" - docker pull "$FULL_IMAGE_PATH" - echo "Successfully pulled image: $FULL_IMAGE_PATH" + # - name: Pull Docker Image from Artifact Registry + # run: | + # FULL_IMAGE_PATH="${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.GAR_REPO }}/${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }}" + # echo "Attempting to pull image: $FULL_IMAGE_PATH" + # docker pull "$FULL_IMAGE_PATH" + # echo "Successfully pulled image: $FULL_IMAGE_PATH" - - name: Verify Pulled Image (Optional) - run: | - echo "Listing local Docker images..." - docker images | grep "${{ env.IMAGE_NAME }}" - echo "Image verification complete." + # - name: Verify Pulled Image (Optional) + # run: | + # echo "Listing local Docker images..." + # docker images | grep "${{ env.IMAGE_NAME }}" + # echo "Image verification complete." - - name: Run Docker Container (with TPU options if on TPU VM) - run: | - FULL_IMAGE_PATH="${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.GAR_REPO }}/${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }}" - echo "Running Docker container: $FULL_IMAGE_PATH" - # IMPORTANT: Add --privileged and --network host ONLY IF this job is running - # on your TPU VM self-hosted runner AND it has the necessary permissions. - # If running on 'ubuntu-latest', these flags are not meaningful for TPU access. - docker run --rm \ - --privileged \ - --network host \ - "$FULL_IMAGE_PATH" \ - /bin/bash -c " \ - echo 'Container is running...'; \ - # Add your test or verification commands here, e.g.: - # python3 -c 'import jax; print(jax.default_backend())'; \ - # pytest your_tests.py; \ - echo 'Container execution finished.'; \ - " \ No newline at end of file + # - name: Run Docker Container (with TPU options if on TPU VM) + # run: | + # FULL_IMAGE_PATH="${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.GAR_REPO }}/${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }}" + # echo "Running Docker container: $FULL_IMAGE_PATH" + # # IMPORTANT: Add --privileged and --network host ONLY IF this job is running + # # on your TPU VM self-hosted runner AND it has the necessary permissions. + # # If running on 'ubuntu-latest', these flags are not meaningful for TPU access. + # docker run --rm \ + # --privileged \ + # --network host \ + # "$FULL_IMAGE_PATH" \ + # /bin/bash -c " \ + # echo 'Container is running...'; \ + # # Add your test or verification commands here, e.g.: + # # python3 -c 'import jax; print(jax.default_backend())'; \ + # # pytest your_tests.py; \ + # echo 'Container execution finished.'; \ + # " \ No newline at end of file From 72e729f455d1e5a06e4aa651f940877963c3d1a4 Mon Sep 17 00:00:00 2001 From: Harshith K Date: Tue, 29 Jul 2025 12:33:31 +0530 Subject: [PATCH 35/47] updated env vars for TPU tests workflow --- .github/workflows/tpu-tests.yml | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 2778e2e5a463..dd1f99132a39 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -10,16 +10,17 @@ on: permissions: contents: read # Only read permission is needed for checkout -env: - PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID - GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) - GAR_REPO: keras-docker-images # Replace with your Artifact Registry repository name - IMAGE_NAME: keras-jax-tpu-amd64 # Replace with the name of your Docker image - IMAGE_TAG: latest # Replace with the specific tag you want to pull (e.g., latest or a specific SHA) - jobs: pull-and-use-image: name: Pull & Use Image from GAR + + env: + PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID + GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) + GAR_REPO: keras-docker-images # Replace with your Artifact Registry repository name + IMAGE_NAME: keras-jax-tpu-amd64 # Replace with the name of your Docker image + IMAGE_TAG: latest # Replace with the specific tag you want to pull (e.g., latest or a specific SHA) + runs-on: # - linux-x86-ct5lp-112-4tpu - linux-x86-ct6e-44-1tpu From e12929960cc5bc0893067b0ab6cbf90d2181ba42 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 12:38:43 +0530 Subject: [PATCH 36/47] updated env vars parsing syntax in TPU tests workflow --- .github/workflows/tpu-tests.yml | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index dd1f99132a39..f78dfb1f6f99 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -10,23 +10,26 @@ on: permissions: contents: read # Only read permission is needed for checkout +# --- Move env variables to the workflow level --- +env: + PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID + GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) + GAR_REPO: keras-docker-images # Replace with your Artifact Registry repository name + IMAGE_NAME: keras-jax-tpu-amd64 # Replace with the name of your Docker image + IMAGE_TAG: latest # Replace with the specific tag you want to pull (e.g., latest or a specific SHA) +# ------------------------------------------------ + jobs: pull-and-use-image: name: Pull & Use Image from GAR - env: - PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID - GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) - GAR_REPO: keras-docker-images # Replace with your Artifact Registry repository name - IMAGE_NAME: keras-jax-tpu-amd64 # Replace with the name of your Docker image - IMAGE_TAG: latest # Replace with the specific tag you want to pull (e.g., latest or a specific SHA) - - runs-on: + runs-on: # - linux-x86-ct5lp-112-4tpu - linux-x86-ct6e-44-1tpu # - keras-jax-tpu-runner container: + # Now env variables are recognized here image: ${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.GAR_REPO }}/${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }} options: --privileged --network host # Add any necessary options for the job container From 3fe5b57d14f82a58afe3a2e3fec8b28baaa806c6 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 12:41:51 +0530 Subject: [PATCH 37/47] updated env vars syntax in TPU tests workflow --- .github/workflows/tpu-tests.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index f78dfb1f6f99..13429e5f6231 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -10,14 +10,13 @@ on: permissions: contents: read # Only read permission is needed for checkout -# --- Move env variables to the workflow level --- env: PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) GAR_REPO: keras-docker-images # Replace with your Artifact Registry repository name IMAGE_NAME: keras-jax-tpu-amd64 # Replace with the name of your Docker image IMAGE_TAG: latest # Replace with the specific tag you want to pull (e.g., latest or a specific SHA) -# ------------------------------------------------ + FULL_TPU_IMAGE_PATH: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest jobs: pull-and-use-image: @@ -30,7 +29,7 @@ jobs: container: # Now env variables are recognized here - image: ${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.GAR_REPO }}/${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }} + image: ${{ env.FULL_TPU_IMAGE_PATH }} options: --privileged --network host # Add any necessary options for the job container steps: From 10df307bd6e074fce4849f2d60c26f9a42213d19 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 12:44:37 +0530 Subject: [PATCH 38/47] updated env vars syntax in TPU tests workflow --- .github/workflows/tpu-tests.yml | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 13429e5f6231..c8bbde5067d0 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -10,26 +10,43 @@ on: permissions: contents: read # Only read permission is needed for checkout -env: - PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID - GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1) - GAR_REPO: keras-docker-images # Replace with your Artifact Registry repository name - IMAGE_NAME: keras-jax-tpu-amd64 # Replace with the name of your Docker image - IMAGE_TAG: latest # Replace with the specific tag you want to pull (e.g., latest or a specific SHA) +# Remove these from 'env' +# env: +# PROJECT_ID: gtech-rmi-dev +# GAR_LOCATION: us-central1 +# GAR_REPO: keras-docker-images +# IMAGE_NAME: keras-jax-tpu-amd64 +# IMAGE_TAG: latest +# FULL_TPU_IMAGE_PATH: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest + +# --- Define the image path using 'vars' --- +vars: + PROJECT_ID: gtech-rmi-dev # Define these if you want to use them for other purposes + GAR_LOCATION: us-central1 + GAR_REPO: keras-docker-images + IMAGE_NAME: keras-jax-tpu-amd64 + IMAGE_TAG: latest FULL_TPU_IMAGE_PATH: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest +# ------------------------------------------ jobs: pull-and-use-image: name: Pull & Use Image from GAR + # You can still have an 'env' block here or at the workflow level for *other* variables + # that steps might need. Just not for the 'container.image' definition. + env: + # Example: an env variable that steps within the container might use + TF_CPP_MIN_LOG_LEVEL: 2 + runs-on: # - linux-x86-ct5lp-112-4tpu - linux-x86-ct6e-44-1tpu # - keras-jax-tpu-runner container: - # Now env variables are recognized here - image: ${{ env.FULL_TPU_IMAGE_PATH }} + # *** THIS IS THE CRITICAL CHANGE: Use 'vars' instead of 'env' *** + image: ${{ vars.FULL_TPU_IMAGE_PATH }} options: --privileged --network host # Add any necessary options for the job container steps: From dd21e09589d014b199b4edb04d6229b30b8be8e8 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 12:46:37 +0530 Subject: [PATCH 39/47] updated env vars syntax in TPU tests workflow --- .github/workflows/tpu-tests.yml | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index c8bbde5067d0..3a6505fdc096 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -10,44 +10,28 @@ on: permissions: contents: read # Only read permission is needed for checkout -# Remove these from 'env' -# env: -# PROJECT_ID: gtech-rmi-dev -# GAR_LOCATION: us-central1 -# GAR_REPO: keras-docker-images -# IMAGE_NAME: keras-jax-tpu-amd64 -# IMAGE_TAG: latest -# FULL_TPU_IMAGE_PATH: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest - -# --- Define the image path using 'vars' --- -vars: - PROJECT_ID: gtech-rmi-dev # Define these if you want to use them for other purposes +# Define base environment variables at the workflow level +env: + PROJECT_ID: gtech-rmi-dev GAR_LOCATION: us-central1 GAR_REPO: keras-docker-images IMAGE_NAME: keras-jax-tpu-amd64 IMAGE_TAG: latest - FULL_TPU_IMAGE_PATH: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest -# ------------------------------------------ jobs: pull-and-use-image: name: Pull & Use Image from GAR - # You can still have an 'env' block here or at the workflow level for *other* variables - # that steps might need. Just not for the 'container.image' definition. - env: - # Example: an env variable that steps within the container might use - TF_CPP_MIN_LOG_LEVEL: 2 - runs-on: # - linux-x86-ct5lp-112-4tpu - linux-x86-ct6e-44-1tpu # - keras-jax-tpu-runner + # Construct the image string directly here using the env variables + # This is the most reliable way to get env variables into the container image definition. container: - # *** THIS IS THE CRITICAL CHANGE: Use 'vars' instead of 'env' *** - image: ${{ vars.FULL_TPU_IMAGE_PATH }} - options: --privileged --network host # Add any necessary options for the job container + image: ${{ env.GAR_LOCATION }}/docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.GAR_REPO }}/${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }} + options: --privileged --network host steps: - name: Checkout Repository @@ -62,6 +46,9 @@ jobs: python3 -c "import jax; print(jax.default_backend()); print(jax.devices())" # Add any other verification steps here + + + # - name: Authenticate to Google Cloud with Service Account Key # id: 'auth' # uses: 'google-github-actions/auth@v2' From 328628fc886a8e1d6d628b43b064771055df312f Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 12:51:32 +0530 Subject: [PATCH 40/47] updated env vars syntax in TPU tests workflow --- .github/workflows/tpu-tests.yml | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 3a6505fdc096..12538ef3f91c 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -11,6 +11,7 @@ permissions: contents: read # Only read permission is needed for checkout # Define base environment variables at the workflow level +# These can still be used inside steps, just not for the container image definition env: PROJECT_ID: gtech-rmi-dev GAR_LOCATION: us-central1 @@ -22,15 +23,13 @@ jobs: pull-and-use-image: name: Pull & Use Image from GAR - runs-on: - # - linux-x86-ct5lp-112-4tpu - - linux-x86-ct6e-44-1tpu - # - keras-jax-tpu-runner + runs-on: linux-x86-ct6e-44-1tpu - # Construct the image string directly here using the env variables - # This is the most reliable way to get env variables into the container image definition. + # The container image path must be a complete string. + # The 'env' context is not available here. + # Note the corrected GAR path format: 'us-central1-docker.pkg.dev' container: - image: ${{ env.GAR_LOCATION }}/docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.GAR_REPO }}/${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }} + image: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest options: --privileged --network host steps: @@ -43,8 +42,8 @@ jobs: echo "Contents of current directory:" ls -la # Verify Python, JAX, etc., inside the container - python3 -c "import jax; print(jax.default_backend()); print(jax.devices())" - # Add any other verification steps here + echo "Verifying JAX installation..." + python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()}')" From 01f0c17a411d8b4334dc7c3e4a11b62cd9591ada Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 13:05:02 +0530 Subject: [PATCH 41/47] updated image name in TPU tests workflow --- .github/workflows/tpu-tests.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 12538ef3f91c..5ee3c46d5ded 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -29,7 +29,8 @@ jobs: # The 'env' context is not available here. # Note the corrected GAR path format: 'us-central1-docker.pkg.dev' container: - image: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest + # image: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest + image: docker:latest options: --privileged --network host steps: From 3e41c37899d7c7e8f42c4309e2f2357796481d5c Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 13:09:04 +0530 Subject: [PATCH 42/47] updated image name with generic ubuntu image --- .github/workflows/tpu-tests.yml | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 5ee3c46d5ded..fb96679a3474 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -30,21 +30,28 @@ jobs: # Note the corrected GAR path format: 'us-central1-docker.pkg.dev' container: # image: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest - image: docker:latest - options: --privileged --network host + image: ubuntu:22.04 + # options: --privileged --network host steps: - - name: Checkout Repository - uses: actions/checkout@v4 - - name: Verify Environment inside Container + - name: Verify Container Environment run: | - echo "Current working directory: $(pwd)" - echo "Contents of current directory:" - ls -la - # Verify Python, JAX, etc., inside the container - echo "Verifying JAX installation..." - python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()}')" + echo "Successfully started the public container! ✅" + echo "OS Details:" + cat /etc/os-release + + # - name: Checkout Repository + # uses: actions/checkout@v4 + + # - name: Verify Environment inside Container + # run: | + # echo "Current working directory: $(pwd)" + # echo "Contents of current directory:" + # ls -la + # # Verify Python, JAX, etc., inside the container + # echo "Verifying JAX installation..." + # python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()}')" From 5e55c2c04fc0002b8f489c21273795238f14c1bf Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 13:19:07 +0530 Subject: [PATCH 43/47] updated tpu-tests to use ghcr --- .github/workflows/tpu-tests.yml | 88 ++++++++++++++++++++++++++++----- 1 file changed, 76 insertions(+), 12 deletions(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index fb96679a3474..a7988c042fee 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -9,6 +9,7 @@ on: permissions: contents: read # Only read permission is needed for checkout + packages: write # Define base environment variables at the workflow level # These can still be used inside steps, just not for the container image definition @@ -20,26 +21,89 @@ env: IMAGE_TAG: latest jobs: - pull-and-use-image: - name: Pull & Use Image from GAR + build-and-push: + name: Build and Push to GHCR + runs-on: ubuntu-latest # This job doesn't need the special TPU runner + + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + # GITHUB_TOKEN is automatically created by Actions and has permissions to push to your repo's package registry. + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and Push Docker Image + uses: docker/build-push-action@v6 + with: + context: . + # Push the image to ghcr.io + push: true + # Create a unique tag using the commit SHA for this specific build + tags: ghcr.io/${{ github.repository }}:${{ github.sha }} + + + test-in-container: + name: Test in Custom Container + # This job must run after the build-and-push job is complete + needs: build-and-push + + # Use the required TPU runner runs-on: linux-x86-ct6e-44-1tpu - # The container image path must be a complete string. - # The 'env' context is not available here. - # Note the corrected GAR path format: 'us-central1-docker.pkg.dev' + # CRITICAL: Use the container image we just pushed in the previous job. + # This satisfies the runner's requirement for a container to be specified. container: - # image: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest - image: ubuntu:22.04 - # options: --privileged --network host + image: ghcr.io/${{ github.repository }}:${{ github.sha }} + options: --privileged --network host steps: + - name: Checkout Repository + uses: actions/checkout@v4 + # We need the code available inside the container's workspace to run tests. - - name: Verify Container Environment + - name: Run Verification and Tests run: | - echo "Successfully started the public container! ✅" - echo "OS Details:" - cat /etc/os-release + echo "Successfully running inside the custom container from GHCR!" + echo "Current working directory:" + pwd + echo "Contents of current directory:" + ls -la + echo "Verifying JAX installation..." + python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()}')" + + + + + # pull-and-use-image: + # name: Pull & Use Image from GAR + + # runs-on: linux-x86-ct6e-44-1tpu + + # # The container image path must be a complete string. + # # The 'env' context is not available here. + # # Note the corrected GAR path format: 'us-central1-docker.pkg.dev' + # container: + # # image: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest + # image: ubuntu:22.04 + # # options: --privileged --network host + + # steps: + + # - name: Verify Container Environment + # run: | + # echo "Successfully started the public container! ✅" + # echo "OS Details:" + # cat /etc/os-release + + + + # - name: Checkout Repository # uses: actions/checkout@v4 From ea9ff88b8212a162b6675f7718ea78ee652130b0 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 14:11:31 +0530 Subject: [PATCH 44/47] updated tpu-tests to store built image as local tar --- .github/workflows/tpu-tests.yml | 109 ++++++++++++++++++-------------- 1 file changed, 63 insertions(+), 46 deletions(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index a7988c042fee..d12fcaa31c85 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -11,71 +11,88 @@ permissions: contents: read # Only read permission is needed for checkout packages: write -# Define base environment variables at the workflow level -# These can still be used inside steps, just not for the container image definition -env: - PROJECT_ID: gtech-rmi-dev - GAR_LOCATION: us-central1 - GAR_REPO: keras-docker-images - IMAGE_NAME: keras-jax-tpu-amd64 - IMAGE_TAG: latest jobs: - - build-and-push: - name: Build and Push to GHCR - runs-on: ubuntu-latest # This job doesn't need the special TPU runner + # JOB 1: Build the image and save it as an artifact + build-and-save-image: + name: Build Image and Save as Artifact + runs-on: ubuntu-latest # A standard runner is fine for building steps: - name: Checkout Repository uses: actions/checkout@v4 - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - # GITHUB_TOKEN is automatically created by Actions and has permissions to push to your repo's package registry. - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 - - name: Build and Push Docker Image + - name: Build Docker image + id: docker_build uses: docker/build-push-action@v6 with: context: . - # Push the image to ghcr.io - push: true - # Create a unique tag using the commit SHA for this specific build - tags: ghcr.io/${{ github.repository }}:${{ github.sha }} - - - test-in-container: - name: Test in Custom Container - # This job must run after the build-and-push job is complete - needs: build-and-push - - # Use the required TPU runner + # Do not push to a registry + push: false + # Load the image into the local Docker daemon so we can save it + load: true + # Tag the image for local use + tags: local-tpu-image:latest + + - name: Save Docker image as a tar file + run: docker save --output tpu-image.tar local-tpu-image:latest + + - name: Upload Docker image artifact + uses: actions/upload-artifact@v4 + with: + name: tpu-image-artifact + path: tpu-image.tar + retention-days: 1 # Keep the artifact for only one day + + # JOB 2: Load the image from the artifact and run tests + load-and-test: + name: Load and Test Image + needs: build-and-save-image # Must wait for Job 1 to finish runs-on: linux-x86-ct6e-44-1tpu - # CRITICAL: Use the container image we just pushed in the previous job. - # This satisfies the runner's requirement for a container to be specified. + # WORKAROUND: Use a minimal, public image simply to satisfy the runner's + # requirement that a container MUST be specified. The actual tests will + # not run in this container. container: - image: ghcr.io/${{ github.repository }}:${{ github.sha }} - options: --privileged --network host + image: ubuntu:22.04 steps: - - name: Checkout Repository + - name: Download Docker image artifact + uses: actions/download-artifact@v4 + with: + name: tpu-image-artifact + + - name: Load Docker image from tar file + run: docker load --input tpu-image.tar + + - name: Checkout repository code uses: actions/checkout@v4 - # We need the code available inside the container's workspace to run tests. + # We need the source code available to mount into the container for testing. - - name: Run Verification and Tests + - name: Run Tests Inside the Loaded Container run: | - echo "Successfully running inside the custom container from GHCR!" - echo "Current working directory:" - pwd - echo "Contents of current directory:" - ls -la - echo "Verifying JAX installation..." - python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()}')" + echo "Running tests inside the custom-loaded container..." + + # Manually run the container we loaded from the artifact. + docker run --rm \ + --workdir /app \ + -v "$(pwd)":/app \ + --privileged --network host \ + local-tpu-image:latest \ + bash -c " + echo 'Successfully running inside the on-the-fly container!' + echo 'Current working directory inside container:' + pwd + echo 'Contents of current directory:' + ls -la + echo 'Verifying JAX installation...' + python3 -c \"import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()}')\" + " + + From 6d92aa910cc6f5c841638e86b70459bc41f7a498 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 14:23:15 +0530 Subject: [PATCH 45/47] updated image name from ubuntu:22.04 to docker:24.0-cli in tpu tests workflow --- .github/workflows/tpu-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index d12fcaa31c85..fc5d0337a8d9 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -57,7 +57,7 @@ jobs: # requirement that a container MUST be specified. The actual tests will # not run in this container. container: - image: ubuntu:22.04 + image: docker:24.0-cli steps: - name: Download Docker image artifact From 3c75bf8d8d9508b4b5b3214d87de8f87d291666c Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 14:38:05 +0530 Subject: [PATCH 46/47] updated image name from docker:24.0-cli to ubuntu:22.04 in tpu tests workflow and added a step to install docker client --- .github/workflows/tpu-tests.yml | 84 +++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 35 deletions(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index fc5d0337a8d9..e6bf9ca0afe1 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -1,63 +1,71 @@ -name: TPU Tests +name: TPU Tests (Build with Caching) on: push: branches: [ master ] pull_request: + workflow_dispatch: release: types: [created] -permissions: - contents: read # Only read permission is needed for checkout - packages: write - jobs: - # JOB 1: Build the image and save it as an artifact + # This job now intelligently builds OR restores the image from a cache build-and-save-image: - name: Build Image and Save as Artifact - runs-on: ubuntu-latest # A standard runner is fine for building + name: Build or Restore Cached Image + runs-on: ubuntu-latest steps: - name: Checkout Repository uses: actions/checkout@v4 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + - name: Generate cache key from Dockerfile + id: hash_dockerfile + run: echo "hash=$(sha256sum Dockerfile | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT - - name: Build Docker image - id: docker_build + - name: Cache Docker image + id: cache-docker-image + uses: actions/cache@v4 + with: + # The path to the file we want to cache + path: tpu-image.tar + # The key to identify the cache. If the Dockerfile's content changes, + # the hash changes, and a new cache is created. + key: ${{ runner.os }}-docker-${{ steps.hash_dockerfile.outputs.hash }} + # A fallback key in case of a partial match + restore-keys: | + ${{ runner.os }}-docker- + + - name: Build Docker image if cache not found + # This step only runs if the cache step above did not find a perfect match. + if: steps.cache-docker-image.outputs.cache-hit != 'true' uses: docker/build-push-action@v6 with: context: . - # Do not push to a registry push: false - # Load the image into the local Docker daemon so we can save it load: true - # Tag the image for local use tags: local-tpu-image:latest - - name: Save Docker image as a tar file + - name: Save Docker image to tar if cache not found + if: steps.cache-docker-image.outputs.cache-hit != 'true' run: docker save --output tpu-image.tar local-tpu-image:latest - name: Upload Docker image artifact + # This step always runs, uploading either the restored cache or the newly built image uses: actions/upload-artifact@v4 with: name: tpu-image-artifact path: tpu-image.tar - retention-days: 1 # Keep the artifact for only one day + retention-days: 1 - # JOB 2: Load the image from the artifact and run tests + # This job does not need to change. It always consumes the artifact. load-and-test: name: Load and Test Image - needs: build-and-save-image # Must wait for Job 1 to finish + needs: build-and-save-image runs-on: linux-x86-ct6e-44-1tpu - - # WORKAROUND: Use a minimal, public image simply to satisfy the runner's - # requirement that a container MUST be specified. The actual tests will - # not run in this container. container: - image: docker:24.0-cli + # Use the Ubuntu image that we know works on your runner + image: ubuntu:22.04 steps: - name: Download Docker image artifact @@ -65,18 +73,31 @@ jobs: with: name: tpu-image-artifact + # --- NEW STEP --- + # Install the Docker CLI inside the bootstrap container at runtime. + - name: Install Docker Client + run: | + apt-get update + apt-get install -y --no-install-recommends ca-certificates curl gnupg + install -m 0755 -d /etc/apt/keyrings + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg + chmod a+r /etc/apt/keyrings/docker.gpg + echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \ + tee /etc/apt/sources.list.d/docker.list > /dev/null + apt-get update + apt-get install -y docker-ce-cli + - name: Load Docker image from tar file + # This will now succeed because we just installed the docker client run: docker load --input tpu-image.tar - name: Checkout repository code uses: actions/checkout@v4 - # We need the source code available to mount into the container for testing. - name: Run Tests Inside the Loaded Container run: | - echo "Running tests inside the custom-loaded container..." - - # Manually run the container we loaded from the artifact. docker run --rm \ --workdir /app \ -v "$(pwd)":/app \ @@ -84,19 +105,12 @@ jobs: local-tpu-image:latest \ bash -c " echo 'Successfully running inside the on-the-fly container!' - echo 'Current working directory inside container:' - pwd - echo 'Contents of current directory:' - ls -la - echo 'Verifying JAX installation...' python3 -c \"import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()}')\" " - - # pull-and-use-image: # name: Pull & Use Image from GAR From 1589a759bab52b7e6003ab628ae7ade687b336e0 Mon Sep 17 00:00:00 2001 From: kharshith-k Date: Tue, 29 Jul 2025 14:50:21 +0530 Subject: [PATCH 47/47] added volume mount from host in load-and-test-job --- .github/workflows/tpu-tests.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index e6bf9ca0afe1..c43b3c1b1a1b 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -66,7 +66,8 @@ jobs: container: # Use the Ubuntu image that we know works on your runner image: ubuntu:22.04 - + volumes: + - /var/run/docker.sock:/var/run/docker.sock steps: - name: Download Docker image artifact uses: actions/download-artifact@v4