@@ -1678,23 +1678,11 @@ instead.
1678
1678
@assert length (updates) == size (scatter_indices, 1 )
1679
1679
@assert size (scatter_indices, 2 ) == N
1680
1680
1681
- updates = convert (TracedRArray{T,1 }, updates)
1682
-
1683
- update_computation = MLIR. IR. Region ()
1684
- block = MLIR. IR. Block (
1685
- [mlir_type (TracedRNumber{T}), mlir_type (TracedRNumber{T})],
1686
- [MLIR. IR. Location (), MLIR. IR. Location ()],
1687
- )
1688
- return_op = MLIR. Dialects. stablehlo. return_ ([MLIR. IR. argument (block, 2 )])
1689
- MLIR. IR. rmfromparent! (return_op)
1690
- push! (block, return_op)
1691
- pushfirst! (update_computation, block)
1692
-
1693
1681
return scatter (
1682
+ (a, b) -> b,
1694
1683
[dest],
1695
1684
scatter_indices,
1696
- [updates];
1697
- update_computation,
1685
+ [convert (TracedRArray{T,1 }, updates)];
1698
1686
update_window_dims= Int64[],
1699
1687
inserted_window_dims= collect (Int64, 1 : N),
1700
1688
input_batching_dims= Int64[],
@@ -1705,6 +1693,36 @@ instead.
1705
1693
)[1 ]
1706
1694
end
1707
1695
1696
+ @noinline function scatter (
1697
+ f:: F ,
1698
+ dest:: Vector{TracedRArray{T,N}} ,
1699
+ scatter_indices:: TracedRArray{Int64} ,
1700
+ updates:: Vector{<:TracedRArray{T}} ;
1701
+ location= mlir_stacktrace (" scatter" , @__FILE__ , @__LINE__ ),
1702
+ kwargs... ,
1703
+ ) where {F,T,N}
1704
+ sample_inputs = (
1705
+ Reactant. TracedUtils. promote_to (TracedRNumber{T}, zero (T)),
1706
+ Reactant. TracedUtils. promote_to (TracedRNumber{T}, zero (T)),
1707
+ )
1708
+
1709
+ compiled_fn =
1710
+ Reactant. TracedUtils. make_mlir_fn (
1711
+ f,
1712
+ sample_inputs,
1713
+ (),
1714
+ " update_computation" ,
1715
+ false ;
1716
+ args_in_result= :result ,
1717
+ return_dialect= :stablehlo ,
1718
+ ). f
1719
+ update_computation = MLIR. IR. Region ()
1720
+ MLIR. API. mlirRegionTakeBody (update_computation, MLIR. IR. region (compiled_fn, 1 ))
1721
+ MLIR. IR. rmfromparent! (compiled_fn)
1722
+
1723
+ return scatter (dest, scatter_indices, updates; update_computation, location, kwargs... )
1724
+ end
1725
+
1708
1726
@noinline function scatter (
1709
1727
dest:: Vector{TracedRArray{T,N}} ,
1710
1728
scatter_indices:: TracedRArray{Int64} ,
0 commit comments