Skip to content

Commit 6c43de9

Browse files
mcabbottMichael Abbottdlfivefifty
authored
PermutedDimsArray, and layout for dense 1st index (#12)
* add PermutedDimsArray, and FirstMajor layout * change to UnitStride{D} for any dimension D * use Compat.jl * tidy & fix * smarter ntuple, rm Compat.jl * more tests Co-authored-by: Michael Abbott <me@escbook> Co-authored-by: Sheehan Olver <[email protected]>
1 parent 5729d6a commit 6c43de9

File tree

3 files changed

+138
-5
lines changed

3 files changed

+138
-5
lines changed

src/ArrayLayouts.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ else
4848
end
4949

5050
export materialize, materialize!, MulAdd, muladd!, Ldiv, Rdiv, Lmul, Rmul, lmul, rmul, ldiv, rdiv, mul, MemoryLayout, AbstractStridedLayout,
51-
DenseColumnMajor, ColumnMajor, ZerosLayout, FillLayout, AbstractColumnMajor, RowMajor, AbstractRowMajor,
51+
DenseColumnMajor, ColumnMajor, ZerosLayout, FillLayout, AbstractColumnMajor, RowMajor, AbstractRowMajor, UnitStride,
5252
DiagonalLayout, ScalarLayout, SymTridiagonalLayout, HermitianLayout, SymmetricLayout, TriangularLayout,
5353
UnknownLayout, AbstractBandedLayout, ApplyBroadcastStyle, ConjLayout, AbstractFillLayout,
5454
colsupport, rowsupport, layout_getindex, QLayout, LayoutArray, LayoutMatrix, LayoutVector
@@ -213,4 +213,4 @@ end
213213
return A
214214
end
215215

216-
end
216+
end

src/memorylayout.jl

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ abstract type AbstractRowMajor <: AbstractDecreasingStrides end
1818
struct DenseRowMajor <: AbstractRowMajor end
1919
struct RowMajor <: AbstractRowMajor end
2020
struct DecreasingStrides <: AbstractIncreasingStrides end
21+
struct UnitStride{D} <: AbstractStridedLayout end
2122
struct StridedLayout <: AbstractStridedLayout end
2223
struct ScalarLayout <: MemoryLayout end
2324

@@ -42,6 +43,31 @@ dispatch to BLAS and LAPACK routines if the memory layout is BLAS compatible and
4243
the element type is a `Float32`, `Float64`, `ComplexF32`, or `ComplexF64`.
4344
In this case, one must implement the strided array interface, which requires
4445
overrides of `strides(A::MyMatrix)` and `unknown_convert(::Type{Ptr{T}}, A::MyMatrix)`.
46+
47+
The complete list of more specialised types is as follows:
48+
```
49+
julia> using ArrayLayouts, AbstractTrees
50+
51+
julia> AbstractTrees.children(x::Type) = subtypes(x)
52+
53+
julia> print_tree(AbstractStridedLayout)
54+
AbstractStridedLayout
55+
├─ AbstractDecreasingStrides
56+
│ └─ AbstractRowMajor
57+
│ ├─ DenseRowMajor
58+
│ └─ RowMajor
59+
├─ AbstractIncreasingStrides
60+
│ ├─ AbstractColumnMajor
61+
│ │ ├─ ColumnMajor
62+
│ │ └─ DenseColumnMajor
63+
│ ├─ DecreasingStrides
64+
│ └─ IncreasingStrides
65+
├─ StridedLayout
66+
└─ UnitStride
67+
68+
julia> Base.show_supertypes(AbstractStridedLayout)
69+
AbstractStridedLayout <: MemoryLayout <: Any
70+
```
4571
"""
4672
AbstractStridedLayout
4773

@@ -157,7 +183,7 @@ MemoryLayout(::Type{<:ReshapedArray{T,N,A,DIMS}}) where {T,N,A,DIMS} = reshapedl
157183
@inline reshapedlayout(::DenseColumnMajor, _) = DenseColumnMajor()
158184

159185

160-
@inline MemoryLayout(A::Type{<:SubArray{T,N,P,I}}) where {T,N,P,I} =
186+
@inline MemoryLayout(A::Type{<:SubArray{T,N,P,I}}) where {T,N,P,I} =
161187
sublayout(MemoryLayout(P), I)
162188
sublayout(_1, _2) = UnknownLayout()
163189
sublayout(_1, _2, _3)= UnknownLayout()
@@ -257,6 +283,59 @@ transposelayout(::ConjLayout{ML}) where ML = ConjLayout{typeof(transposelayout(M
257283
adjointlayout(::Type{T}, M::MemoryLayout) where T = transposelayout(conjlayout(T, M))
258284

259285

286+
# Layouts of PermutedDimsArrays
287+
"""
288+
UnitStride{D}()
289+
290+
is returned by `MemoryLayout(A)` for arrays of `ndims(A) >= 3` which have `stride(A,D) == 1`.
291+
292+
`UnitStride{1}` is weaker than `ColumnMajor` in that it does not demand that the other
293+
strides are increasing, hence it is not a subtype of `AbstractIncreasingStrides`.
294+
To ensure that `stride(A,1) == 1`, you may dispatch on `Union{UnitStride{1}, AbstractColumnMajor}`
295+
to allow for both options. (With complex numbers, you may also need their `ConjLayout` versions.)
296+
297+
Likewise, both `UnitStride{ndims(A)}` and `AbstractRowMajor` have `stride(A, ndims(A)) == 1`.
298+
"""
299+
UnitStride
300+
301+
MemoryLayout(::Type{PermutedDimsArray{T,N,P,Q,S}}) where {T,N,P,Q,S} = permutelayout(MemoryLayout(S), Val(P))
302+
303+
permutelayout(::Any, perm) = UnknownLayout()
304+
permutelayout(::StridedLayout, perm) = StridedLayout()
305+
permutelayout(::ConjLayout{ML}, perm) where ML = ConjLayout{typeof(permutelayout(ML(), perm))}()
306+
307+
function permutelayout(layout::AbstractColumnMajor, ::Val{perm}) where {perm}
308+
issorted(perm) && return layout
309+
issorted(reverse(perm)) && return reverse(layout)
310+
D = sum(ntuple(dim -> perm[dim] == 1 ? dim : 0, length(perm)))
311+
return UnitStride{D}()
312+
end
313+
function permutelayout(layout::AbstractRowMajor, ::Val{perm}) where {perm}
314+
issorted(perm) && return layout
315+
issorted(reverse(perm)) && return reverse(layout)
316+
N = length(perm) # == ndims(A)
317+
D = sum(ntuple(dim -> perm[dim] == N ? dim : 0, N))
318+
return UnitStride{D}()
319+
end
320+
function permutelayout(layout::UnitStride{D0}, ::Val{perm}) where {D0, perm}
321+
D = sum(ntuple(dim -> perm[dim] == D0 ? dim : 0, length(perm)))
322+
return UnitStride{D}()
323+
end
324+
function permutelayout(layout::Union{IncreasingStrides,DecreasingStrides}, ::Val{perm}) where {perm}
325+
issorted(perm) && return layout
326+
issorted(reverse(perm)) && return reverse(layout)
327+
return StridedLayout()
328+
end
329+
330+
Base.reverse(::DenseRowMajor) = DenseColumnMajor()
331+
Base.reverse(::RowMajor) = ColumnMajor()
332+
Base.reverse(::DenseColumnMajor) = DenseRowMajor()
333+
Base.reverse(::ColumnMajor) = RowMajor()
334+
Base.reverse(::IncreasingStrides) = DecreasingStrides()
335+
Base.reverse(::DecreasingStrides) = IncreasingStrides()
336+
Base.reverse(::AbstractStridedLayout) = StridedLayout()
337+
338+
260339
# MemoryLayout of Symmetric/Hermitian
261340
"""
262341
SymmetricLayout{layout}()

test/test_layouts.jl

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
using ArrayLayouts, LinearAlgebra, FillArrays, Test
22
import ArrayLayouts: MemoryLayout, DenseRowMajor, DenseColumnMajor, StridedLayout,
3-
ConjLayout, RowMajor, ColumnMajor, UnknownLayout,
3+
ConjLayout, RowMajor, ColumnMajor, UnitStride,
44
SymmetricLayout, HermitianLayout, UpperTriangularLayout,
55
UnitUpperTriangularLayout, LowerTriangularLayout,
6-
UnitLowerTriangularLayout, ScalarLayout,
6+
UnitLowerTriangularLayout, ScalarLayout, UnknownLayout,
77
hermitiandata, symmetricdata, FillLayout, ZerosLayout,
88
DiagonalLayout, colsupport, rowsupport
99

@@ -178,4 +178,58 @@ struct FooNumber <: Number end
178178
@test colsupport(LowerTriangular(A),3) 3:5
179179
@test rowsupport(LowerTriangular(A),3) Base.OneTo(3)
180180
end
181+
182+
@testset "PermutedDimsArray" begin
183+
A = [1.0 2; 3 4]
184+
@test MemoryLayout(PermutedDimsArray(A, (1,2))) == DenseColumnMajor()
185+
@test MemoryLayout(PermutedDimsArray(A, (2,1))) == DenseRowMajor()
186+
@test MemoryLayout(transpose(PermutedDimsArray(A, (2,1)))) == DenseColumnMajor()
187+
@test MemoryLayout(adjoint(PermutedDimsArray(A, (2,1)))) == DenseColumnMajor()
188+
B = [1.0+im 2; 3 4]
189+
@test MemoryLayout(PermutedDimsArray(B, (2,1))) == DenseRowMajor()
190+
@test MemoryLayout(transpose(PermutedDimsArray(B, (2,1)))) == DenseColumnMajor()
191+
@test MemoryLayout(adjoint(PermutedDimsArray(B, (2,1)))) == ConjLayout{DenseColumnMajor}()
192+
193+
C = view(ones(10,20,30), 2:9, 3:18, 4:27);
194+
@test MemoryLayout(C) == ColumnMajor()
195+
@test MemoryLayout(PermutedDimsArray(C, (1,2,3))) == ColumnMajor()
196+
@test MemoryLayout(PermutedDimsArray(C, (1,3,2))) == UnitStride{1}()
197+
198+
@test MemoryLayout(PermutedDimsArray(C, (3,1,2))) == UnitStride{2}()
199+
@test MemoryLayout(PermutedDimsArray(C, (2,1,3))) == UnitStride{2}()
200+
201+
@test MemoryLayout(PermutedDimsArray(C, (3,2,1))) == RowMajor()
202+
@test MemoryLayout(PermutedDimsArray(C, (2,3,1))) == UnitStride{3}()
203+
204+
revC = PermutedDimsArray(C, (3,2,1));
205+
@test MemoryLayout(PermutedDimsArray(revC, (3,2,1))) == ColumnMajor()
206+
@test MemoryLayout(PermutedDimsArray(revC, (3,1,2))) == UnitStride{1}()
207+
208+
D = ones(10,20,30,40);
209+
@test MemoryLayout(D) == DenseColumnMajor()
210+
@test MemoryLayout(PermutedDimsArray(D, (1,2,3,4))) == DenseColumnMajor()
211+
@test MemoryLayout(PermutedDimsArray(D, (1,4,3,2))) == UnitStride{1}()
212+
213+
@test MemoryLayout(PermutedDimsArray(D, (4,1,3,2))) == UnitStride{2}()
214+
@test MemoryLayout(PermutedDimsArray(D, (2,1,4,3))) == UnitStride{2}()
215+
216+
@test MemoryLayout(PermutedDimsArray(D, (4,3,2,1))) == DenseRowMajor()
217+
@test MemoryLayout(PermutedDimsArray(D, (4,2,1,3))) == UnitStride{3}()
218+
219+
twoD = PermutedDimsArray(D, (3,1,2,4));
220+
MemoryLayout(PermutedDimsArray(twoD, (2,1,4,3))) == UnitStride{1}()
221+
222+
revD = PermutedDimsArray(D, (4,3,2,1));
223+
MemoryLayout(PermutedDimsArray(revD, (4,3,2,1))) == DenseColumnMajor()
224+
MemoryLayout(PermutedDimsArray(revD, (4,2,3,1))) == UnitStride{1}()
225+
226+
227+
issorted((1,2,3,4))
228+
# Fails on Julia 1.4, in tests. Could use BenchmarkTools.@ballocated instead.
229+
@test_skip 0 == @allocated issorted((1,2,3,4))
230+
reverse((1,2,3,4))
231+
@test_skip 0 == @allocated reverse((1,2,3,4))
232+
MemoryLayout(revD)
233+
@test 0 == @allocated MemoryLayout(revD)
234+
end
181235
end

0 commit comments

Comments
 (0)