diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index f7c6ed10a..48bda9a36 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -26,7 +26,6 @@ cc_library( srcs = ["gpu.cc"], copts = ["-Wno-vla-cxx-extension"], deps = [ - "@llvm-project//mlir:CAPIIR", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", ], @@ -36,7 +35,6 @@ cc_library( name = "cpu", srcs = ["cpu.cc"], deps = [ - "@llvm-project//mlir:CAPIIR", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", ], @@ -603,7 +601,6 @@ cc_library( "@llvm-project//mlir:VectorToLLVM", "@llvm-project//mlir:VectorToSCF", "@llvm-project//mlir:ViewLikeInterface", - "@llvm-project//mlir:CAPIIR", "@shardy//shardy/dialect/sdy/ir:dialect", "@shardy//shardy/dialect/sdy/transforms/propagation:op_sharding_rule_builder", "@stablehlo//:base", diff --git a/src/enzyme_ad/jax/Passes/LowerJIT.cpp b/src/enzyme_ad/jax/Passes/LowerJIT.cpp index 018376114..9df0cda90 100644 --- a/src/enzyme_ad/jax/Passes/LowerJIT.cpp +++ b/src/enzyme_ad/jax/Passes/LowerJIT.cpp @@ -80,7 +80,21 @@ #include "mlir/Target/LLVMIR/Export.h" -#include "mlir-c/Support.h" +#if (defined(_WIN32) || defined(__CYGWIN__)) && \ + !defined(MLIR_CAPI_ENABLE_WINDOWS_DLL_DECLSPEC) +// Visibility annotations disabled. +#define MLIR_CAPI_EXPORTED +#elif defined(_WIN32) || defined(__CYGWIN__) +// Windows visibility declarations. +#if MLIR_CAPI_BUILDING_LIBRARY +#define MLIR_CAPI_EXPORTED __declspec(dllexport) +#else +#define MLIR_CAPI_EXPORTED __declspec(dllimport) +#endif +#else +// Non-windows: use visibility attributes. +#define MLIR_CAPI_EXPORTED __attribute__((visibility("default"))) +#endif #define DEBUG_TYPE "lower-jit" diff --git a/src/enzyme_ad/jax/cpu.cc b/src/enzyme_ad/jax/cpu.cc index 78d66f0a6..48451da3e 100644 --- a/src/enzyme_ad/jax/cpu.cc +++ b/src/enzyme_ad/jax/cpu.cc @@ -2,7 +2,21 @@ #include "xla/service/custom_call_target_registry.h" #include -#include "mlir-c/Support.h" +#if (defined(_WIN32) || defined(__CYGWIN__)) && \ + !defined(MLIR_CAPI_ENABLE_WINDOWS_DLL_DECLSPEC) +// Visibility annotations disabled. +#define MLIR_CAPI_EXPORTED +#elif defined(_WIN32) || defined(__CYGWIN__) +// Windows visibility declarations. +#if MLIR_CAPI_BUILDING_LIBRARY +#define MLIR_CAPI_EXPORTED __declspec(dllexport) +#else +#define MLIR_CAPI_EXPORTED __declspec(dllimport) +#endif +#else +// Non-windows: use visibility attributes. +#define MLIR_CAPI_EXPORTED __attribute__((visibility("default"))) +#endif template struct CallInfo; diff --git a/src/enzyme_ad/jax/gpu.cc b/src/enzyme_ad/jax/gpu.cc index 5a84d6a72..63b7dfb25 100644 --- a/src/enzyme_ad/jax/gpu.cc +++ b/src/enzyme_ad/jax/gpu.cc @@ -1,7 +1,21 @@ #include "xla/ffi/api/ffi.h" #include "xla/ffi/ffi_api.h" -#include "mlir-c/Support.h" +#if (defined(_WIN32) || defined(__CYGWIN__)) && \ + !defined(MLIR_CAPI_ENABLE_WINDOWS_DLL_DECLSPEC) +// Visibility annotations disabled. +#define MLIR_CAPI_EXPORTED +#elif defined(_WIN32) || defined(__CYGWIN__) +// Windows visibility declarations. +#if MLIR_CAPI_BUILDING_LIBRARY +#define MLIR_CAPI_EXPORTED __declspec(dllexport) +#else +#define MLIR_CAPI_EXPORTED __declspec(dllimport) +#endif +#else +// Non-windows: use visibility attributes. +#define MLIR_CAPI_EXPORTED __attribute__((visibility("default"))) +#endif template struct CallInfo;