diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 18009fe625..91c0595b6d 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -895,17 +895,24 @@ end # stack function overloaded_stack(dims::Union{Integer,Colon}, xs) - @assert allequal(ndims.(xs)) "All arrays must have the same number of dimensions..." + @assert allequal([ndims(x) for x in xs]) "All arrays must have the same number of \ + dimensions..." dims = dims isa Colon ? ndims(first(xs)) + 1 : dims - res = map(xs) do x + res = [] + for x in xs new_shape = ntuple( i -> i == dims ? 1 : (i < dims ? size(x, i) : size(x, i - 1)), ndims(x) + 1 ) - return materialize_traced_array(reshape(x, new_shape)) + push!(res, materialize_traced_array(internal_stack_reshape(x, new_shape))) end return cat(res...; dims) end +internal_stack_reshape(x, new_shape) = reshape(x, new_shape) +function internal_stack_reshape(x::TracedRNumber{T}, new_shape) where {T} + return internal_stack_reshape(TracedRArray{T,0}((), x.mlir_data, ()), new_shape) +end + # sort function Base.sort(x::AnyTracedRArray; alg=missing, kwargs...) return sort!(copy(x); alg, kwargs...) diff --git a/test/basic.jl b/test/basic.jl index 36ee00bd16..5451fc2b11 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1531,3 +1531,12 @@ end @test @jit(sum(x_ra; dims=1:2)) ≈ sum(x; dims=1:2) end + +stack_numbers(x) = stack([sum(x[:, i]) for i in axes(x, 2)]) + +@testset "stack numbers" begin + x = rand(Float32, 2, 4) + x_ra = Reactant.to_rarray(x) + + @test @jit(stack_numbers(x_ra)) ≈ stack_numbers(x) +end