diff --git a/Project.toml b/Project.toml index d852c27..cec5f57 100644 --- a/Project.toml +++ b/Project.toml @@ -5,21 +5,21 @@ version = "0.2.1" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" -LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LazyStack = "1fad7336-0346-5a1a-a56f-a06ba010965b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Compat = "2.2, 3" -LazyArrays = "0.12, 0.13, 0.14, 0.15, 0.16" LazyStack = "0.0.4, 0.0.5, 0.0.6, 0.0.7, 0.0.8" MacroTools = "0.5" OffsetArrays = "0.11, 1.0" +Requires = "0.5, 1" StaticArrays = "0.10, 0.11, 0.12" ZygoteRules = "0.1, 0.2" julia = "1" @@ -28,10 +28,11 @@ julia = "1" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Einsum = "b7d42ee7-0b51-5a75-98ca-779d3107e4c0" JuliennedArrays = "5cadff95-7770-533d-a838-a1bf817ee6e0" +LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Compat", "Einsum", "JuliennedArrays", "LoopVectorization", "Statistics", "Strided"] +test = ["Test", "Compat", "Einsum", "JuliennedArrays", "LazyArrays", "LoopVectorization", "Statistics", "Strided"] diff --git a/docs/src/options.md b/docs/src/options.md index d23b32b..322f9fc 100644 --- a/docs/src/options.md +++ b/docs/src/options.md @@ -79,11 +79,13 @@ In the following example, the product `V .* V' .* V3` contains about 1GB of data the writing of which is avoided by giving the option `lazy`: ```julia +using LazyArrays # you must now load this package V = rand(500); V3 = reshape(V,1,1,:); @time @reduce W[i] := sum(j,k) V[i]*V[j]*V[k]; # 0.6 seconds, 950 MB @time @reduce W[i] := sum(j,k) V[i]*V[j]*V[k] lazy; # 0.025 s, 5 KB ``` +However, right now this gives `3.7 s (250 M allocations, 9 GB)`, something is broken! The package [Strided.jl](https://github.com/Jutho/Strided.jl) can apply multi-threading to broadcasting, and some other magic. You can enable it with the option `strided`, like this: diff --git a/src/TensorCast.jl b/src/TensorCast.jl index ded64a6..6950525 100644 --- a/src/TensorCast.jl +++ b/src/TensorCast.jl @@ -1,6 +1,12 @@ module TensorCast +# This speeds up loading a bit, on Julia 1.5, about 1s in my test. +# https://github.com/JuliaPlots/Plots.jl/pull/2544/files +if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@optlevel")) + @eval Base.Experimental.@optlevel 1 +end + export @cast, @reduce, @matmul, @pretty using MacroTools, StaticArrays, Compat @@ -10,11 +16,28 @@ include("macro.jl") include("pretty.jl") include("string.jl") -include("slice.jl") # slice, glue, etc -include("view.jl") # orient, Reverse{d} etc -include("lazy.jl") # LazyCast -include("static.jl") # StaticArrays +module Fast # shield non-macro code from @optlevel 1 + using LinearAlgebra, StaticArrays, Compat + + include("slice.jl") # slice, glue, etc + export sliceview, slicecopy, glue, glue!, red_glue, cat_glue, copy_glue, lazy_glue, iscodesorted, countcolons + + include("view.jl") # orient, Reverse{d} etc + export diagview, orient, rview, mul!, star, PermuteDims, Reverse, Shuffle + + include("static.jl") # StaticArrays + export static_slice, static_glue + +end +using .Fast +const mul! = Fast.mul! + +using Requires + +@init @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin + include("lazy.jl") # LazyCast # this costs about 3s in my test, 3.8s -> 7.7s +end -include("warm.jl") +include("warm.jl") # worth 2s in my test end # module diff --git a/src/lazy.jl b/src/lazy.jl index 68dba91..acf9b9c 100644 --- a/src/lazy.jl +++ b/src/lazy.jl @@ -1,5 +1,5 @@ -import LazyArrays +import .LazyArrays #= The macro option "lazy" always produces things like sum(@__dot__(lazy(x+y))) diff --git a/src/macro.jl b/src/macro.jl index eb7ed2f..cf6e61f 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -234,6 +234,8 @@ This mostly aims to re-work the given expression into `some(steps(A))[i,j]`, but also pushes `A = f(x)` into `store.top`. """ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) + @nospecialize ex + # This acts only on single indexing expressions: if @capture(ex, A_{ijk__}) static=true @@ -378,6 +380,7 @@ target dims not correctly handled yet -- what do I want? TODO Simple glue / stand. does not permutedims, but broadcasting may have to... avoid twice? """ function standardglue(ex, target, store::NamedTuple, call::CallInfo) + @nospecialize ex # The sole target here is indexing expressions: if @capture(ex, A_[inner__]) @@ -469,6 +472,7 @@ This beings the expression to have target indices, by permutedims and if necessary broadcasting, always using `readycast()`. """ function targetcast(ex, target, store::NamedTuple, call::CallInfo) + @nospecialize ex # If just one naked expression, then we won't broadcast: if @capture(ex, A_[ijk__]) @@ -503,6 +507,7 @@ end This is walked over the expression to prepare for `@__dot__` etc, by `targetcast()`. """ function readycast(ex, target, store::NamedTuple, call::CallInfo) + @nospecialize ex # Scalar functions can be protected entirely from broadcasting: # TODO this means A[i,j] + rand()/10 doesn't work, /(...,10) is a function! @@ -578,6 +583,7 @@ If there are more than two factors, it recurses, and you get `(A*B) * C`, or perhaps tuple `(A*B, C)`. """ function matmultarget(ex, target, parsed, store::NamedTuple, call::CallInfo) + @nospecialize ex @capture(ex, A_ * B_ * C__ | *(A_, B_, C__) ) || throw(MacroError("can't @matmul that!", call)) @@ -631,6 +637,7 @@ pushing calculation steps into store. Also a convenient place to tidy all indices, including e.g. `fun(M[:,j],N[j]).same[i']`. """ function recursemacro(ex, store::NamedTuple, call::CallInfo) + @nospecialize ex # Actually look for recursion if @capture(ex, @reduce(subex__) ) @@ -675,6 +682,8 @@ This saves to `store` the sizes of all input tensors, and their sub-slices if an however it should not destroy this so that `sz_j` can be got later. """ function rightsizes(ex, store::NamedTuple, call::CallInfo) + @nospecialize ex + :recurse in call.flags && return nothing # outer version took care of this if @capture(ex, A_[outer__][inner__] | A_[outer__]{inner__} ) @@ -1115,8 +1124,7 @@ end tensorprimetidy(v::Vector) = Any[ tensorprimetidy(x) for x in v ] function tensorprimetidy(ex) - MacroTools.postwalk(ex) do x - + MacroTools.postwalk(ex) do @nospecialize x @capture(x, ((ij__,) \ k_) ) && return :( ($(ij...),$k) ) @capture(x, i_ \ j_ ) && return :( ($i,$j) ) @@ -1172,7 +1180,7 @@ containsindexing(s) = false function containsindexing(ex::Expr) flag = false # MacroTools.postwalk(x -> @capture(x, A_[ijk__]) && (flag=true), ex) - MacroTools.postwalk(ex) do x + MacroTools.postwalk(ex) do @nospecialize x # @capture(x, A_[ijk__]) && !(all(isconstant, ijk)) && (flag=true) if @capture(x, A_[ijk__]) # @show x ijk # TODO this is a bit broken? @pretty @cast Z[i,j] := W[i] * exp(X[1][i] - X[2][j]) @@ -1185,7 +1193,7 @@ end listindices(s::Symbol) = [] function listindices(ex::Expr) list = [] - MacroTools.postwalk(ex) do x + MacroTools.postwalk(ex) do @nospecialize x if @capture(x, A_[ijk__]) flat, _ = indexparse(nothing, ijk) push!(list, flat) diff --git a/test/runtests.jl b/test/runtests.jl index 22bf376..8abf1a9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ using StaticArrays using OffsetArrays using Einsum using Strided +using LazyArrays using Compat if VERSION >= v"1.1" using LoopVectorization