Skip to content

Commit 7ba14d7

Browse files
committed
Introduce switch to make scalar indexing error and define necessary
methods to make tests pass without hitting scalar indexing.
1 parent f509f59 commit 7ba14d7

File tree

8 files changed

+161
-58
lines changed

8 files changed

+161
-58
lines changed

src/DistributedArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using LinearAlgebra
88

99
import Base: +, -, *, div, mod, rem, &, |, xor
1010
import Base.Callable
11-
import LinearAlgebra: axpy!, dot, norm,
11+
import LinearAlgebra: axpy!, dot, norm
1212

1313
import Primes
1414
import Primes: factor

src/core.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ end
5555
5656
Get the vector of processes storing pieces of DArray `d`.
5757
"""
58-
Distributed.procs(d::DArray) = d.pids
58+
Distributed.procs(d::DArray) = d.pids
59+
Distributed.procs(d::SubDArray) = procs(parent(d))
5960

6061
"""
6162
localpart(A)

src/darray.jl

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,41 @@ function localindices(d::DArray)
364364
return d.indices[lpidx]
365365
end
366366

367-
# find which piece holds index (I...)
368-
locate(d::DArray, I::Int...) =
369-
ntuple(i -> searchsortedlast(d.cuts[i], I[i]), ndims(d))
367+
# Equality
368+
function Base.:(==)(d::DArray{<:Any,<:Any,A}, a::AbstractArray) where A
369+
if size(d) != size(a)
370+
return false
371+
else
372+
b = asyncmap(procs(d)) do p
373+
remotecall_fetch(p) do
374+
localpart(d) == A(a[localindices(d)...])
375+
end
376+
end
377+
return all(b)
378+
end
379+
end
380+
Base.:(==)(d::SubDArray, a::AbstractArray) = copy(d) == a
381+
Base.:(==)(a::AbstractArray, d::DArray) = d == a
382+
Base.:(==)(a::AbstractArray, d::SubDArray) = d == a
383+
Base.:(==)(d1::DArray, d2::DArray) = invoke(==, Tuple{DArray, AbstractArray}, d1, d2)
384+
Base.:(==)(d1::SubDArray, d2::DArray) = copy(d1) == d2
385+
Base.:(==)(d1::DArray, d2::SubDArray) = d1 == copy(d2)
386+
Base.:(==)(d1::SubDArray, d2::SubDArray) = copy(d1) == copy(d2)
387+
388+
"""
389+
locate(d::DArray, I::Int...)
390+
391+
Determine the index of `procs(d)` that hold element `I`.
392+
"""
393+
function locate(d::DArray, I::Int...)
394+
ntuple(ndims(d)) do i
395+
fi = searchsortedlast(d.cuts[i], I[i])
396+
if fi >= length(d.cuts[i])
397+
throw(ArgumentError("element not contained in array"))
398+
end
399+
return fi
400+
end
401+
end
370402

371403
chunk(d::DArray{T,N,A}, i...) where {T,N,A} = remotecall_fetch(localpart, d.pids[i...], d)::A
372404

@@ -479,15 +511,15 @@ end
479511
function (::Type{Array{S,N}})(s::SubDArray{T,N}) where {S,T,N}
480512
I = s.indices
481513
d = s.parent
482-
if isa(I,Tuple{Vararg{UnitRange{Int}}}) && S<:T && T<:S
514+
if isa(I,Tuple{Vararg{UnitRange{Int}}}) && S<:T && T<:S && !isempty(s)
483515
l = locate(d, map(first, I)...)
484516
if isequal(d.indices[l...], I)
485517
# SubDArray corresponds to a chunk
486518
return chunk(d, l...)
487519
end
488520
end
489521
a = Array{S}(undef, size(s))
490-
a[[1:size(a,i) for i=1:N]...] .= s
522+
a[[1:size(a,i) for i=1:N]...] = s
491523
return a
492524
end
493525

@@ -540,15 +572,15 @@ end
540572

541573
function Base.getindex(d::DArray, i::Int)
542574
_scalarindexingallowed()
543-
return getindex_tuple(d, CartesianIndices(d)[i])
575+
return getindex_tuple(d, Tuple(CartesianIndices(d)[i]))
544576
end
545577
function Base.getindex(d::DArray, i::Int...)
546578
_scalarindexingallowed()
547579
return getindex_tuple(d, i)
548580
end
549581

550582
Base.getindex(d::DArray) = d[1]
551-
Base.getindex(d::DArray, I::Union{Int,UnitRange{Int},Colon,Vector{Int},StepRange{Int,Int}}...) = view(d, I...)
583+
Base.getindex(d::SubOrDArray, I::Union{Int,UnitRange{Int},Colon,Vector{Int},StepRange{Int,Int}}...) = view(d, I...)
552584

553585
function Base.isassigned(D::DArray, i::Integer...)
554586
try
@@ -564,15 +596,15 @@ function Base.isassigned(D::DArray, i::Integer...)
564596
end
565597

566598

567-
Base.copyto!(dest::SubOrDArray, src::SubOrDArray) = begin
599+
function Base.copyto!(dest::SubOrDArray, src::AbstractArray)
568600
asyncmap(procs(dest)) do p
569601
remotecall_fetch(p) do
570-
localpart(dest)[:] = src[localindices(dest)...]
602+
ldest = localpart(dest)
603+
ldest[:] = Array(view(src, localindices(dest)...))
571604
end
572605
end
573606
return dest
574607
end
575-
Base.copy!(dest::SubOrDArray, src::SubOrDArray) = copyto!(dest, src)
576608

577609
function Base.deepcopy(src::DArray)
578610
dest = similar(src)

src/linalg.jl

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
function Base.copy(D::Adjoint{T,<:DArray{T,2}}) where T
1+
function Base.copy(Dadj::Adjoint{T,<:DArray{T,2}}) where T
2+
D = parent(Dadj)
23
DArray(reverse(size(D)), procs(D)) do I
34
lp = Array{T}(undef, map(length, I))
45
rp = convert(Array, D[reverse(I)...])
56
adjoint!(lp, rp)
67
end
78
end
89

9-
function Base.copy(D::Transpose{T,<:DArray{T,2}}) where T
10+
function Base.copy(Dtr::Transpose{T,<:DArray{T,2}}) where T
11+
D = parent(Dtr)
1012
DArray(reverse(size(D)), procs(D)) do I
1113
lp = Array{T}(undef, map(length, I))
1214
rp = convert(Array, D[reverse(I)...])
@@ -49,7 +51,7 @@ function dot(x::DVector, y::DVector)
4951
return reduce(+, results)
5052
end
5153

52-
function norm(x::DVector, p::Real = 2)
54+
function norm(x::DArray, p::Real = 2)
5355
results = []
5456
@sync begin
5557
for pp in procs(x)
@@ -83,7 +85,7 @@ function add!(dest, src, scale = one(dest[1]))
8385
return dest
8486
end
8587

86-
function A_mul_B!::Number, A::DMatrix, x::AbstractVector, β::Number, y::DVector)
88+
function mul!(y::DVector, A::DMatrix, x::AbstractVector, α::Number = 1, β::Number = 0)
8789

8890
# error checks
8991
if size(A, 2) != length(x)
@@ -106,11 +108,14 @@ function A_mul_B!(α::Number, A::DMatrix, x::AbstractVector, β::Number, y::DVec
106108

107109
# Scale y if necessary
108110
if β != one(β)
109-
@sync for p in y.pids
110-
if β != zero(β)
111-
@async remotecall_fetch(y -> (rmul!(localpart(y), β); nothing), p, y)
112-
else
113-
@async remotecall_fetch(y -> (fill!(localpart(y), 0); nothing), p, y)
111+
asyncmap(procs(y)) do p
112+
remotecall_fetch(p) do
113+
if !iszero(β)
114+
rmul!(localpart(y), β)
115+
else
116+
fill!(localpart(y), 0)
117+
end
118+
return nothing
114119
end
115120
end
116121
end
@@ -127,7 +132,9 @@ function A_mul_B!(α::Number, A::DMatrix, x::AbstractVector, β::Number, y::DVec
127132
return y
128133
end
129134

130-
function Ac_mul_B!::Number, A::DMatrix, x::AbstractVector, β::Number, y::DVector)
135+
function mul!(y::DVector, adjA::Adjoint{<:Number,<:DMatrix}, x::AbstractVector, α::Number = 1, β::Number = 0)
136+
137+
A = parent(adjA)
131138

132139
# error checks
133140
if size(A, 1) != length(x)
@@ -148,11 +155,14 @@ function Ac_mul_B!(α::Number, A::DMatrix, x::AbstractVector, β::Number, y::DVe
148155

149156
# Scale y if necessary
150157
if β != one(β)
151-
@sync for p in y.pids
152-
if β != zero(β)
153-
@async remotecall_fetch(() -> (rmul!(localpart(y), β); nothing), p)
154-
else
155-
@async remotecall_fetch(() -> (fill!(localpart(y), 0); nothing), p)
158+
asyncmap(procs(y)) do p
159+
remotecall_fetch(p) do
160+
if !iszero(β)
161+
rmul!(localpart(y), β)
162+
else
163+
fill!(localpart(y), 0)
164+
end
165+
return nothing
156166
end
157167
end
158168
end
@@ -189,7 +199,7 @@ function LinearAlgebra.rmul!(DA::DMatrix, D::Diagonal)
189199
end
190200

191201
# Level 3
192-
function _matmatmul!(α::Number, A::DMatrix, B::AbstractMatrix, β::Number, C::DMatrix, tA)
202+
function _matmatmul!(C::DMatrix, A::DMatrix, B::AbstractMatrix, α::Number, β::Number, tA)
193203
# error checks
194204
Ad1, Ad2 = (tA == 'N') ? (1,2) : (2,1)
195205
mA, nA = (size(A, Ad1), size(A, Ad2))
@@ -254,40 +264,60 @@ function _matmatmul!(α::Number, A::DMatrix, B::AbstractMatrix, β::Number, C::D
254264
return C
255265
end
256266

257-
A_mul_B!::Number, A::DMatrix, B::AbstractMatrix, β::Number, C::DMatrix) = _matmatmul!(α, A, B, β, C, 'N')
258-
Ac_mul_B!::Number, A::DMatrix, B::AbstractMatrix, β::Number, C::DMatrix) = _matmatmul!(α, A, B, β, C, 'C')
259-
At_mul_B!::Number, A::DMatrix, B::AbstractMatrix, β::Number, C::DMatrix) = _matmatmul!(α, A, B, β, C, 'T')
260-
At_mul_B!(C::DMatrix, A::DMatrix, B::AbstractMatrix) = At_mul_B!(one(eltype(C)), A, B, zero(eltype(C)), C)
267+
mul!(C::DMatrix, A::DMatrix, B::AbstractMatrix, α::Number = 1, β::Number = 0) = _matmatmul!(C, A, B, α, β, 'N')
268+
mul!(C::DMatrix, A::Adjoint{<:Number,<:DMatrix}, B::AbstractMatrix, α::Number = 1, β::Number = 0) = _matmatmul!(C, parent(A), B, α, β, 'C')
269+
mul!(C::DMatrix, A::Transpose{<:Number,<:DMatrix}, B::AbstractMatrix, α::Number = 1, β::Number = 0) = _matmatmul!(C, parent(A), B, α, β, 'T')
261270

262271
_matmul_op = (t,s) -> t*s + t*s
263272

264273
function Base.:*(A::DMatrix, x::AbstractVector)
265274
T = Base.promote_op(_matmul_op, eltype(A), eltype(x))
266275
y = DArray(I -> Array{T}(undef, map(length, I)), (size(A, 1),), procs(A)[:,1], (size(procs(A), 1),))
267-
return A_mul_B!(one(T), A, x, zero(T), y)
276+
return mul!(y, A, x)
268277
end
269278
function Base.:*(A::DMatrix, B::AbstractMatrix)
270279
T = Base.promote_op(_matmul_op, eltype(A), eltype(B))
271280
C = DArray(I -> Array{T}(undef, map(length, I)),
272281
(size(A, 1), size(B, 2)),
273282
procs(A)[:,1:min(size(procs(A), 2), size(procs(B), 2))],
274283
(size(procs(A), 1), min(size(procs(A), 2), size(procs(B), 2))))
275-
return A_mul_B!(one(T), A, B, zero(T), C)
284+
return mul!(C, A, B)
285+
end
286+
287+
function Base.:*(adjA::Adjoint{<:Any,<:DMatrix}, x::AbstractVector)
288+
A = parent(adjA)
289+
T = Base.promote_op(_matmul_op, eltype(A), eltype(x))
290+
y = DArray(I -> Array{T}(undef, map(length, I)),
291+
(size(A, 2),),
292+
procs(A)[1,:],
293+
(size(procs(A), 2),))
294+
return mul!(y, adjA, x)
295+
end
296+
function Base.:*(adjA::Adjoint{<:Any,<:DMatrix}, B::AbstractMatrix)
297+
A = parent(adjA)
298+
T = Base.promote_op(_matmul_op, eltype(A), eltype(B))
299+
C = DArray(I -> Array{T}(undef, map(length, I)), (size(A, 2),
300+
size(B, 2)),
301+
procs(A)[1:min(size(procs(A), 1), size(procs(B), 2)),:],
302+
(size(procs(A), 2), min(size(procs(A), 1), size(procs(B), 2))))
303+
return mul!(C, adjA, B)
276304
end
277305

278-
function Ac_mul_B(A::DMatrix, x::AbstractVector)
306+
function Base.:*(trA::Transpose{<:Any,<:DMatrix}, x::AbstractVector)
307+
A = parent(trA)
279308
T = Base.promote_op(_matmul_op, eltype(A), eltype(x))
280309
y = DArray(I -> Array{T}(undef, map(length, I)),
281310
(size(A, 2),),
282311
procs(A)[1,:],
283312
(size(procs(A), 2),))
284-
return Ac_mul_B!(one(T), A, x, zero(T), y)
313+
return mul!(y, trA, x)
285314
end
286-
function Ac_mul_B(A::DMatrix, B::AbstractMatrix)
315+
function Base.:*(trA::Transpose{<:Any,<:DMatrix}, B::AbstractMatrix)
316+
A = parent(trA)
287317
T = Base.promote_op(_matmul_op, eltype(A), eltype(B))
288318
C = DArray(I -> Array{T}(undef, map(length, I)), (size(A, 2),
289319
size(B, 2)),
290320
procs(A)[1:min(size(procs(A), 1), size(procs(B), 2)),:],
291321
(size(procs(A), 2), min(size(procs(A), 1), size(procs(B), 2))))
292-
return Ac_mul_B!(one(T), A, B, zero(T), C)
322+
return mul!(C, trA, B)
293323
end

src/mapreduce.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ import SparseArrays: nnz
55

66
Base.map(f, d0::DArray, ds::AbstractArray...) = broadcast(f, d0, ds...)
77

8-
function Base.map!(f::F, dest::DArray, src::DArray) where {F}
8+
function Base.map!(f::F, dest::DArray, src::DArray{<:Any,<:Any,A}) where {F,A}
99
asyncmap(procs(dest)) do p
1010
remotecall_fetch(p) do
11-
map!(f, localpart(dest), src[localindices(dest)...])
11+
map!(f, localpart(dest), A(view(src, localindices(dest)...)))
1212
return nothing
1313
end
1414
end
@@ -53,7 +53,7 @@ rewrite_local(x) = x
5353

5454
function Base.reduce(f, d::DArray)
5555
results = asyncmap(procs(d)) do p
56-
remotecall_fetch(p, f, d) do (f, d)
56+
remotecall_fetch(p) do
5757
return reduce(f, localpart(d))
5858
end
5959
end
@@ -122,12 +122,39 @@ function Base.mapreducedim!(f, op, R::DArray, A::DArray)
122122
end
123123
region = tuple(collect(1:ndims(A))[[size(R)...] .!= [size(A)...]]...)
124124
if isempty(region)
125-
return copy!(R, A)
125+
return copyto!(R, A)
126126
end
127127
B = mapreducedim_within(f, op, A, region)
128128
return mapreducedim_between!(identity, op, R, B, region)
129129
end
130130

131+
function Base._all(f, A::DArray, ::Colon)
132+
B = asyncmap(procs(A)) do p
133+
remotecall_fetch(p) do
134+
all(f, localpart(A))
135+
end
136+
end
137+
return all(B)
138+
end
139+
140+
function Base._any(f, A::DArray, ::Colon)
141+
B = asyncmap(procs(A)) do p
142+
remotecall_fetch(p) do
143+
any(f, localpart(A))
144+
end
145+
end
146+
return any(B)
147+
end
148+
149+
function Base.count(f, A::DArray)
150+
B = asyncmap(procs(A)) do p
151+
remotecall_fetch(p) do
152+
count(f, localpart(A))
153+
end
154+
end
155+
return sum(B)
156+
end
157+
131158
function nnz(A::DArray)
132159
B = asyncmap(A.pids) do p
133160
remotecall_fetch(nnzlocalpart, p, A)

src/serialize.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ function Serialization.serialize(S::AbstractSerializer, d::DArray{T,N,A}) where
22
# Only send the ident for participating workers - we expect the DArray to exist in the
33
# remote registry. DO NOT send the localpart.
44
destpid = worker_id_from_socket(S.io)
5-
serialize_type(S, typeof(d))
5+
Serialization.serialize_type(S, typeof(d))
66
if (destpid in d.pids) || (destpid == d.id[1])
77
serialize(S, (true, d.id)) # (id_only, id)
88
else
@@ -64,7 +64,7 @@ function Serialization.serialize(S::AbstractSerializer, s::DestinationSerializer
6464
pid = worker_id_from_socket(S.io)
6565
pididx = findfirst(isequal(pid), s.pids)
6666
@assert pididx !== nothing
67-
serialize_type(S, typeof(s))
67+
Serialization.serialize_type(S, typeof(s))
6868
serialize(S, s.generate(pididx))
6969
end
7070

0 commit comments

Comments
 (0)