Skip to content

Commit 9795fa6

Browse files
authored
fix: stack for numbers (#1576)
1 parent a3be565 commit 9795fa6

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/TracedRArray.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -933,17 +933,24 @@ end
933933

934934
# stack
935935
function overloaded_stack(dims::Union{Integer,Colon}, xs)
936-
@assert allequal(ndims.(xs)) "All arrays must have the same number of dimensions..."
936+
@assert allequal([ndims(x) for x in xs]) "All arrays must have the same number of \
937+
dimensions..."
937938
dims = dims isa Colon ? ndims(first(xs)) + 1 : dims
938-
res = map(xs) do x
939+
res = []
940+
for x in xs
939941
new_shape = ntuple(
940942
i -> i == dims ? 1 : (i < dims ? size(x, i) : size(x, i - 1)), ndims(x) + 1
941943
)
942-
return materialize_traced_array(reshape(x, new_shape))
944+
push!(res, materialize_traced_array(internal_stack_reshape(x, new_shape)))
943945
end
944946
return cat(res...; dims)
945947
end
946948

949+
internal_stack_reshape(x, new_shape) = reshape(x, new_shape)
950+
function internal_stack_reshape(x::TracedRNumber{T}, new_shape) where {T}
951+
return internal_stack_reshape(TracedRArray{T,0}((), x.mlir_data, ()), new_shape)
952+
end
953+
947954
# sort
948955
function Base.sort(x::AnyTracedRArray; alg=missing, kwargs...)
949956
return sort!(copy(x); alg, kwargs...)

test/basic.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,3 +1531,12 @@ end
15311531

15321532
@test @jit(sum(x_ra; dims=1:2)) sum(x; dims=1:2)
15331533
end
1534+
1535+
stack_numbers(x) = stack([sum(x[:, i]) for i in axes(x, 2)])
1536+
1537+
@testset "stack numbers" begin
1538+
x = rand(Float32, 2, 4)
1539+
x_ra = Reactant.to_rarray(x)
1540+
1541+
@test @jit(stack_numbers(x_ra)) stack_numbers(x)
1542+
end

0 commit comments

Comments
 (0)