Skip to content

Commit 3df641a

Browse files
authored
toplevel: Parallelize report_package signature analysis (#785)
Parallelize the analysis of method signatures in `report_package` using `Threads.@spawn`. Each `analyze_method_signature!` task gets its own independent `AnalyzerState` with local `inf_cache` and `analysis_results`, avoiding data races without complex cross-task cache sharing. Design notes: https://publish.obsidian.md/jetls/work/JETLS/JET/JET+Make+%60report_package%60+parallelized Key changes: - Add `AtomicContainers` module providing thread-safe container types (`SWContainer`, `LWContainer`, `CASContainer`) for concurrent access - Introduce `CASDict` for thread-safe analyzer cache (`JET_ANALYZER_CACHE` and `OPT_ANALYZER_CACHE`) - Add `PackageAnalysisProgress` struct for tracking parallel analysis progress with atomic counters - Parallelize `analyze_from_definitions!` in virtualprocess.jl Benchamrks: > `julia --startup-file-no --threads=4,2 -e 'using JET; report_package(JET; target_modules=(JET,), sourceinfo=:compact);'` | Approach | Time | |------------|--------| | sequential | 52.07s | | parallel | 17.75s | > ``julia --startup-file=no --threads=4,2 -e 'using JET; using Pkg; Pkg.activate(; temp=true); Pkg.add("CSV"); using CSV; report_package(CSV; target_modules=(CSV,), sourceinfo=:compact);'`` | Approach | Time | |------------|--------| | sequential | 44.23s | | parallel | 19.57s |
1 parent 7147a40 commit 3df641a

File tree

7 files changed

+189
-62
lines changed

7 files changed

+189
-62
lines changed

.github/workflows/ci.yml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,26 @@ jobs:
8383
- uses: julia-actions/julia-processcoverage@v1
8484
- uses: codecov/codecov-action@v5
8585

86+
multithreading-test:
87+
name: JULIA_NUM_THREADS=4,2
88+
runs-on: ubuntu-latest
89+
steps:
90+
- uses: actions/checkout@v6
91+
- uses: julia-actions/setup-julia@v2
92+
with:
93+
version: "1"
94+
arch: x64
95+
- uses: julia-actions/cache@v2
96+
- name: set preferences
97+
working-directory: .
98+
run: |
99+
echo '[JET]
100+
JET_DEV_MODE = true' > LocalPreferences.toml
101+
- uses: julia-actions/julia-buildpkg@latest
102+
- uses: julia-actions/julia-runtest@latest
103+
env:
104+
JULIA_NUM_THREADS: "4,2"
105+
86106
empty-loading-test:
87107
runs-on: ubuntu-latest
88108
steps:

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5353
<!-- links end -->
5454

5555
## [Unreleased]
56+
### Changed
57+
- **Parallelized `report_package`**: Method signature analysis in `report_package`
58+
is now parallelized using Julia's multithreading, providing significant
59+
speedup on multi-core systems.
60+
61+
> With `--threads=4,2`
62+
- Benchmark on `report_package(JET)`: 52.07s → 17.75s (~3x faster)
63+
- Benchmark on `report_package(CSV)`: 44.23s → 19.57s (~2x faster)
64+
5665
### Internal
5766
- Refactored the project file to use the [`[workspace]`](https://pkgdocs.julialang.org/v1/toml-files/#The-%5Bworkspace%5D-section) for the docs/test environment of JET.
5867
This allows running e.g. `julia --project=./test test/runtests.jl` or

src/JETBase.jl

Lines changed: 103 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,34 @@ const JET_LOGGER_LEVELS_DESC = let
294294
end
295295
jet_logger_level(@nospecialize io::IO) = get(io, JET_LOGGER_LEVEL, DEFAULT_LOGGER_LEVEL)::Int
296296

297+
# multithreading
298+
299+
"""
300+
CASDict{K,V}
301+
302+
A thread-safe dictionary using Compare-And-Swap (CAS) operations for lock-free updates.
303+
Reads are always lock-free via atomic load. Writes use a CAS retry loop, making this
304+
suitable for lightweight, pure update functions that are safe to retry.
305+
306+
Currently implements [`get!`](@ref) only.
307+
"""
308+
mutable struct CASDict{K,V}
309+
@atomic data::Dict{K,V}
310+
CASDict{K,V}() where {K,V} = new{K,V}(Dict{K,V}())
311+
end
312+
313+
function Base.get!(f::Base.Callable, d::CASDict{K,V}, key::K) where {K,V}
314+
old = @atomic :acquire d.data
315+
val = get(old, key, nothing)
316+
val !== nothing && return val::V
317+
while true
318+
new = copy(old)
319+
val = get!(f, new, key)::V
320+
old, success = @atomicreplace :acquire_release :monotonic d.data old => new
321+
success && return val
322+
end
323+
end
324+
297325
# analysis core
298326
# =============
299327

@@ -327,6 +355,20 @@ Prints a report of the top-level error `report` to the given `io`.
327355
"""
328356
function print_report end
329357

358+
mutable struct PackageAnalysisProgress
359+
const reports::Vector{InferenceErrorReport}
360+
const reports_lock::ReentrantLock
361+
@atomic done::Int
362+
@atomic analyzed::Int
363+
@atomic cached::Int
364+
const interval::Int
365+
@atomic next_interval::Int
366+
function PackageAnalysisProgress(n_sigs::Int)
367+
interval = max(n_sigs ÷ 25, 1)
368+
new(InferenceErrorReport[], ReentrantLock(), 0, 0, 0, interval, interval)
369+
end
370+
end
371+
330372
include("toplevel/virtualprocess.jl")
331373

332374
# results
@@ -968,6 +1010,11 @@ struct SigAnalysisResult
9681010
codeinst::CodeInstance
9691011
end
9701012

1013+
struct SigWorkItem
1014+
siginfos::Vector{Revise.SigInfo}
1015+
index::Int
1016+
end
1017+
9711018
"""
9721019
analyze_and_report_package!(analyzer::AbstractAnalyzer, package::Module; jetconfigs...) -> JETToplevelResult
9731020
@@ -991,7 +1038,6 @@ function analyze_and_report_package!(analyzer::AbstractAnalyzer, pkgmod::Module;
9911038
end
9921039

9931040
start = time()
994-
counter, analyzed, cached = Ref(0), Ref(0), Ref(0)
9951041
res = VirtualProcessResult(nothing)
9961042
jetconfigs = set_if_missing(jetconfigs, :toplevel_logger, IOContext(stdout, JET_LOGGER_LEVEL => DEFAULT_LOGGER_LEVEL))
9971043
config = ToplevelConfig(; jetconfigs...)
@@ -1001,66 +1047,89 @@ function analyze_and_report_package!(analyzer::AbstractAnalyzer, pkgmod::Module;
10011047
newstate = AnalyzerState(AnalyzerState(analyzer); world=Base.get_world_counter())
10021048
analyzer = AbstractAnalyzer(analyzer, newstate)
10031049

1004-
n_sigs = 0
1005-
for fi in pkgdata.fileinfos, (_, exsigs) in fi.modexsigs, (_, siginfos) in exsigs
1006-
isnothing(siginfos) && continue
1007-
n_sigs += length(siginfos)
1008-
end
1050+
workitems = SigWorkItem[]
10091051
for fi in pkgdata.fileinfos, (_, exsigs) in fi.modexsigs, (_, siginfos) in exsigs
10101052
isnothing(siginfos) && continue
10111053
for (i, siginfo) in enumerate(siginfos)
1012-
toplevel_logger(config) do @nospecialize(io::IO)
1013-
clearline(io)
1014-
end
1015-
counter[] += 1
1016-
inf_world = CC.get_inference_world(analyzer)
1054+
push!(workitems, SigWorkItem(siginfos, i))
1055+
end
1056+
end
1057+
1058+
n_sigs = length(workitems)
1059+
progress = PackageAnalysisProgress(n_sigs)
1060+
inf_world = CC.get_inference_world(analyzer)
1061+
1062+
toplevel_logger(config) do @nospecialize(io::IO)
1063+
print(io, "Analyzing top-level definitions (progress: 0/$n_sigs | interval: $(progress.interval))")
1064+
end
1065+
1066+
tasks = map(workitems) do workitem
1067+
(; siginfos, index) = workitem
1068+
siginfo = siginfos[index]
1069+
Threads.@spawn :default try
10171070
ext = Revise.get_extended_data(siginfo, :JET)
1071+
local reports::Vector{InferenceErrorReport}
10181072
if ext !== nothing && ext.data isa SigAnalysisResult
10191073
prev_result = ext.data::SigAnalysisResult
10201074
if (CC.cache_owner(analyzer) === prev_result.codeinst.owner &&
10211075
prev_result.codeinst.max_world inf_world prev_result.codeinst.min_world)
1022-
toplevel_logger(config) do @nospecialize(io::IO)
1023-
(counter[] == n_sigs ? println : print)(io, "Skipped analysis for cached definition ($(counter[])/$n_sigs)")
1024-
end
1025-
cached[] += 1
1076+
@atomic progress.cached += 1
10261077
reports = prev_result.reports
10271078
@goto gotreports
10281079
end
10291080
end
1081+
# Create a new analyzer with fresh local caches (`inf_cache` and `analysis_results`)
1082+
# to avoid data races between concurrent signature analysis tasks
1083+
task_analyzer = AbstractAnalyzer(analyzer,
1084+
AnalyzerState(AnalyzerState(analyzer), #=refresh_local_cache=#true))
10301085
match = Base._which(siginfo.sig;
1031-
method_table = CC.method_table(analyzer),
1086+
method_table = CC.method_table(task_analyzer),
10321087
world = inf_world,
10331088
raise = false)
10341089
if match !== nothing
1035-
toplevel_logger(config; pre=clearline) do @nospecialize(io::IO)
1036-
if jet_logger_level(io) JET_LOGGER_LEVEL_DEBUG
1037-
print(io, "Analyzing top-level definition `")
1038-
Base.show_tuple_as_call(io, Symbol(""), siginfo.sig)
1039-
print(io, "` (progress: $(counter[])/$n_sigs)")
1040-
else
1041-
print(io, "Analyzing top-level definition (progress: $(counter[])/$n_sigs)")
1042-
end
1043-
end
1044-
result = analyze_method_signature!(analyzer,
1090+
result = analyze_method_signature!(task_analyzer,
10451091
match.method, match.spec_types, match.sparams)
1046-
analyzed[] += 1
1047-
reports = get_reports(analyzer, result)
1048-
siginfos[i] = Revise.replace_extended_data(siginfo, :JET, SigAnalysisResult(reports, result.ci))
1049-
@label gotreports
1050-
append!(res.inference_error_reports, reports)
1092+
@atomic progress.analyzed += 1
1093+
reports = get_reports(task_analyzer, result)
1094+
siginfos[index] = Revise.replace_extended_data(siginfo, :JET, SigAnalysisResult(reports, result.ci))
10511095
else
1052-
toplevel_logger(config) do @nospecialize(io::IO)
1096+
toplevel_logger(config; pre=println) do @nospecialize(io::IO)
10531097
print(io, "Couldn't find a single matching method for the signature `")
10541098
Base.show_tuple_as_call(io, Symbol(""), siginfo.sig)
1055-
println(io, "` (progress: $(counter[])/$n_sigs)")
1099+
println(io, "`")
1100+
end
1101+
reports = InferenceErrorReport[]
1102+
end
1103+
@label gotreports
1104+
isempty(reports) || @lock progress.reports_lock append!(progress.reports, reports)
1105+
catch err
1106+
@error "Error analyzing method signature" siginfo.sig
1107+
Base.showerror(stderr, err, catch_backtrace())
1108+
finally
1109+
done = (@atomic progress.done += 1)
1110+
current_next = @atomic progress.next_interval
1111+
if done >= current_next
1112+
@atomicreplace progress.next_interval current_next => current_next + progress.interval
1113+
toplevel_logger(config; pre=clearline) do @nospecialize(io::IO)
1114+
print(io, "Analyzing top-level definitions (progress: $done/$n_sigs)")
10561115
end
10571116
end
10581117
end
10591118
end
10601119

1120+
waitall(tasks)
1121+
1122+
append!(res.inference_error_reports, progress.reports)
1123+
1124+
toplevel_logger(config; pre=clearline) do @nospecialize(io::IO)
1125+
done = @atomic progress.done
1126+
print(io, "Analyzing top-level definitions (progress: $done/$n_sigs)")
1127+
end
10611128
toplevel_logger(config; pre=println) do @nospecialize(io::IO)
10621129
sec = round(time() - start; digits = 3)
1063-
println(io, "Analyzed all top-level definitions (all: $(counter[]) | analyzed: $(analyzed[]) | cached: $(cached[]) | took: $sec sec)")
1130+
analyzed = @atomic progress.analyzed
1131+
cached = @atomic progress.cached
1132+
println(io, "Analyzed all top-level definitions (all: $n_sigs | analyzed: $analyzed | cached: $cached | took: $sec sec)")
10641133
end
10651134

10661135
unique!(aggregation_policy(analyzer), res.inference_error_reports)

src/abstractinterpret/abstractanalyzer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ end
295295
# -----------
296296
# 2. `AbstractAnalyzer(analyzer::NewAnalyzer, state::AnalyzerState) -> NewAnalyzer`
297297

298-
@noinline function AbstractAnalyzer(analyzer::AbstractAnalyzer, state::AnalyzerState)
298+
@noinline function AbstractAnalyzer(analyzer::AbstractAnalyzer, ::AnalyzerState)
299299
AnalyzerType = nameof(typeof(analyzer))
300300
error(lazy"""
301301
Missing `$AbstractAnalyzer` API:

src/analyzers/jetanalyzer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ JETInterface.typeinf_world(::BasicJETAnalyzer) = JET_TYPEINF_WORLD[]
123123
JETInterface.typeinf_world(::SoundJETAnalyzer) = JET_TYPEINF_WORLD[]
124124
JETInterface.typeinf_world(::TypoJETAnalyzer) = JET_TYPEINF_WORLD[]
125125

126-
const JET_ANALYZER_CACHE = Dict{UInt, AnalysisToken}()
126+
const JET_ANALYZER_CACHE = CASDict{UInt,AnalysisToken}()
127127

128128
JETAnalyzerConfig(analyzer::JETAnalyzer) = analyzer.config
129129

src/analyzers/optanalyzer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ JETInterface.AnalysisToken(analyzer::OptAnalyzer) = analyzer.analysis_token
160160
JETInterface.typeinf_world(::OptAnalyzer) = JET_TYPEINF_WORLD[]
161161
JETInterface.vscode_diagnostics_order(::OptAnalyzer) = false
162162

163-
const OPT_ANALYZER_CACHE = Dict{UInt, AnalysisToken}()
163+
const OPT_ANALYZER_CACHE = CASDict{UInt,AnalysisToken}()
164164

165165
# overloads
166166
# =========

src/toplevel/virtualprocess.jl

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -734,42 +734,71 @@ gen_virtual_module(parent::Module = Main; name = VIRTUAL_MODULE_NAME) =
734734
# if code generation has failed given the entry method signature, the overload of
735735
# `InferenceState(..., ::AbstractAnalyzer)` will collect `GeneratorErrorReport`
736736
function analyze_from_definitions!(interp::ConcreteInterpreter, config::ToplevelConfig)
737-
succeeded = Ref(0)
738737
start = time()
739-
analyzer = ToplevelAbstractAnalyzer(interp, non_toplevel_concretized; refresh_local_cache = false)
740738
entrypoint = config.analyze_from_definitions
741739
res = InterpretationState(interp).res
742740
n_sigs = length(res.toplevel_signatures)
743-
for i = 1:n_sigs
744-
tt = res.toplevel_signatures[i]
745-
match = Base._which(tt;
746-
# NOTE use the latest world counter with `method_table(analyzer)` unwrapped,
747-
# otherwise it may use a world counter when this method isn't defined yet
748-
method_table = CC.method_table(analyzer),
749-
world = CC.get_inference_world(analyzer),
750-
raise = false)
751-
if (match !== nothing &&
752-
(!(entrypoint isa Symbol) || # implies `analyze_from_definitions===true`
753-
match.method.name === entrypoint))
754-
succeeded[] += 1
755-
toplevel_logger(config; pre=clearline) do @nospecialize(io::IO)
756-
print(io, "analyzing from top-level definitions ($(succeeded[])/$n_sigs)")
741+
n_sigs == 0 && return nothing
742+
743+
progress = PackageAnalysisProgress(n_sigs)
744+
745+
toplevel_logger(config) do @nospecialize(io::IO)
746+
print(io, "analyzing from top-level definitions (0/$n_sigs)")
747+
end
748+
749+
tasks = map(1:n_sigs) do i
750+
Threads.@spawn begin
751+
tt = res.toplevel_signatures[i]
752+
# Create a new analyzer with fresh local caches (`inf_cache` and `analysis_results`)
753+
# to avoid data races between concurrent signature analysis tasks
754+
analyzer = ToplevelAbstractAnalyzer(interp, non_toplevel_concretized;
755+
refresh_local_cache = true)
756+
match = Base._which(tt;
757+
# NOTE use the latest world counter with `method_table(analyzer)` unwrapped,
758+
# otherwise it may use a world counter when this method isn't defined yet
759+
method_table = CC.method_table(analyzer),
760+
world = CC.get_inference_world(analyzer),
761+
raise = false)
762+
if (match !== nothing &&
763+
(!(entrypoint isa Symbol) || # implies `analyze_from_definitions===true`
764+
match.method.name === entrypoint))
765+
@atomic progress.analyzed += 1
766+
result = analyze_method_signature!(analyzer,
767+
match.method, match.spec_types, match.sparams)
768+
reports = get_reports(analyzer, result)
769+
isempty(reports) || @lock progress.reports_lock append!(progress.reports, reports)
770+
else
771+
# something went wrong
772+
toplevel_logger(config; pre=clearline) do @nospecialize(io::IO)
773+
println(io, "couldn't find a single method matching the signature `", tt, "`")
774+
end
757775
end
758-
result = analyze_method_signature!(analyzer,
759-
match.method, match.spec_types, match.sparams)
760-
reports = get_reports(analyzer, result)
761-
append!(res.inference_error_reports, reports)
762-
else
763-
# something went wrong
764-
toplevel_logger(config; filter=(JET_LOGGER_LEVEL_DEBUG), pre=clearline) do @nospecialize(io::IO)
765-
println(io, "couldn't find a single method matching the signature `", tt, "`")
776+
done = (@atomic progress.done += 1)
777+
current_next = @atomic progress.next_interval
778+
if done >= current_next
779+
@atomicreplace progress.next_interval current_next => current_next + progress.interval
780+
toplevel_logger(config; pre=clearline) do @nospecialize(io::IO)
781+
analyzed = @atomic progress.analyzed
782+
print(io, "analyzing from top-level definitions ($analyzed/$n_sigs)")
783+
end
766784
end
767785
end
768786
end
787+
788+
waitall(tasks)
789+
790+
append!(res.inference_error_reports, progress.reports)
791+
792+
toplevel_logger(config; pre=clearline) do @nospecialize(io::IO)
793+
done = @atomic progress.done
794+
print(io, "analyzing from top-level definitions ($done/$n_sigs)")
795+
end
769796
toplevel_logger(config; pre=println) do @nospecialize(io::IO)
770797
sec = round(time() - start; digits = 3)
771-
println(io, "analyzed $(succeeded[]) top-level definitions (took $sec sec)")
798+
analyzed = @atomic progress.analyzed
799+
println(io, "analyzed $analyzed top-level definitions (took $sec sec)")
772800
end
801+
773802
return nothing
774803
end
775804

0 commit comments

Comments
 (0)