Skip to content

Commit a645d94

Browse files
authored
feat: generalize Ops.scatter to handle tracing (#1392)
* feat: generalize Ops.scatter to handle tracing * Update src/TracedRArray.jl
1 parent 24f9acb commit a645d94

File tree

2 files changed

+33
-26
lines changed

2 files changed

+33
-26
lines changed

src/Ops.jl

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,23 +1678,11 @@ instead.
16781678
@assert length(updates) == size(scatter_indices, 1)
16791679
@assert size(scatter_indices, 2) == N
16801680

1681-
updates = convert(TracedRArray{T,1}, updates)
1682-
1683-
update_computation = MLIR.IR.Region()
1684-
block = MLIR.IR.Block(
1685-
[mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})],
1686-
[MLIR.IR.Location(), MLIR.IR.Location()],
1687-
)
1688-
return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)])
1689-
MLIR.IR.rmfromparent!(return_op)
1690-
push!(block, return_op)
1691-
pushfirst!(update_computation, block)
1692-
16931681
return scatter(
1682+
(a, b) -> b,
16941683
[dest],
16951684
scatter_indices,
1696-
[updates];
1697-
update_computation,
1685+
[convert(TracedRArray{T,1}, updates)];
16981686
update_window_dims=Int64[],
16991687
inserted_window_dims=collect(Int64, 1:N),
17001688
input_batching_dims=Int64[],
@@ -1705,6 +1693,36 @@ instead.
17051693
)[1]
17061694
end
17071695

1696+
@noinline function scatter(
1697+
f::F,
1698+
dest::Vector{TracedRArray{T,N}},
1699+
scatter_indices::TracedRArray{Int64},
1700+
updates::Vector{<:TracedRArray{T}};
1701+
location=mlir_stacktrace("scatter", @__FILE__, @__LINE__),
1702+
kwargs...,
1703+
) where {F,T,N}
1704+
sample_inputs = (
1705+
Reactant.TracedUtils.promote_to(TracedRNumber{T}, zero(T)),
1706+
Reactant.TracedUtils.promote_to(TracedRNumber{T}, zero(T)),
1707+
)
1708+
1709+
compiled_fn =
1710+
Reactant.TracedUtils.make_mlir_fn(
1711+
f,
1712+
sample_inputs,
1713+
(),
1714+
"update_computation",
1715+
false;
1716+
args_in_result=:result,
1717+
return_dialect=:stablehlo,
1718+
).f
1719+
update_computation = MLIR.IR.Region()
1720+
MLIR.API.mlirRegionTakeBody(update_computation, MLIR.IR.region(compiled_fn, 1))
1721+
MLIR.IR.rmfromparent!(compiled_fn)
1722+
1723+
return scatter(dest, scatter_indices, updates; update_computation, location, kwargs...)
1724+
end
1725+
17081726
@noinline function scatter(
17091727
dest::Vector{TracedRArray{T,N}},
17101728
scatter_indices::TracedRArray{Int64},

src/TracedRArray.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -439,22 +439,11 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
439439
]
440440
updates = Ops.reshape(updates, updates_shape)
441441

442-
# simply set the 2nd block argument as a result
443-
update_computation = MLIR.IR.Region()
444-
block = MLIR.IR.Block(
445-
[Ops.mlir_type(TracedRNumber{T}), Ops.mlir_type(TracedRNumber{T})],
446-
[MLIR.IR.Location(), MLIR.IR.Location()],
447-
)
448-
return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)])
449-
MLIR.IR.rmfromparent!(return_op)
450-
push!(block, return_op)
451-
pushfirst!(update_computation, block)
452-
453442
res = Ops.scatter(
443+
(xᵢ, xⱼ) -> xⱼ,
454444
[a],
455445
gather_dims.start_indices,
456446
[updates];
457-
update_computation,
458447
update_window_dims=gather_dims.offset_dims,
459448
inserted_window_dims=gather_dims.collapsed_slice_dims,
460449
input_batching_dims=Int64[],

0 commit comments

Comments
 (0)