Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ PythonCall = "0.9.25"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.15"
Reactant_jll = "0.0.230"
Reactant_jll = "0.0.232"
ScopedValues = "1.3.0"
Scratch = "1.2"
Sockets = "1.10"
Expand Down
2 changes: 2 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ export default defineConfig({
{text: "Local build", link: "/tutorials/local-build"},
{text: "Control Flow", link: "/tutorials/control-flow"},
{text: "Sharding", link: "/tutorials/sharding"},
{text: "Persistent Compilation Cache", link: "/tutorials/persistent_compile_cache"},
],
},
{
Expand Down Expand Up @@ -160,6 +161,7 @@ export default defineConfig({
{ text: "Local build", link: "/tutorials/local-build" },
{ text: "Control Flow", link: "/tutorials/control-flow" },
{ text: "Sharding", link: "/tutorials/sharding" },
{ text: "Persistent Compilation Cache", link: "/tutorials/persistent_compile_cache" },
],
}
],
Expand Down
6 changes: 6 additions & 0 deletions docs/src/api/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ Reactant.addressable_devices
Reactant.ignore_derivatives
```

## Persistent Compilation Cache

```@docs
clear_compilation_cache!
```

## Internal utils

```@docs
Expand Down
10 changes: 6 additions & 4 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Tutorials

- [Profiling](@ref profiling).
- [Multi-Host Environments](@ref distributed).
- [Local build of ReactantExtra](@ref local-build).
- [Control flow](@ref control-flow).
- [Profiling](@ref profiling).
- [Multi-Host Environments](@ref distributed).
- [Local build of ReactantExtra](@ref local-build).
- [Control flow](@ref control-flow).
- [Sharding](@ref sharding).
- [Persistent Compilation Cache](@ref persistent_compile_cache).

We are currently working on adding more tutorials to Reactant!! Please check back soon!
28 changes: 28 additions & 0 deletions docs/src/tutorials/persistent_compile_cache.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# [Persistent Compilation Cache](@id persistent_compile_cache)

Reactant.jl supports a persistent compilation cache that caches compiled and autotuned
kernels on disk. We use [XLA's persisted autotuning](https://openxla.org/xla/persisted_autotuning)
for this purpose. By default, the autotuning cache is enabled.

## Preferences

- `persistent_cache_enabled`: Whether to enable the persistent compilation cache. Defaults
to `false`.
- `persistent_cache_directory`: The base directory to use for the persistent compilation
cache. Note that it is recommended to not set this preference, as Reactant will create
a unique directory corresponding to XLA and Reactant_jll's version. If the user sets
this preference, it is the user's responsibility to ensure that the directory exists
and is writable and needs to be segregated based on XLA and Reactant_jll's version.
Defaults to `""`.
- `persistent_kernel_cache_enabled`: Whether to enable the kernel cache. Defaults to `false`.
- `persistent_autotune_cache_enabled`: Whether to enable the autotuning cache. Defaults to
`true`.

## Clearing the cache

To clear the cache, you can use [`Reactant.clear_compilation_cache!`](@ref):

```julia
using Reactant
clear_compilation_cache!()
```
91 changes: 91 additions & 0 deletions src/PersistentCompileCache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
module PersistentCompileCache

using ..Reactant: Reactant

using Preferences: load_preference
using Scratch: @get_scratch!
using Reactant_jll: Reactant_jll

const CACHE_DIR = Ref{Union{Nothing,String}}(nothing)
const KERNEL_CACHE_ENABLED = Ref(false)
const AUTOTUNE_CACHE_ENABLED = Ref(false)

function __init__()
persistent_cache_enabled = load_preference(Reactant, "persistent_cache_enabled", true)
persistent_cache_directory = load_preference(Reactant, "persistent_cache_directory", "")

if persistent_cache_enabled
if isempty(persistent_cache_directory)
# We version our cache directory based on Reactant_jll version (technically we
# need to version according to XLA, but this is a good enough proxy)
version = pkgversion(Reactant_jll)
CACHE_DIR[] = @get_scratch!(
"xla_persistent_cache_$(version.major)_$(version.minor)_$(version.patch)"
)
else
CACHE_DIR[] = persistent_cache_directory
end
@debug "Persistent compilation cache enabled. Using base directory: $(CACHE_DIR[])"

KERNEL_CACHE_ENABLED[] = load_preference(
Reactant, "persistent_kernel_cache_enabled", false
)
@debug "Kernel cache enabled: $(KERNEL_CACHE_ENABLED[])"

AUTOTUNE_CACHE_ENABLED[] = load_preference(
Reactant, "persistent_autotune_cache_enabled", true
)
@debug "Autotune cache enabled: $(AUTOTUNE_CACHE_ENABLED[])"
else
@debug "Persistent compilation cache disabled..."
end

return nothing
end

function kernel_cache_enabled()
return KERNEL_CACHE_ENABLED[] && CACHE_DIR[] !== nothing
end

function get_kernel_cache_path()
kernel_cache_enabled() || return ""
return joinpath(CACHE_DIR[], "xla_gpu_kernel_cache_file")
end

function autotune_cache_enabled()
return AUTOTUNE_CACHE_ENABLED[] && CACHE_DIR[] !== nothing
end

function get_autotune_cache_directory()
autotune_cache_enabled() || return ""
dir = joinpath(CACHE_DIR[], "xla_gpu_per_fusion_autotune_cache_dir/")
mkpath(dir)
return dir
end

"""
clear_compilation_cache!()

Deletes the compilation cache directory. This removes all cached compilation artifacts for
all past versions of Reactant_jll.
"""
function clear_compilation_cache!()
(CACHE_DIR[] !== nothing) && rm(CACHE_DIR[]; recursive=true, force=true)

for dir in readdir(dirname(@get_scratch!("test_dir")); join=true)
if isdir(dir) && startswith(basename(dir), "xla_persistent_cache")
@debug "Removing cache directory: $dir"
rm(dir; recursive=true, force=true)
end
end

return nothing
end

export clear_compilation_cache!

end

using .PersistentCompileCache

export clear_compilation_cache!
2 changes: 2 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export @allowscalar # re-exported from GPUArraysCore

is_extension_loaded(::Val) = false

include("PersistentCompileCache.jl")

# auxiliary types and functions
include("OrderedIdDict.jl")

Expand Down
5 changes: 5 additions & 0 deletions src/xla/IFRT/LoadedExecutable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ function XLA.compile(
num_replicas::Int64,
num_partitions::Int64,
is_sharded::Bool,
Reactant.PersistentCompileCache.kernel_cache_enabled()::Bool,
Reactant.PersistentCompileCache.get_kernel_cache_path()::Cstring,
Reactant.PersistentCompileCache.autotune_cache_enabled()::Bool,
Reactant.PersistentCompileCache.get_autotune_cache_directory()::Cstring,
Reactant.Distributed.local_rank()::Cint,
)::Ptr{Cvoid}
end
end
Expand Down
5 changes: 5 additions & 0 deletions src/xla/PJRT/LoadedExecutable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ function XLA.compile(
num_replicas::Int64,
num_partitions::Int64,
is_sharded::Bool,
Reactant.PersistentCompileCache.kernel_cache_enabled()::Bool,
Reactant.PersistentCompileCache.get_kernel_cache_path()::Cstring,
Reactant.PersistentCompileCache.autotune_cache_enabled()::Bool,
Reactant.PersistentCompileCache.get_autotune_cache_directory()::Cstring,
Reactant.Distributed.local_rank()::Cint,
)::Ptr{Cvoid}
end
end
Expand Down
14 changes: 14 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1510,3 +1510,17 @@ end
@test @jit(nested_mapreduce_hcat(x_ra, y_ra)) ≈ nested_mapreduce_hcat(x, y)
end
end

@testset "compilation cache" begin
if Reactant.PersistentCompileCache.autotune_cache_enabled() &&
contains(string(Reactant.devices()[1]), "CUDA")
A = Reactant.to_rarray(rand(Float32, 2, 5))
B = Reactant.to_rarray(rand(Float32, 5, 1000))
@jit A * B # This should populate the cache dir

@test any(
endswith(".textproto"),
readdir(Reactant.PersistentCompileCache.get_autotune_cache_directory()),
)
end
end
Loading