diff --git a/deps/ReactantExtra/.bazelrc b/deps/ReactantExtra/.bazelrc index 48c3908d12..e8163cd0af 100644 --- a/deps/ReactantExtra/.bazelrc +++ b/deps/ReactantExtra/.bazelrc @@ -2,12 +2,14 @@ build --announce_rc # TODO: Migrate for https://github.com/bazelbuild/bazel/issues/7260 common --noincompatible_enable_cc_toolchain_resolution +common --repo_env USE_HERMETIC_CC_TOOLCHAIN=0 common --experimental_repo_remote_exec common --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 common --cxxopt=-w --host_cxxopt=-w common --define=grpc_no_ares=true common --noenable_bzlmod + build --repo_env=USE_PYWRAP_RULES=True build --copt=-DGRPC_BAZEL_BUILD build --host_copt=-DGRPC_BAZEL_BUILD @@ -27,6 +29,7 @@ build:cuda --repo_env TF_NVCC_CLANG=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.1" build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" +build:cuda --repo_env=HERMETIC_NVSHMEM_VERSION="3.2.5" # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,compute_90" @@ -35,6 +38,8 @@ build:cuda --@local_config_cuda//:enable_cuda # Default hermetic CUDA and CUDNN versions. build:cuda --@local_config_cuda//cuda:include_cuda_libs=true build:cuda --@local_config_cuda//:cuda_compiler=nvcc +# build:cuda --@local_config_nvshmem//:override_include_nvshmem_libs=true +# build:cuda --@local_config_nvshmem//cuda:include_nvshmem_libs=true build:rocm --repo_env TF_NEED_ROCM=1 build:rocm --define=using_rocm=true diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index c2f8ede46f..79225856aa 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -1,5 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@xla//tools/toolchains/cross_compile/cc:cc_toolchain_config.bzl", "cc_toolchain_config") # load("//toolchain:yggdrasil.bzl", "ygg_cc_toolchain") @@ -793,12 +794,7 @@ cc_library( "-Werror=return-type", "-Werror=unused-result", "-Wno-error=stringop-truncation", - ] + select({ - "@xla//xla/tsl:is_cuda_enabled_and_oss": [ - "-DREACTANT_CUDA=1", - ], - "//conditions:default": [], - }), + ] + if_cuda(["-DREACTANT_CUDA=1"]), linkopts = select({ "//conditions:default": [], "@bazel_tools//src/conditions:darwin": [ @@ -1048,8 +1044,7 @@ cc_library( "@xla//xla/tsl/platform:errors", "@xla//xla/service:hlo_proto_cc_impl", "@com_google_absl//absl/status:statusor", - ] + select({ - "@xla//xla/tsl:is_cuda_enabled_and_oss": [ + ] + if_cuda([ "@jax//jaxlib/cuda:cuda_gpu_kernels", "@xla//xla/backends/profiler:profiler_backends", "@xla//xla/backends/profiler/gpu:device_tracer", @@ -1061,10 +1056,7 @@ cc_library( "@xla//xla/stream_executor:cuda_platform", "@xla//xla/stream_executor:kernel", "@xla//xla/stream_executor/cuda:all_runtime", - ], - "//conditions:default": [ - ], - }) + if_rocm([ + ]) + if_rocm([ "@xla//xla/stream_executor:rocm_platform", "@xla//xla/service/gpu:amdgpu_compiler", "@xla//xla/backends/profiler/gpu:device_tracer", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index e554dc5ebd..0f35b5f770 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -11,7 +11,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "05b4c93ab2f581eae42f44599a1993311ad72c63" +ENZYMEXLA_COMMIT = "fec4f6a25c046ff6acc29161656745fefa2fccda" ENZYMEXLA_SHA256 = "" @@ -286,19 +286,36 @@ load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() load( - "@xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "@rules_ml_toolchain//cc_toolchain/deps:cc_toolchain_deps.bzl", + "cc_toolchain_deps", +) + +cc_toolchain_deps() + +register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64") + +register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64_cuda") + +load( + "@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", ) cuda_json_init_repository() +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) + load( "@cuda_redist_json//:distributions.bzl", "CUDA_REDISTRIBUTIONS", "CUDNN_REDISTRIBUTIONS", ) load( - "@xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", "cuda_redist_init_repositories", "cudnn_redist_init_repository", ) @@ -312,28 +329,28 @@ cudnn_redist_init_repository( ) load( - "@xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure", ) cuda_configure(name = "local_config_cuda") load( - "@xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "@rules_ml_toolchain//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", "nccl_redist_init_repository", ) nccl_redist_init_repository() load( - "@xla//third_party/nccl/hermetic:nccl_configure.bzl", + "@rules_ml_toolchain//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure", ) nccl_configure(name = "local_config_nccl") load( - "@xla//third_party/nvshmem/hermetic:nvshmem_json_init_repository.bzl", + "@rules_ml_toolchain//third_party/nvshmem/hermetic:nvshmem_json_init_repository.bzl", "nvshmem_json_init_repository", ) @@ -344,17 +361,10 @@ load( "NVSHMEM_REDISTRIBUTIONS", ) load( - "@xla//third_party/nvshmem/hermetic:nvshmem_redist_init_repository.bzl", + "@rules_ml_toolchain//third_party/nvshmem/hermetic:nvshmem_redist_init_repository.bzl", "nvshmem_redist_init_repository", ) nvshmem_redist_init_repository( nvshmem_redistributions = NVSHMEM_REDISTRIBUTIONS, ) - -load( - "@xla//third_party/nvshmem/hermetic:nvshmem_configure.bzl", - "nvshmem_configure", -) - -nvshmem_configure(name = "local_config_nvshmem")