Skip to content

Commit c1ccaa5

Browse files
committed
inference: accelerate type-limits under wide-recursion
when we hit union-splitting, we need to ensure type limits are very aggressive and preferably also independent of the height of the recursion chain fix #31572
1 parent c9786e6 commit c1ccaa5

File tree

2 files changed

+70
-26
lines changed

2 files changed

+70
-26
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
6363
nonbot = 0 # the index of the only non-Bottom inference result if > 0
6464
seen = 0 # number of signatures actually inferred
6565
istoplevel = sv.linfo.def isa Module
66+
any_splitunions = napplicable > 1
6667
for i in 1:napplicable
6768
match = applicable[i]::SimpleVector
6869
method = match[3]::Method
@@ -80,7 +81,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
8081
if splitunions
8182
splitsigs = switchtupleunion(sig)
8283
for sig_n in splitsigs
83-
rt, edgecycle1, edge = abstract_call_method(method, sig_n, svec(), sv)
84+
rt, edgecycle1, edge = abstract_call_method(method, sig_n, svec(), any_splitunions, sv)
8485
if edge !== nothing
8586
push!(edges, edge)
8687
end
@@ -89,7 +90,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
8990
this_rt === Any && break
9091
end
9192
else
92-
this_rt, edgecycle1, edge = abstract_call_method(method, sig, match[2]::SimpleVector, sv)
93+
this_rt, edgecycle1, edge = abstract_call_method(method, sig, match[2]::SimpleVector, any_splitunions, sv)
9394
edgecycle |= edgecycle1::Bool
9495
if edge !== nothing
9596
push!(edges, edge)
@@ -227,7 +228,7 @@ function abstract_call_method_with_const_args(@nospecialize(rettype), @nospecial
227228
return result
228229
end
229230

230-
function abstract_call_method(method::Method, @nospecialize(sig), sparams::SimpleVector, sv::InferenceState)
231+
function abstract_call_method(method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
231232
if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base
232233
return Any, false, nothing
233234
end
@@ -266,30 +267,36 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
266267
inf_method2 = infstate.src.method_for_inference_limit_heuristics # limit only if user token match
267268
inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing}
268269
if topmost === nothing && method2 === inf_method2
269-
# inspect the parent of this edge,
270-
# to see if they are the same Method as sv
271-
# in which case we'll need to ensure it is convergent
272-
# otherwise, we don't
273-
for parent in infstate.callers_in_cycle
274-
# check in the cycle list first
275-
# all items in here are mutual parents of all others
276-
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
277-
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
278-
if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
279-
topmost = infstate
280-
edgecycle = true
281-
break
282-
end
283-
end
284-
let parent = infstate.parent
285-
# then check the parent link
286-
if topmost === nothing && parent !== nothing
287-
parent = parent::InferenceState
270+
if hardlimit
271+
topmost = infstate
272+
edgecycle = true
273+
else
274+
# if this is a soft limit,
275+
# also inspect the parent of this edge,
276+
# to see if they are the same Method as sv
277+
# in which case we'll need to ensure it is convergent
278+
# otherwise, we don't
279+
for parent in infstate.callers_in_cycle
280+
# check in the cycle list first
281+
# all items in here are mutual parents of all others
288282
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
289283
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
290-
if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
284+
if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
291285
topmost = infstate
292286
edgecycle = true
287+
break
288+
end
289+
end
290+
let parent = infstate.parent
291+
# then check the parent link
292+
if topmost === nothing && parent !== nothing
293+
parent = parent::InferenceState
294+
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
295+
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
296+
if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
297+
topmost = infstate
298+
edgecycle = true
299+
end
293300
end
294301
end
295302
end
@@ -321,7 +328,7 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
321328
comparison = method.sig
322329
end
323330
# see if the type is actually too big (relative to the caller), and limit it if required
324-
newsig = limit_type_size(sig, comparison, sv.linfo.specTypes, sv.params.TUPLE_COMPLEXITY_LIMIT_DEPTH, spec_len)
331+
newsig = limit_type_size(sig, comparison, hardlimit ? comparison : sv.linfo.specTypes, sv.params.TUPLE_COMPLEXITY_LIMIT_DEPTH, spec_len)
325332

326333
if newsig !== sig
327334
# continue inference, but note that we've limited parameter complexity

test/compiler/inference.jl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,13 +1050,13 @@ copy_dims_out(out) = ()
10501050
copy_dims_out(out, dim::Int, tail...) = copy_dims_out((out..., dim), tail...)
10511051
copy_dims_out(out, dim::Colon, tail...) = copy_dims_out((out..., dim), tail...)
10521052
@test Base.return_types(copy_dims_out, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
1053-
@test all(m -> 20 < count_specializations(m) < 45, methods(copy_dims_out))
1053+
@test all(m -> 4 < count_specializations(m) < 15, methods(copy_dims_out)) # currently about 5
10541054

10551055
copy_dims_pair(out) = ()
10561056
copy_dims_pair(out, dim::Int, tail...) = copy_dims_pair(out => dim, tail...)
10571057
copy_dims_pair(out, dim::Colon, tail...) = copy_dims_pair(out => dim, tail...)
10581058
@test Base.return_types(copy_dims_pair, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
1059-
@test all(m -> 10 < count_specializations(m) < 35, methods(copy_dims_pair))
1059+
@test all(m -> 5 < count_specializations(m) < 15, methods(copy_dims_pair)) # currently about 7
10601060

10611061
@test isdefined_tfunc(typeof(NamedTuple()), Const(0)) === Const(false)
10621062
@test isdefined_tfunc(typeof(NamedTuple()), Const(1)) === Const(false)
@@ -2325,3 +2325,40 @@ h28762(::Type{X}) where {X} = Array{f28762(X)}(undef, 0)
23252325
@inferred g28762(Array)
23262326
@inferred h28762(Array)
23272327
end
2328+
2329+
# issue #31572
2330+
struct MixedKeyDict{T<:Tuple} #<: AbstractDict{Any,Any}
2331+
dicts::T
2332+
end
2333+
Base.merge(f::Function, d::MixedKeyDict, others::MixedKeyDict...) = _merge(f, (), d.dicts, (d->d.dicts).(others)...)
2334+
Base.merge(f, d::MixedKeyDict, others::MixedKeyDict...) = _merge(f, (), d.dicts, (d->d.dicts).(others)...)
2335+
function _merge(f, res, d, others...)
2336+
ofsametype, remaining = _alloftype(Base.heads(d), ((),), others...)
2337+
return _merge(f, (res..., merge(f, ofsametype...)), Base.tail(d), remaining...)
2338+
end
2339+
_merge(f, res, ::Tuple{}, others...) = _merge(f, res, others...)
2340+
_merge(f, res, d) = MixedKeyDict((res..., d...))
2341+
_merge(f, res, ::Tuple{}) = MixedKeyDict(res)
2342+
function _alloftype(ofdesiredtype::Tuple{Vararg{D}}, accumulated, d::Tuple{D,Vararg}, others...) where D
2343+
return _alloftype((ofdesiredtype..., first(d)),
2344+
(Base.front(accumulated)..., (last(accumulated)..., Base.tail(d)...), ()),
2345+
others...)
2346+
end
2347+
function _alloftype(ofdesiredtype, accumulated, d, others...)
2348+
return _alloftype(ofdesiredtype,
2349+
(Base.front(accumulated)..., (last(accumulated)..., first(d))),
2350+
Base.tail(d), others...)
2351+
end
2352+
function _alloftype(ofdesiredtype, accumulated, ::Tuple{}, others...)
2353+
return _alloftype(ofdesiredtype,
2354+
(accumulated..., ()),
2355+
others...)
2356+
end
2357+
_alloftype(ofdesiredtype, accumulated) = ofdesiredtype, Base.front(accumulated)
2358+
let
2359+
d = MixedKeyDict((Dict(1 => 3), Dict(4. => 2)))
2360+
e = MixedKeyDict((Dict(1 => 7), Dict(5. => 9)))
2361+
@test merge(+, d, e).dicts == (Dict(1 => 10), Dict(4.0 => 2, 5.0 => 9))
2362+
f = MixedKeyDict((Dict(2 => 7), Dict(5. => 11)))
2363+
@test merge(+, d, e, f).dicts == (Dict(1 => 10, 2 => 7), Dict(4.0 => 2, 5.0 => 20))
2364+
end

0 commit comments

Comments
 (0)