Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/CompileOptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ Fine-grained control over the compilation options for the Reactant compiler.

## Sharding Options

- `shardy_passes`: Defaults to `:to_mhlo_shardings`. Other options are:
- `shardy_passes`: Defaults to `:post_sdy_propagation`. Other options are:
- `:none`: No sharding passes will be run. Shardy + MHLO shardings are handled by XLA.
- `:post_sdy_propagation`: Runs the Shardy propagation passes. MHLO shardings are
handled by XLA.
Expand Down
15 changes: 11 additions & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,7 @@ function __get_compile_options_and_kwargs(;
raise_first::Bool=false,
legalize_chlo_to_stablehlo::Bool=false,
cudnn_hlo_optimize::Bool=false,
shardy_passes::Union{Symbol,ShardyPropagationOptions}=:to_mhlo_shardings,
shardy_passes::Union{Symbol,ShardyPropagationOptions}=:post_sdy_propagation,
optimize_then_pad::Bool=true,
optimize_communications::Union{Bool,OptimizeCommunicationOptions}=true,
assert_nonallocating::Bool=false,
Expand Down Expand Up @@ -2132,6 +2132,7 @@ function compile_mlir!(
get_optimize_comms_passes(
compile_options.optimize_communications
)...,
"func.func(sdy-reshard-to-collectives)",
],
",",
),
Expand Down Expand Up @@ -2165,6 +2166,7 @@ function compile_mlir!(
get_optimize_comms_passes(
compile_options.optimize_communications
)...,
"func.func(sdy-reshard-to-collectives)",
"xla-sdy-stablehlo-export-pipeline",
],
",",
Expand Down Expand Up @@ -2319,7 +2321,7 @@ function get_common_compile_options()
:client => nothing,
:raise => false,
:raise_first => false,
:shardy_passes => :(:to_mhlo_shardings),
:shardy_passes => :(:post_sdy_propagation),
:assert_nonallocating => false,
:donated_args => :(:auto),
:transpose_propagate => :(:up),
Expand Down Expand Up @@ -2362,7 +2364,10 @@ See also [`@code_xla`](@ref), [`@code_mhlo`](@ref).
"""
macro code_hlo(args...)
compile_expr, (; compiled) = compile_call_expr(
__module__, compile_mlir, get_common_compile_options(), args...
__module__,
compile_mlir,
merge(get_common_compile_options(), Dict{Symbol,Any}(:shardy_passes => :(:none))),
args...,
)
#! format: off
return esc(
Expand Down Expand Up @@ -2391,7 +2396,9 @@ macro code_mhlo(args...)
compile_mlir,
merge(
get_common_compile_options(),
Dict{Symbol,Any}(:legalize_stablehlo_to_mhlo => true),
Dict{Symbol,Any}(
:legalize_stablehlo_to_mhlo => true, :shardy_passes => :(:to_mhlo_shardings)
),
),
args...,
)
Expand Down
7 changes: 2 additions & 5 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,8 @@ function write_to_host_buffer!(data::Array, X::ConcretePJRTArray{T,N}) where {T,
completed = Set{eltype(X.sharding.device_to_array_slices)}()
for idx in 1:length(X.data)
slice = X.sharding.device_to_array_slices[idx]
if slice ∉ completed
push!(completed, slice)
else
continue
end
slice ∈ completed && continue
push!(completed, slice)
data_slice = data[slice...]
XLA.to_host(X.data[idx], data_slice, Reactant.Sharding.NoSharding())
data[slice...] .= data_slice
Expand Down
10 changes: 2 additions & 8 deletions src/xla/IFRT/Array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,7 @@ function replicate_array_to_all_devices(array::Array, sharding, mesh, size_arr)
Reactant.Compiler.run_pass_pipeline!(
mod,
join(
[
"sdy-propagation-pipeline",
"sdy-close-shardings",
"xla-sdy-stablehlo-export-pipeline",
"canonicalize",
"cse",
],
["sdy-propagation-pipeline", "sdy-close-shardings", "canonicalize", "cse"],
",",
),
)
Expand All @@ -375,7 +369,7 @@ function replicate_array_to_all_devices(array::Array, sharding, mesh, size_arr)
num_partitions=length(mesh.device_ids),
num_outputs=1, # unused
num_parameters=1, # unused
use_shardy_partitioner=false, # unused
use_shardy_partitioner=true, # unused
)

only(XLA.execute(exec, (array.buffer,), (UInt8(0),), Val(1)))
Expand Down
4 changes: 2 additions & 2 deletions test/sharding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ end
@test contains(repr(hlo), "sharding_constraint")
hlo = @code_hlo shardy_passes = :to_mhlo_shardings fn_with_constraint(x_ra)
@test !contains(repr(hlo), "sharding_constraint")
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 6
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 5

z = Reactant.to_rarray(x; sharding=constraint)
res = @jit fn_with_constraint(x_ra)
Expand All @@ -234,7 +234,7 @@ end
x_ra_no_sharding
)
@test !contains(repr(hlo), "sharding_constraint")
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 6
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 5

res = @jit fn_with_constraint(x_ra_no_sharding)
@test x .+ x ≈ Array(res)
Expand Down
Loading