Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deps/ReactantExtra/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ build --announce_rc

# TODO: Migrate for https://github.com/bazelbuild/bazel/issues/7260
common --noincompatible_enable_cc_toolchain_resolution
common --repo_env USE_HERMETIC_CC_TOOLCHAIN=0
common --experimental_repo_remote_exec
common --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
common --cxxopt=-w --host_cxxopt=-w
common --define=grpc_no_ares=true
common --noenable_bzlmod


build --repo_env=USE_PYWRAP_RULES=True
build --copt=-DGRPC_BAZEL_BUILD
build --host_copt=-DGRPC_BAZEL_BUILD
Expand All @@ -27,6 +29,7 @@ build:cuda --repo_env TF_NVCC_CLANG=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.1"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0"
build:cuda --repo_env=HERMETIC_NVSHMEM_VERSION="3.2.5"
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,compute_90"
Expand All @@ -35,6 +38,8 @@ build:cuda --@local_config_cuda//:enable_cuda
# Default hermetic CUDA and CUDNN versions.
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true
build:cuda --@local_config_cuda//:cuda_compiler=nvcc
# build:cuda --@local_config_nvshmem//:override_include_nvshmem_libs=true
# build:cuda --@local_config_nvshmem//cuda:include_nvshmem_libs=true

build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --define=using_rocm=true
Expand Down
16 changes: 4 additions & 12 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@xla//tools/toolchains/cross_compile/cc:cc_toolchain_config.bzl", "cc_toolchain_config")

# load("//toolchain:yggdrasil.bzl", "ygg_cc_toolchain")
Expand Down Expand Up @@ -793,12 +794,7 @@ cc_library(
"-Werror=return-type",
"-Werror=unused-result",
"-Wno-error=stringop-truncation",
] + select({
"@xla//xla/tsl:is_cuda_enabled_and_oss": [
"-DREACTANT_CUDA=1",
],
"//conditions:default": [],
}),
] + if_cuda(["-DREACTANT_CUDA=1"]),
linkopts = select({
"//conditions:default": [],
"@bazel_tools//src/conditions:darwin": [
Expand Down Expand Up @@ -1048,8 +1044,7 @@ cc_library(
"@xla//xla/tsl/platform:errors",
"@xla//xla/service:hlo_proto_cc_impl",
"@com_google_absl//absl/status:statusor",
] + select({
"@xla//xla/tsl:is_cuda_enabled_and_oss": [
] + if_cuda([
"@jax//jaxlib/cuda:cuda_gpu_kernels",
"@xla//xla/backends/profiler:profiler_backends",
"@xla//xla/backends/profiler/gpu:device_tracer",
Expand All @@ -1061,10 +1056,7 @@ cc_library(
"@xla//xla/stream_executor:cuda_platform",
"@xla//xla/stream_executor:kernel",
"@xla//xla/stream_executor/cuda:all_runtime",
],
"//conditions:default": [
],
}) + if_rocm([
]) + if_rocm([
"@xla//xla/stream_executor:rocm_platform",
"@xla//xla/service/gpu:amdgpu_compiler",
"@xla//xla/backends/profiler/gpu:device_tracer",
Expand Down
40 changes: 25 additions & 15 deletions deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "05b4c93ab2f581eae42f44599a1993311ad72c63"
ENZYMEXLA_COMMIT = "fec4f6a25c046ff6acc29161656745fefa2fccda"

ENZYMEXLA_SHA256 = ""

Expand Down Expand Up @@ -286,19 +286,36 @@ load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
flatbuffers()

load(
"@xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"@rules_ml_toolchain//cc_toolchain/deps:cc_toolchain_deps.bzl",
"cc_toolchain_deps",
)

cc_toolchain_deps()

register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64")

register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64_cuda")

load(
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
)

cuda_json_init_repository()

load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
"CUDNN_REDISTRIBUTIONS",
)

load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
"CUDNN_REDISTRIBUTIONS",
)
load(
"@xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"cuda_redist_init_repositories",
"cudnn_redist_init_repository",
)
Expand All @@ -312,28 +329,28 @@ cudnn_redist_init_repository(
)

load(
"@xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"cuda_configure",
)

cuda_configure(name = "local_config_cuda")

load(
"@xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"@rules_ml_toolchain//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"nccl_redist_init_repository",
)

nccl_redist_init_repository()

load(
"@xla//third_party/nccl/hermetic:nccl_configure.bzl",
"@rules_ml_toolchain//third_party/nccl/hermetic:nccl_configure.bzl",
"nccl_configure",
)

nccl_configure(name = "local_config_nccl")

load(
"@xla//third_party/nvshmem/hermetic:nvshmem_json_init_repository.bzl",
"@rules_ml_toolchain//third_party/nvshmem/hermetic:nvshmem_json_init_repository.bzl",
"nvshmem_json_init_repository",
)

Expand All @@ -344,17 +361,10 @@ load(
"NVSHMEM_REDISTRIBUTIONS",
)
load(
"@xla//third_party/nvshmem/hermetic:nvshmem_redist_init_repository.bzl",
"@rules_ml_toolchain//third_party/nvshmem/hermetic:nvshmem_redist_init_repository.bzl",
"nvshmem_redist_init_repository",
)

nvshmem_redist_init_repository(
nvshmem_redistributions = NVSHMEM_REDISTRIBUTIONS,
)

load(
"@xla//third_party/nvshmem/hermetic:nvshmem_configure.bzl",
"nvshmem_configure",
)

nvshmem_configure(name = "local_config_nvshmem")
Loading