Skip to content

Commit 8639776

Browse files
authored
Generalize the element type of BlockedUnitRange (#337)
* Allow more general BlockUnitRange element types * Restrict element type * Get tests passing * Fix some tests * Fix some doctests * Skip broken test in Julia v1.6 * Better support for unitful numbers * Fix tests * Stricter types in _BlockedUnitRange * Improve tests coverage
1 parent d02efe6 commit 8639776

File tree

2 files changed

+186
-43
lines changed

2 files changed

+186
-43
lines changed

src/blockaxis.jl

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function findblockindex(b::AbstractVector, k::Integer)
1616
bl = blocklasts(b)
1717
blockidx = _searchsortedfirst(bl, k)
1818
@assert blockindex != lastindex(bl) + 1 # guaranteed by the @boundscheck above
19-
prevblocklast = blockidx == firstindex(bl) ? first(b)-1 : bl[blockidx-1]
19+
prevblocklast = blockidx == firstindex(bl) ? first(b)-oneunit(eltype(b)) : bl[blockidx-1]
2020
local_index = k - prevblocklast
2121
return BlockIndex(blockidx, local_index)
2222
end
@@ -36,7 +36,7 @@ a vector of block lengths to a `BlockedUnitRange`.
3636
# Examples
3737
```jldoctest
3838
julia> blockedrange(2, [2,2,3]) # first value and block lengths
39-
3-blocked 7-element BlockedUnitRange{Vector{Int64}}:
39+
3-blocked 7-element BlockedUnitRange{Int64, Vector{Int64}}:
4040
2
4141
3
4242
@@ -50,31 +50,59 @@ julia> blockedrange(2, [2,2,3]) # first value and block lengths
5050
5151
See also [`BlockedOneTo`](@ref).
5252
"""
53-
struct BlockedUnitRange{CS} <: AbstractBlockedUnitRange{Int,CS}
54-
first::Int
53+
struct BlockedUnitRange{T<:Integer,CS} <: AbstractBlockedUnitRange{T,CS}
54+
first::T
5555
lasts::CS
5656
# assume that lasts is sorted, no checks carried out here
57-
global function _BlockedUnitRange(f, cs::CS) where CS
57+
global function _BlockedUnitRange(f::T, cs::CS) where {T,CS<:AbstractVector{T}}
5858
Base.require_one_based_indexing(cs)
59-
new{CS}(f, cs)
59+
return new{T,CS}(f, cs)
60+
end
61+
global function _BlockedUnitRange(f::T, cs::CS) where {T,CS<:Tuple{T,Vararg{T}}}
62+
return new{T,CS}(f, cs)
63+
end
64+
global function _BlockedUnitRange(f::T, cs::Tuple{}) where {T}
65+
return new{T,Tuple{}}(f, cs)
6066
end
6167
end
6268

63-
@inline _BlockedUnitRange(cs) = _BlockedUnitRange(1,cs)
69+
@inline function _BlockedUnitRange(f::T, cs::AbstractVector{S}) where {T,S}
70+
U = promote_type(T, S)
71+
return _BlockedUnitRange(convert(U, f), convert.(U, cs))
72+
end
73+
@inline function _BlockedUnitRange(f::T, cs::Tuple{S,Vararg{S}}) where {T,S}
74+
U = promote_type(T, S)
75+
return _BlockedUnitRange(convert(U, f), convert.(U, cs))
76+
end
77+
@inline function _BlockedUnitRange(f, cs::Tuple)
78+
return _BlockedUnitRange(f, promote(cs...))
79+
end
80+
@inline _BlockedUnitRange(cs::AbstractVector) = _BlockedUnitRange(oneunit(eltype(cs)), cs)
81+
@inline _BlockedUnitRange(cs::NTuple) = _BlockedUnitRange(oneunit(eltype(cs)), cs)
82+
_BlockedUnitRange(cs::Tuple) = _BlockedUnitRange(promote(cs...))
6483

6584
first(b::BlockedUnitRange) = b.first
6685
@inline blocklasts(a::BlockedUnitRange) = a.lasts
6786

6887
BlockedUnitRange(::BlockedUnitRange) = throw(ArgumentError("Forbidden due to ambiguity"))
88+
# Use `accumulate` instead of `cumsum` because it preserves the element type of the block lengths
89+
_blocklengths2blocklasts(blocks) = accumulate(+, blocks) # extra level to allow changing default accumulate behaviour
90+
# Use `cumsum` for fill arrays to output lazy `StepRangeLen` representation
91+
_blocklengths2blocklasts(blocks::Fill) = cumsum(blocks)
92+
_blocklengths2blocklasts(blocks::Ones) = cumsum(blocks)
93+
94+
@inline blockfirsts(a::AbstractBlockedUnitRange) = [first(a); @views(blocklasts(a)[1:end-1]) .+ oneunit(eltype(a))]
6995

70-
@inline blockfirsts(a::AbstractBlockedUnitRange) = [first(a); @views(blocklasts(a)[1:end-1]) .+ 1]
7196
# optimize common cases
7297
@inline function blockfirsts(a::AbstractBlockedUnitRange{<:Any,<:Union{Vector, RangeCumsum{<:Any, <:UnitRange}}})
7398
v = Vector{eltype(a)}(undef, length(blocklasts(a)))
7499
v[1] = first(a)
75-
v[2:end] .= @views(blocklasts(a)[oneto(end-1)]) .+ 1
100+
v[2:end] .= @views(blocklasts(a)[oneto(end-1)]) .+ oneunit(eltype(a))
76101
return v
77102
end
103+
@inline function blockfirsts(a::AbstractBlockedUnitRange{<:Any,<:Tuple})
104+
return (first(a), (blocklasts(a)[oneto(end-1)] .+ oneunit(eltype(a)))...)
105+
end
78106

79107
"""
80108
BlockedOneTo
@@ -124,8 +152,6 @@ BlockedOneTo(::BlockedOneTo) = throw(ArgumentError("Forbidden due to ambiguity")
124152

125153
axes(b::BlockedOneTo) = (b,)
126154

127-
_blocklengths2blocklasts(blocks) = cumsum(blocks) # extra level to allow changing default cumsum behaviour
128-
129155
"""
130156
blockedrange(blocklengths::Union{Tuple, AbstractVector})
131157
blockedrange(first::Int, blocklengths::Union{Tuple, AbstractVector})
@@ -144,27 +170,27 @@ julia> blockedrange([1,2])
144170
3
145171
146172
julia> blockedrange(2, (1,2))
147-
2-blocked 3-element BlockedUnitRange{Tuple{Int64, Int64}}:
173+
2-blocked 3-element BlockedUnitRange{Int64, Tuple{Int64, Int64}}:
148174
2
149175
150176
3
151177
4
152178
```
153179
"""
154180
@inline blockedrange(blocks::Union{Tuple,AbstractVector}) = BlockedOneTo(_blocklengths2blocklasts(blocks))
155-
@inline blockedrange(f::Int, blocks::Union{Tuple,AbstractVector}) = _BlockedUnitRange(f, f-1 .+ _blocklengths2blocklasts(blocks))
181+
@inline blockedrange(f::Integer, blocks::Union{Tuple,AbstractVector}) = _BlockedUnitRange(f, f-oneunit(f) .+ _blocklengths2blocklasts(blocks))
156182

157183
_diff(a::AbstractVector) = diff(a)
158184
_diff(a::Tuple) = diff(collect(a))
159-
@inline _blocklengths(a, bl, dbl) = isempty(bl) ? [dbl;] : [first(bl)-first(a)+1; dbl]
185+
@inline _blocklengths(a, bl, dbl) = isempty(bl) ? [dbl;] : [first(bl)-first(a)+oneunit(eltype(a)); dbl]
160186
@inline function _blocklengths(a::BlockedOneTo, bl::RangeCumsum, ::OrdinalRange)
161187
# the 1:0 is hardcoded here to enable conversions to a Base.OneTo
162188
isempty(bl) ? oftype(bl.range, 1:0) : bl.range
163189
end
164190
@inline _blocklengths(a, bl) = _blocklengths(a, bl, _diff(bl))
165191
@inline blocklengths(a::AbstractBlockedUnitRange) = _blocklengths(a, blocklasts(a))
166192

167-
length(a::AbstractBlockedUnitRange) = isempty(blocklasts(a)) ? 0 : Integer(last(blocklasts(a))-first(a)+1)
193+
length(a::AbstractBlockedUnitRange) = isempty(blocklasts(a)) ? zero(eltype(a)) : Integer(last(blocklasts(a))-first(a)+oneunit(eltype(a)))
168194

169195
"""
170196
blockisequal(a::AbstractUnitRange{Int}, b::AbstractUnitRange{Int})
@@ -217,15 +243,15 @@ function Base.convert(::Type{BlockedUnitRange}, axis::AbstractUnitRange{Int})
217243
f = first(axis)
218244
_BlockedUnitRange(f, _shift_blocklengths(axis, bl, f))
219245
end
220-
function Base.convert(::Type{BlockedUnitRange{CS}}, axis::AbstractUnitRange{Int}) where CS
246+
function Base.convert(::Type{BlockedUnitRange{T,CS}}, axis::AbstractUnitRange{Int}) where {T,CS}
221247
bl = blocklasts(axis)
222248
f = first(axis)
223-
_BlockedUnitRange(f, convert(CS, _shift_blocklengths(axis, bl, f)))
249+
_BlockedUnitRange(convert(T, f), convert(CS, _shift_blocklengths(axis, bl, f)))
224250
end
225251

226252
Base.unitrange(b::AbstractBlockedUnitRange) = first(b):last(b)
227253

228-
Base.promote_rule(::Type{<:AbstractBlockedUnitRange}, ::Type{Base.OneTo{Int}}) = UnitRange{Int}
254+
Base.promote_rule(::Type{<:AbstractBlockedUnitRange{T}}, ::Type{Base.OneTo{Int}}) where {T} = UnitRange{promote_type(T, Int)}
229255

230256
function Base.convert(::Type{BlockedOneTo}, axis::AbstractUnitRange{Int})
231257
first(axis) == 1 || throw(ArgumentError("first element of range is not 1"))
@@ -352,10 +378,10 @@ julia> blocksizes(A,2)
352378
blocksizes(A) = map(blocklengths, axes(A))
353379
blocksizes(A,i) = blocklengths(axes(A,i))
354380

355-
axes(b::AbstractBlockedUnitRange) = (BlockedOneTo(blocklasts(b) .- (first(b)-1)),)
381+
axes(b::AbstractBlockedUnitRange) = (BlockedOneTo(blocklasts(b) .- (first(b)-oneunit(eltype(b)))),)
356382
unsafe_indices(b::AbstractBlockedUnitRange) = axes(b)
357383
# ::Integer works around case where blocklasts might return different type
358-
last(b::AbstractBlockedUnitRange)::Integer = isempty(blocklasts(b)) ? first(b)-1 : last(blocklasts(b))
384+
last(b::AbstractBlockedUnitRange)::Integer = isempty(blocklasts(b)) ? first(b)-oneunit(eltype(b)) : last(blocklasts(b))
359385

360386
# view and indexing are identical for a unitrange
361387
view(b::AbstractBlockedUnitRange, K::Block{1}) = b[K]
@@ -367,19 +393,19 @@ view(b::AbstractBlockedUnitRange, K::Block{1}) = b[K]
367393
@boundscheck K in bax || throw(BlockBoundsError(b, k))
368394
S = first(bax)
369395
K == S && return first(b):first(cs)
370-
return cs[k-1]+1:cs[k]
396+
return cs[k-1]+oneunit(eltype(b)):cs[k]
371397
end
372398

373399
@propagate_inbounds function getindex(b::AbstractBlockedUnitRange, KR::BlockRange{1})
374400
cs = blocklasts(b)
375-
isempty(KR) && return _BlockedUnitRange(1,cs[1:0])
401+
isempty(KR) && return _BlockedUnitRange(oneunit(eltype(b)),cs[1:0])
376402
K,J = first(KR),last(KR)
377403
k,j = Integer(K),Integer(J)
378404
bax = blockaxes(b,1)
379405
@boundscheck K in bax || throw(BlockBoundsError(b,K))
380406
@boundscheck J in bax || throw(BlockBoundsError(b,J))
381407
K == first(bax) && return _BlockedUnitRange(first(b),cs[k:j])
382-
_BlockedUnitRange(cs[k-1]+1,cs[k:j])
408+
_BlockedUnitRange(cs[k-1]+oneunit(eltype(b)),cs[k:j])
383409
end
384410

385411
@propagate_inbounds function getindex(b::AbstractBlockedUnitRange, KR::BlockRange{1,Tuple{Base.OneTo{Int}}})
@@ -442,7 +468,7 @@ function findblock(b::AbstractUnitRange{Int}, k::Integer)
442468
end
443469

444470
"""
445-
blockfirsts(a::AbstractUnitRange{Int})
471+
blockfirsts(a::AbstractUnitRange{<:Integer})
446472
447473
Return the first index of each block of `a`.
448474
@@ -466,9 +492,9 @@ julia> blockfirsts(b)
466492
4
467493
```
468494
"""
469-
blockfirsts(a::AbstractUnitRange{Int}) = Ones{Int}(1)
495+
blockfirsts(a::AbstractUnitRange{<:Integer}) = Ones{eltype(a)}(1)
470496
"""
471-
blocklasts(a::AbstractUnitRange{Int})
497+
blocklasts(a::AbstractUnitRange{<:Integer})
472498
473499
Return the last index of each block of `a`.
474500
@@ -492,9 +518,9 @@ julia> blocklasts(b)
492518
6
493519
```
494520
"""
495-
blocklasts(a::AbstractUnitRange{Int}) = Fill(length(a),1)
521+
blocklasts(a::AbstractUnitRange{<:Integer}) = Fill(eltype(a)(length(a)),1)
496522
"""
497-
blocklengths(a::AbstractUnitRange{Int})
523+
blocklengths(a::AbstractUnitRange{<:Integer})
498524
499525
Return the length of each block of `a`.
500526
@@ -518,7 +544,7 @@ julia> blocklengths(b)
518544
3
519545
```
520546
"""
521-
blocklengths(a::AbstractUnitRange) = blocklasts(a) .- blockfirsts(a) .+ 1
547+
blocklengths(a::AbstractUnitRange{<:Integer}) = blocklasts(a) .- blockfirsts(a) .+ oneunit(eltype(a))
522548

523549
Base.summary(io::IO, a::AbstractBlockedUnitRange) = _block_summary(io, a)
524550

@@ -538,7 +564,6 @@ blockaxes(S::Base.Slice) = blockaxes(S.indices)
538564
_broadcaststyle(_) = Broadcast.DefaultArrayStyle{1}()
539565
Base.BroadcastStyle(::Type{<:AbstractBlockedUnitRange{<:Any,R}}) where R = _broadcaststyle(Base.BroadcastStyle(R))
540566

541-
542567
###
543568
# Special Fill/Range cases
544569
#
@@ -548,22 +573,22 @@ Base.BroadcastStyle(::Type{<:AbstractBlockedUnitRange{<:Any,R}}) where R = _broa
548573
_blocklengths2blocklasts(blocks::AbstractRange) = RangeCumsum(blocks)
549574
function blockfirsts(a::AbstractBlockedUnitRange{<:Any,Base.OneTo{Int}})
550575
first(a) == 1 || error("Offset axes not supported")
551-
Base.OneTo{Int}(length(blocklasts(a)))
576+
Base.OneTo{eltype(a)}(length(blocklasts(a)))
552577
end
553578
function blocklengths(a::AbstractBlockedUnitRange{<:Any,Base.OneTo{Int}})
554579
first(a) == 1 || error("Offset axes not supported")
555-
Ones{Int}(length(blocklasts(a)))
580+
Ones{eltype(a)}(length(blocklasts(a)))
556581
end
557582
function blockfirsts(a::AbstractBlockedUnitRange{<:Any,<:AbstractRange})
558583
st = step(blocklasts(a))
559584
first(a) == 1 || error("Offset axes not supported")
560-
@assert first(blocklasts(a))-first(a)+1 == st
561-
range(1; step=st, length=length(blocklasts(a)))
585+
@assert first(blocklasts(a))-first(a)+oneunit(eltype(a)) == st
586+
range(oneunit(eltype(a)); step=st, length=eltype(a)(length(blocklasts(a))))
562587
end
563588
function blocklengths(a::AbstractBlockedUnitRange{<:Any,<:AbstractRange})
564589
st = step(blocklasts(a))
565590
first(a) == 1 || error("Offset axes not supported")
566-
@assert first(blocklasts(a))-first(a)+1 == st
591+
@assert first(blocklasts(a))-first(a)+oneunit(eltype(a)) == st
567592
Fill(st,length(blocklasts(a)))
568593
end
569594

0 commit comments

Comments
 (0)