From 4660547b0b4e1086987e0858d312cc78d995b992 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 4 Apr 2021 00:03:04 -0400 Subject: [PATCH 1/3] branch, take 1 --- docs/src/basics.md | 17 +++++++++++++++++ src/macro.jl | 18 ++++++++++++------ test/four.jl | 5 +++-- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/docs/src/basics.md b/docs/src/basics.md index 944c8f7..eaf60ea 100644 --- a/docs/src/basics.md +++ b/docs/src/basics.md @@ -345,6 +345,23 @@ true ``` This one can also be done as `reinterpret(reshape, Tri{Int64}, M)`. +But what would be smarter in the general case is to do one splat, not many: + +```julia-repl +julia> Tri.(eachrow(M)...) +4-element Vector{Tri{Int64}}: + Tri{Int64}(1, 2, 3) + Tri{Int64}(4, 5, 6) + Tri{Int64}(7, 8, 9) + Tri{Int64}(10, 11, 12) + +julia> @btime Base.splat(tuple).(eachcol(m)) setup=(m=rand(4,100)); + 38.041 μs (1411 allocations: 48.33 KiB) + +julia> @btime tuple.(eachrow(m)...) setup=(m=rand(4,100)); + 824.256 ns (12 allocations: 4.06 KiB) +``` + ## Arrays of functions Besides arrays of numbers (and arrays of arrays) you can also broadcast an array of functions, diff --git a/src/macro.jl b/src/macro.jl index 249956c..7a0d023 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -521,12 +521,6 @@ function readycast(ex, target, store::NamedTuple, call::CallInfo) @capture(ex, funs_[ijk__](args__) ) && return :( Core._apply($funs[$(ijk...)], $(args...) ) ) # splats - @capture(ex, fun_(pre__, arg_...)) && containsindexing(arg) && begin - @gensym splat ys - xs = [gensym(Symbol(:x, i)) for i in 1:length(pre)] - push!(store.main, :( local $splat($(xs...), $ys) = $fun($(xs...), $ys...) )) - return :( $splat($(pre...), $arg) ) - end # Apart from those, readycast acts only on lone tensors: @capture(ex, A_[ijk__]) || return ex @@ -658,6 +652,18 @@ function recursemacro(ex::Expr, canon, store::NamedTuple, call::CallInfo) ex = scalar ? :($name) : :($name[$(ind...)]) end + # Handle splatted slices -- need to be caught early + if @capture(ex, fun_(A_[ijk__]...)) && any(iscolon, ijk) && begin + # println("splat!") + revcode = map(i -> iscolon(i) ? :* : :(:), ijk) + fex = :( $fun.(TensorCast.sliceview($A, ($(revcode...),))...) ) + fA = maybepush(fex, store) + ijknew = filter(!iscolon, ijk) + newex = :($fA[$(ijknew...)]) + # @show ex newex + return newex + end + end # Tidy up indices, A[i,j][k] will be hit on different rounds... if @capture(ex, A_[ijk__]) if !(A isa Symbol) # this check allows some tests which have c[c] etc. diff --git a/test/four.jl b/test/four.jl index 229a19d..9def167 100644 --- a/test/four.jl +++ b/test/four.jl @@ -144,8 +144,9 @@ end @test z′[3,4] == Tuple(y[4,3,:]) # with other arguments, they have to come first at the moment: - @cast z2[j,i] := tuple(i, j, y[i,j,:]...) - @test z2[4,5] == (5, 4, y[5,4,:]...) + @test_skip @cast z2[j,i] := tuple(i, j, y[i,j,:]...) + # @test z2[4,5] == (5, 4, y[5,4,:]...) + struct Quad x; y; z; t; end @cast z3[i,j] := Quad(y[i,:,j]...) From b36d9ace83e7ce525517e4b52ba8ca597622e442 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 4 Apr 2021 01:39:35 -0400 Subject: [PATCH 2/3] an attempt --- src/macro.jl | 23 +++++++++++++---------- test/four.jl | 4 ++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/macro.jl b/src/macro.jl index 7a0d023..f190283 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -520,7 +520,6 @@ function readycast(ex, target, store::NamedTuple, call::CallInfo) # and arrays of functions, using apply: @capture(ex, funs_[ijk__](args__) ) && return :( Core._apply($funs[$(ijk...)], $(args...) ) ) - # splats # Apart from those, readycast acts only on lone tensors: @capture(ex, A_[ijk__]) || return ex @@ -653,17 +652,21 @@ function recursemacro(ex::Expr, canon, store::NamedTuple, call::CallInfo) end # Handle splatted slices -- need to be caught early - if @capture(ex, fun_(A_[ijk__]...)) && any(iscolon, ijk) && begin - # println("splat!") - revcode = map(i -> iscolon(i) ? :* : :(:), ijk) - fex = :( $fun.(TensorCast.sliceview($A, ($(revcode...),))...) ) - fA = maybepush(fex, store) - ijknew = filter(!iscolon, ijk) - newex = :($fA[$(ijknew...)]) - # @show ex newex - return newex + if @capture(ex, fun_(args__)) && any(a -> @capture(a, (A_[ijk__]...)), args) && any(iscolon, ijk) + newargs = map(args) do arg + if @capture(arg, (A_[ijk__]...)) && any(iscolon, ijk) + revcode = map(i -> iscolon(i) ? :* : :(:), ijk) + sliced = :( TensorCast.sliceview($A, ($(revcode...),)) ) + sym = maybepush(sliced, store) + indpost = filter(!iscolon, ijk) + :(($sym[$(indpost...)])...) + else + arg + end end + return :( $fun($(newargs...)) ) end + # Tidy up indices, A[i,j][k] will be hit on different rounds... if @capture(ex, A_[ijk__]) if !(A isa Symbol) # this check allows some tests which have c[c] etc. diff --git a/test/four.jl b/test/four.jl index 9def167..2ed6dee 100644 --- a/test/four.jl +++ b/test/four.jl @@ -144,8 +144,8 @@ end @test z′[3,4] == Tuple(y[4,3,:]) # with other arguments, they have to come first at the moment: - @test_skip @cast z2[j,i] := tuple(i, j, y[i,j,:]...) - # @test z2[4,5] == (5, 4, y[5,4,:]...) + @cast z2[j,i] := tuple(i, j, y[i,j,:]...) + @test z2[4,5] == (5, 4, y[5,4,:]...) struct Quad x; y; z; t; end From 75fcfefe7ccfd8682e1546901ca0d9ae5cbe9991 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 4 Apr 2021 12:51:28 -0400 Subject: [PATCH 3/3] another --- src/macro.jl | 32 ++++++++++++++++++++++++-------- test/four.jl | 1 - 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/macro.jl b/src/macro.jl index f190283..f87ab09 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -620,11 +620,13 @@ end """ recursemacro(@reduce sum(i) A[i,j]) -> G[j] -Walks itself over RHS to look for `@reduce ...`, and replace with result, +Walks itself over RHS, originally to look for `@reduce ...`, and replace with result, pushing calculation steps into store. -Also a convenient place to tidy all indices, including e.g. `fun(M[:,j],N[j]).same[i']`. -And to handle naked indices, `i` => `axes(M,1)[i]` but not exactly like that. +Starts from the outside and works in, which makes it useful for other things: +* Handle naked indices, `i` => `axes(M,1)[i]` but not exactly like that, stopping before this sees `A[i]`. +* Catch splats so that `f(M[:,c]...)` can become `f.(eachrow(M)...)` not `(splat(f)).(eachcol(M))`. +* Tidy all indices, including e.g. `fun(M[:,j], N[j]).same[i']`. """ function recursemacro(ex::Expr, canon, store::NamedTuple, call::CallInfo) @@ -651,17 +653,31 @@ function recursemacro(ex::Expr, canon, store::NamedTuple, call::CallInfo) ex = scalar ? :($name) : :($name[$(ind...)]) end - # Handle splatted slices -- need to be caught early + # Handle splatted slices -- walking from inside outwards would slice the wrong way. if @capture(ex, fun_(args__)) && any(a -> @capture(a, (A_[ijk__]...)), args) && any(iscolon, ijk) newargs = map(args) do arg if @capture(arg, (A_[ijk__]...)) && any(iscolon, ijk) - revcode = map(i -> iscolon(i) ? :* : :(:), ijk) - sliced = :( TensorCast.sliceview($A, ($(revcode...),)) ) - sym = maybepush(sliced, store) indpost = filter(!iscolon, ijk) + if indexin(indpost, canon) == 1:length(indpost) + Aperm = A + revcode = map(i -> iscolon(i) ? :* : :(:), ijk) + else + perm = indexin(canon, ijk) + while isnothing(last(perm)) # trim nothings off end + pop!(perm) + end + indpost = canon[1:length(perm)] + revcode = vcat(map(_ -> :*, perm), fill(:(:), count(iscolon, ijk))) + for (d,i) in enumerate(ijk) # append positions of colons + iscolon(i) && push!(perm, d) + end + Aperm = :( TensorCast.transmute($A, $(Tuple(perm))) ) + end + sliced = :( TensorCast.sliceview($Aperm, ($(revcode...),)) ) + sym = maybepush(sliced, store) :(($sym[$(indpost...)])...) else - arg + recursemacro(arg, canon, store, call) end end return :( $fun($(newargs...)) ) diff --git a/test/four.jl b/test/four.jl index 2ed6dee..229a19d 100644 --- a/test/four.jl +++ b/test/four.jl @@ -147,7 +147,6 @@ end @cast z2[j,i] := tuple(i, j, y[i,j,:]...) @test z2[4,5] == (5, 4, y[5,4,:]...) - struct Quad x; y; z; t; end @cast z3[i,j] := Quad(y[i,:,j]...) @test z3[2,3] == Quad(y[2,:,3]...)