Skip to content

Commit 8fd16ef

Browse files
committed
add max_maximum and min_minimum
1 parent c1f5569 commit 8fd16ef

File tree

2 files changed

+57
-22
lines changed

2 files changed

+57
-22
lines changed

base/reduce.jl

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -452,31 +452,66 @@ julia> prod(1:20)
452452
prod(a) = mapreduce(identity, mul_prod, a)
453453

454454
## maximum & minimum
455-
_fast(::typeof(max), x, y) = ifelse(x == x,
456-
ifelse(x > y, x, y),
457-
x)
458-
_fast(::typeof(min), x, y) = ifelse(x == x,
459-
ifelse(x < y, x, y),
460-
x)
461-
462-
function mapreduce_impl(f, op::Union{typeof(max), typeof(min)},
455+
456+
"""
457+
max_maximum(x,y)
458+
459+
Compute the bigger of two elements. This function dose the same thing as `max(x,y)`, except that
460+
* It is slightly faster
461+
* The order of `-0.0` and `0.0` is undefined.
462+
463+
See also [`max`](@ref), [`min_minimum`](@ref)
464+
"""
465+
function max_maximum(x,y)
466+
ifelse(isnan(x),
467+
x,
468+
ifelse(x > y, x, y))
469+
end
470+
471+
"""
472+
min_minimum(x,y)
473+
474+
Variant of [`max_maximum`](@ref) for computing minima.
475+
"""
476+
function min_minimum(x,y)
477+
ifelse(isnan(x),
478+
x,
479+
ifelse(x < y, x, y))
480+
end
481+
482+
function mapreduce_impl(f, op::Union{typeof(max_maximum), typeof(min_minimum)},
463483
A::AbstractArray, first::Int, last::Int)
464484
a1 = @inbounds A[first]
465-
v = mapreduce_first(f, op, a1)
485+
v1 = mapreduce_first(f, op, a1)
486+
v2 = v3 = v4 = v1
466487
chunk_len = 256
467-
for start in (first + 1):chunk_len:last
468-
v == v || return v
469-
stop = min(start + chunk_len-1, last)
470-
@simd for i in start:stop
471-
@inbounds ai = A[i]
472-
v = _fast(op, v, f(ai))
488+
start = first
489+
stop = start + chunk_len - 4
490+
while stop <= last
491+
isnan(v1) && return v1
492+
isnan(v2) && return v2
493+
isnan(v3) && return v3
494+
isnan(v4) && return v4
495+
@inbounds for i in start:4:stop
496+
v1 = op(v1, f(A[i+1]))
497+
v2 = op(v2, f(A[i+2]))
498+
v3 = op(v3, f(A[i+3]))
499+
v4 = op(v4, f(A[i+4]))
473500
end
501+
start = stop
502+
stop = start + chunk_len - 4
503+
end
504+
v = op(op(v1,v2),op(v3,v4))
505+
start += 1
506+
for i in start:last
507+
@inbounds ai = A[i]
508+
v = op(v, f(A[i]))
474509
end
475510
v
476511
end
477512

478-
maximum(f, a) = mapreduce(f, max, a)
479-
minimum(f, a) = mapreduce(f, min, a)
513+
maximum(f, a) = mapreduce(f, max_maximum, a)
514+
minimum(f, a) = mapreduce(f, min_minimum, a)
480515

481516
"""
482517
maximum(itr)
@@ -492,7 +527,7 @@ julia> maximum([1,2,3])
492527
3
493528
```
494529
"""
495-
maximum(a) = mapreduce(identity, max, a)
530+
maximum(a) = mapreduce(identity, max_maximum, a)
496531

497532
"""
498533
minimum(itr)
@@ -508,7 +543,7 @@ julia> minimum([1,2,3])
508543
1
509544
```
510545
"""
511-
minimum(a) = mapreduce(identity, min, a)
546+
minimum(a) = mapreduce(identity, min_minimum, a)
512547

513548
## all & any
514549

base/reducedim.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function _reducedim_init(f, op, fv, fop, A, region)
125125
end
126126

127127
# initialization when computing minima and maxima requires a little care
128-
for (f1, f2, initval) in ((:min, :max, :Inf), (:max, :min, :(-Inf)))
128+
for (f1, f2, initval) in ((:min_minimum, :max_maximum, :Inf), (:max_maximum, :min_minimum, :(-Inf)))
129129
@eval function reducedim_init(f, op::typeof($f1), A::AbstractArray, region)
130130
# First compute the reduce indices. This will throw an ArgumentError
131131
# if any region is invalid
@@ -642,7 +642,7 @@ julia> any!([1 1], A)
642642
any!(r, A)
643643

644644
for (fname, _fname, op) in [(:sum, :_sum, :add_sum), (:prod, :_prod, :mul_prod),
645-
(:maximum, :_maximum, :max), (:minimum, :_minimum, :min)]
645+
(:maximum, :_maximum, :max_maximum), (:minimum, :_minimum, :min_minimum)]
646646
@eval begin
647647
# User-facing methods with keyword arguments
648648
@inline ($fname)(a::AbstractArray; dims=:) = ($_fname)(a, dims)
@@ -662,7 +662,7 @@ all(f::Function, a::AbstractArray; dims=:) = _all(f, a, dims)
662662
_all(a, ::Colon) = _all(identity, a, :)
663663

664664
for (fname, op) in [(:sum, :add_sum), (:prod, :mul_prod),
665-
(:maximum, :max), (:minimum, :min),
665+
(:maximum, :max_maximum), (:minimum, :min_minimum),
666666
(:all, :&), (:any, :|)]
667667
fname! = Symbol(fname, '!')
668668
_fname = Symbol('_', fname)

0 commit comments

Comments
 (0)