Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 7408dd9

Browse files
authored
Merge pull request #645 from JuliaGPU/tb/multitasking
Support for Julia's multitasking.
2 parents 138ece7 + c3ba8b8 commit 7408dd9

File tree

18 files changed

+201
-151
lines changed

18 files changed

+201
-151
lines changed

.gitlab-ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ pool:split:
138138
- nvidia
139139
variables:
140140
CUARRAYS_MEMORY_POOL: 'split'
141+
allow_failure: true
141142

142143
debug:
143144
extends:

Manifest.toml

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,28 @@ uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
3939
version = "6.2.0"
4040

4141
[[CUDAnative]]
42-
deps = ["Adapt", "BinaryProvider", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "MacroTools", "Pkg", "Printf", "TimerOutputs"]
43-
git-tree-sha1 = "e6742ce88d11f1fdf6a9357ba738735f86ce67b5"
44-
repo-rev = "58c6755445c05ff26f1bdc5c12c7ae0aa6c39bc2"
45-
repo-url = "https://github.com/JuliaGPU/CUDAnative.jl.git"
42+
deps = ["Adapt", "BinaryProvider", "CEnum", "CUDAapi", "CUDAdrv", "Cthulhu", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "MacroTools", "Pkg", "Printf", "TimerOutputs"]
43+
git-tree-sha1 = "1ee71ece4332185ad49b93f7b6cf9d51017e40ef"
4644
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
47-
version = "2.10.2"
45+
version = "3.0.0"
46+
47+
[[CodeTracking]]
48+
deps = ["InteractiveUtils", "UUIDs"]
49+
git-tree-sha1 = "0becdab7e6fbbcb7b88d8de5b72e5bb2f28239f3"
50+
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
51+
version = "0.5.8"
52+
53+
[[Compat]]
54+
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
55+
git-tree-sha1 = "ed2c4abadf84c53d9e58510b5fc48912c2336fbb"
56+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
57+
version = "2.2.0"
58+
59+
[[Cthulhu]]
60+
deps = ["CodeTracking", "InteractiveUtils", "TerminalMenus", "Unicode"]
61+
git-tree-sha1 = "5e0f928ccaab1fa2911fc4e204e8a6f5b0213eaf"
62+
uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f"
63+
version = "1.0.0"
4864

4965
[[DataStructures]]
5066
deps = ["InteractiveUtils", "OrderedCollections"]
@@ -56,6 +72,10 @@ version = "0.17.10"
5672
deps = ["Printf"]
5773
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
5874

75+
[[DelimitedFiles]]
76+
deps = ["Mmap"]
77+
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
78+
5979
[[Distributed]]
6080
deps = ["Random", "Serialization", "Sockets"]
6181
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -100,6 +120,9 @@ version = "0.5.4"
100120
deps = ["Base64"]
101121
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
102122

123+
[[Mmap]]
124+
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
125+
103126
[[NNlib]]
104127
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
105128
git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824"
@@ -146,6 +169,10 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
146169
[[Serialization]]
147170
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
148171

172+
[[SharedArrays]]
173+
deps = ["Distributed", "Mmap", "Random", "Serialization"]
174+
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
175+
149176
[[Sockets]]
150177
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
151178

@@ -157,6 +184,12 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
157184
deps = ["LinearAlgebra", "SparseArrays"]
158185
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
159186

187+
[[TerminalMenus]]
188+
deps = ["Compat", "REPL", "Test"]
189+
git-tree-sha1 = "9ae6ed0c94eee4d898e049820942af21daf15efc"
190+
uuid = "dc548174-15c3-5faf-af27-7997cfbde655"
191+
version = "0.1.0"
192+
160193
[[Test]]
161194
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
162195
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Adapt = "1.0"
3030
CEnum = "0.2"
3131
CUDAapi = "3.0, 4.0"
3232
CUDAdrv = "6.0.1"
33-
CUDAnative = "2.10"
33+
CUDAnative = "3.0"
3434
DataStructures = "0.17"
3535
GPUArrays = "3.1"
3636
MacroTools = "0.5"

src/blas/CUBLAS.jl

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,68 +25,74 @@ include("wrappers.jl")
2525
# high-level integrations
2626
include("linalg.jl")
2727

28-
const handles_lock = ReentrantLock()
29-
const created_handles = Dict{Tuple{UInt,Int},cublasHandle_t}()
30-
const created_xt_handles = Dict{Tuple{UInt,Int},cublasXtHandle_t}()
31-
const active_handles = Vector{Union{Nothing,cublasHandle_t}}()
32-
const active_xt_handles = Vector{Union{Nothing,cublasXtHandle_t}}()
28+
# thread cache for task-local library handles
29+
const thread_handles = Vector{Union{Nothing,cublasHandle_t}}()
30+
const thread_xt_handles = Vector{Union{Nothing,cublasXtHandle_t}}()
3331

3432
function handle()
3533
tid = Threads.threadid()
36-
if @inbounds active_handles[tid] === nothing
34+
if @inbounds thread_handles[tid] === nothing
3735
ctx = context()
38-
key = (objectid(ctx), tid)
39-
lock(handles_lock) do
40-
active_handles[tid] = get!(created_handles, key) do
41-
handle = cublasCreate_v2()
42-
atexit(()->CUDAdrv.isvalid(ctx) && cublasDestroy_v2(handle))
43-
44-
# enable tensor math mode if our device supports it, and fast math is enabled
45-
dev = CUDAdrv.device()
46-
if Base.JLOptions().fast_math == 1 && CUDAdrv.capability(dev) >= v"7.0" && version() >= v"9"
47-
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)
36+
thread_handles[tid] = get!(task_local_storage(), (:CUBLAS, ctx)) do
37+
handle = cublasCreate_v2()
38+
finalizer(current_task()) do task
39+
CUDAdrv.isvalid(ctx) || return
40+
context!(ctx) do
41+
cublasDestroy_v2(handle)
4842
end
43+
end
4944

50-
handle
45+
# enable tensor math mode if our device supports it, and fast math is enabled
46+
dev = CUDAdrv.device()
47+
if Base.JLOptions().fast_math == 1 && CUDAdrv.capability(dev) >= v"7.0" && version() >= v"9"
48+
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)
5149
end
50+
51+
handle
5252
end
5353
end
54-
@inbounds active_handles[tid]
54+
@inbounds thread_handles[tid]
5555
end
5656

5757
function xt_handle()
5858
tid = Threads.threadid()
59-
if @inbounds active_xt_handles[tid] === nothing
59+
if @inbounds thread_xt_handles[tid] === nothing
6060
ctx = context()
61-
key = (objectid(ctx), tid)
62-
lock(handles_lock) do
63-
active_xt_handles[tid] = get!(created_xt_handles, key) do
64-
handle = cublasXtCreate()
65-
atexit(()->CUDAdrv.isvalid(ctx) && cublasXtDestroy(handle))
66-
67-
# select the devices
68-
# TODO: this is weird, since we typically use a single device per thread/context
69-
devs = convert.(Cint, CUDAdrv.devices())
70-
cublasXtDeviceSelect(handle, length(devs), devs)
71-
72-
handle
61+
thread_xt_handles[tid] = get!(task_local_storage(), (:CUBLASxt, ctx)) do
62+
handle = cublasXtCreate()
63+
finalizer(current_task()) do task
64+
CUDAdrv.isvalid(ctx) || return
65+
context!(ctx) do
66+
cublasXtDestroy(handle)
67+
end
7368
end
69+
70+
# select the devices
71+
# TODO: this is weird, since we typically use a single device per thread/context
72+
devs = convert.(Cint, CUDAdrv.devices())
73+
cublasXtDeviceSelect(handle, length(devs), devs)
74+
75+
handle
7476
end
7577
end
76-
@inbounds active_xt_handles[tid]
78+
@inbounds thread_xt_handles[tid]
7779
end
7880

7981
function __init__()
80-
resize!(active_handles, Threads.nthreads())
81-
fill!(active_handles, nothing)
82+
resize!(thread_handles, Threads.nthreads())
83+
fill!(thread_handles, nothing)
8284

83-
resize!(active_xt_handles, Threads.nthreads())
84-
fill!(active_xt_handles, nothing)
85+
resize!(thread_xt_handles, Threads.nthreads())
86+
fill!(thread_xt_handles, nothing)
8587

8688
CUDAnative.atcontextswitch() do tid, ctx
87-
# we don't eagerly initialize handles, but do so lazily when requested
88-
active_handles[tid] = nothing
89-
active_xt_handles[tid] = nothing
89+
thread_handles[tid] = nothing
90+
thread_xt_handles[tid] = nothing
91+
end
92+
93+
CUDAnative.attaskswitch() do tid, task
94+
thread_handles[tid] = nothing
95+
thread_xt_handles[tid] = nothing
9096
end
9197
end
9298

src/blas/error.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ end
4848
end
4949

5050
function initialize_api()
51-
# make sure the calling thread has an active context
52-
CUDAnative.initialize_context()
51+
CUDAnative.prepare_cuda_call()
5352
end
5453

5554
macro check(ex)

src/dnn/CUDNN.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,33 +39,38 @@ include("nnlib.jl")
3939

4040
include("compat.jl")
4141

42-
const handles_lock = ReentrantLock()
43-
const created_handles = Dict{Tuple{UInt,Int},cudnnHandle_t}()
44-
const active_handles = Vector{Union{Nothing,cudnnHandle_t}}()
42+
# thread cache for task-local library handles
43+
const thread_handles = Vector{Union{Nothing,cudnnHandle_t}}()
4544

4645
function handle()
4746
tid = Threads.threadid()
48-
if @inbounds active_handles[tid] === nothing
47+
if @inbounds thread_handles[tid] === nothing
4948
ctx = context()
50-
key = (objectid(ctx), tid)
51-
lock(handles_lock) do
52-
active_handles[tid] = get!(created_handles, key) do
53-
handle = cudnnCreate()
54-
atexit(()->CUDAdrv.isvalid(ctx) && cudnnDestroy(handle))
55-
handle
49+
thread_handles[tid] = get!(task_local_storage(), (:CUDNN, ctx)) do
50+
handle = cudnnCreate()
51+
finalizer(current_task()) do task
52+
CUDAdrv.isvalid(ctx) || return
53+
context!(ctx) do
54+
cudnnDestroy(handle)
55+
end
5656
end
57+
58+
handle
5759
end
5860
end
59-
@inbounds active_handles[tid]
61+
@inbounds thread_handles[tid]
6062
end
6163

6264
function __init__()
63-
resize!(active_handles, Threads.nthreads())
64-
fill!(active_handles, nothing)
65+
resize!(thread_handles, Threads.nthreads())
66+
fill!(thread_handles, nothing)
6567

6668
CUDAnative.atcontextswitch() do tid, ctx
67-
# we don't eagerly initialize handles, but do so lazily when requested
68-
active_handles[tid] = nothing
69+
thread_handles[tid] = nothing
70+
end
71+
72+
CUDAnative.attaskswitch() do tid, task
73+
thread_handles[tid] = nothing
6974
end
7075
end
7176

src/dnn/error.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ name(err::CUDNNError) = unsafe_string(cudnnGetErrorString(err))
2020
end
2121

2222
function initialize_api()
23-
# make sure the calling thread has an active context
24-
CUDAnative.initialize_context()
23+
CUDAnative.prepare_cuda_call()
2524
end
2625

2726
macro check(ex)

src/fft/error.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ end
6262
end
6363

6464
function initialize_api()
65-
# make sure the calling thread has an active context
66-
CUDAnative.initialize_context()
65+
CUDAnative.prepare_cuda_call()
6766
end
6867

6968
macro check(ex)

src/memory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ synchronized right before and after executing `ex` to exclude any external effec
299299
macro time(ex)
300300
quote
301301
# @time might surround an application, so be sure to initialize CUDA before that
302-
CUDAnative.initialize_context()
302+
CUDAnative.prepare_cuda_call()
303303

304304
# coarse synchronization to exclude effects from previously-executed code
305305
CUDAdrv.synchronize()

src/rand/CURAND.jl

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,32 @@ include("wrappers.jl")
2323
# high-level integrations
2424
include("random.jl")
2525

26-
const handles_lock = ReentrantLock()
27-
const created_generators = Dict{Tuple{UInt,Int},RNG}()
28-
const active_generators = Vector{Union{Nothing,RNG}}()
26+
# thread cache for task-local library handles
27+
const thread_generators = Vector{Union{Nothing,RNG}}()
2928

3029
function generator()
3130
tid = Threads.threadid()
32-
if @inbounds active_generators[tid] === nothing
31+
if @inbounds thread_generators[tid] === nothing
3332
ctx = context()
34-
key = (objectid(ctx), tid)
35-
lock(handles_lock) do
36-
active_generators[tid] = get!(created_generators, key) do
37-
rng = RNG()
38-
Random.seed!(rng)
39-
rng
40-
end
33+
thread_generators[tid] = get!(task_local_storage(), (:CURAND, ctx)) do
34+
rng = RNG()
35+
Random.seed!(rng)
36+
rng
4137
end
4238
end
43-
@inbounds active_generators[tid]
39+
@inbounds thread_generators[tid]
4440
end
4541

4642
function __init__()
47-
resize!(active_generators, Threads.nthreads())
48-
fill!(active_generators, nothing)
43+
resize!(thread_generators, Threads.nthreads())
44+
fill!(thread_generators, nothing)
4945

5046
CUDAnative.atcontextswitch() do tid, ctx
51-
# we don't eagerly initialize handles, but do so lazily when requested
52-
active_generators[tid] = nothing
47+
thread_generators[tid] = nothing
48+
end
49+
50+
CUDAnative.attaskswitch() do tid, task
51+
thread_generators[tid] = nothing
5352
end
5453
end
5554

0 commit comments

Comments
 (0)