-
Notifications
You must be signed in to change notification settings - Fork 264
[CUDNN] Support BFloat16 #2987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[CUDNN] Support BFloat16 #2987
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/cudnn/src/util.jl b/lib/cudnn/src/util.jl
index 8923ff9b5..c7ec0c2bd 100644
--- a/lib/cudnn/src/util.jl
+++ b/lib/cudnn/src/util.jl
@@ -4,13 +4,13 @@ using BFloat16s: BFloat16
cptr(x,a::DenseCuArray{Float64})=Float64[x]
cptr(x,a::DenseCuArray{Float32})=Float32[x]
cptr(x,a::DenseCuArray{Float16})=Float32[x]
-cptr(x,a::DenseCuArray{BFloat16})=Float32[x]
+cptr(x, a::DenseCuArray{BFloat16}) = Float32[x]
# Conversion between Julia and cuDNN datatypes
cudnnDataType(::Type{Float16})=CUDNN_DATA_HALF
cudnnDataType(::Type{Float32})=CUDNN_DATA_FLOAT
cudnnDataType(::Type{Float64})=CUDNN_DATA_DOUBLE
-cudnnDataType(::Type{BFloat16})=CUDNN_DATA_BFLOAT16
+cudnnDataType(::Type{BFloat16}) = CUDNN_DATA_BFLOAT16
cudnnDataType(::Type{Int8}) = CUDNN_DATA_INT8
cudnnDataType(::Type{UInt8}) = CUDNN_DATA_UINT8
cudnnDataType(::Type{Int32}) = CUDNN_DATA_INT32
@@ -21,7 +21,7 @@ cudnnDataType(::Type{Int32}) = CUDNN_DATA_INT32
juliaDataType(a)=(a==CUDNN_DATA_HALF ? Float16 :
a==CUDNN_DATA_FLOAT ? Float32 :
a==CUDNN_DATA_DOUBLE ? Float64 :
- a==CUDNN_DATA_BFLOAT16 ? BFloat16 :
+ a == CUDNN_DATA_BFLOAT16 ? BFloat16 :
a==CUDNN_DATA_INT8 ? Int8 :
a==CUDNN_DATA_UINT8 ? UInt8 :
a==CUDNN_DATA_INT32 ? Int32 : error())
diff --git a/lib/cudnn/test/activation.jl b/lib/cudnn/test/activation.jl
index e25cf4c7c..33f76ed94 100644
--- a/lib/cudnn/test/activation.jl
+++ b/lib/cudnn/test/activation.jl
@@ -62,9 +62,9 @@ activationtest(alpha=2)
activationtest(beta=2)
if capability(device()) >= v"8.0"
- (ax,ay) = randn.(BFloat16, (10,10))
- (cx,cy) = CuArray.((ax,ay))
- activationtest(mode=CUDNN_ACTIVATION_SIGMOID)
- activationtest(mode=CUDNN_ACTIVATION_RELU)
- activationtest(mode=CUDNN_ACTIVATION_TANH)
+ (ax, ay) = randn.(BFloat16, (10, 10))
+ (cx, cy) = CuArray.((ax, ay))
+ activationtest(mode = CUDNN_ACTIVATION_SIGMOID)
+ activationtest(mode = CUDNN_ACTIVATION_RELU)
+ activationtest(mode = CUDNN_ACTIVATION_TANH)
end
diff --git a/lib/cudnn/test/softmax.jl b/lib/cudnn/test/softmax.jl
index 2102d6d02..74befeeae 100644
--- a/lib/cudnn/test/softmax.jl
+++ b/lib/cudnn/test/softmax.jl
@@ -46,8 +46,8 @@ softmaxtest(algo=CUDNN_SOFTMAX_ACCURATE)
softmaxtest(algo=CUDNN_SOFTMAX_LOG)
if capability(device()) >= v"8.0"
- ax,ay = randn(BFloat16,10,10),randn(BFloat16,10,10)
- cx,cy = CuArray.((ax,ay))
+ ax, ay = randn(BFloat16, 10, 10), randn(BFloat16, 10, 10)
+ cx, cy = CuArray.((ax, ay))
softmaxtest()
- softmaxtest(algo=CUDNN_SOFTMAX_LOG)
+ softmaxtest(algo = CUDNN_SOFTMAX_LOG)
end |
|
Hm, duplicate of #1092? That one doesn't define the |
|
1.12 failure unrelated, retried CI |
|
Second CI fail also seems unrelated, rerunning. If that succeeds you should rebase on top of |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2987 +/- ##
===========================================
+ Coverage 76.53% 89.22% +12.68%
===========================================
Files 148 148
Lines 12860 12950 +90
===========================================
+ Hits 9842 11554 +1712
+ Misses 3018 1396 -1622 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
7f8d47b to
6d2bea8
Compare
|
Thanks, Katharine. I’ve updated the BFloat16s compat entry to align with CUDA.jl’s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDA.jl Benchmarks
Details
| Benchmark suite | Current: a23ee46 | Previous: 5d9474a | Ratio |
|---|---|---|---|
latency/precompile |
55341296562.5 ns |
55510377029.5 ns |
1.00 |
latency/ttfp |
7795931886 ns |
7790703567 ns |
1.00 |
latency/import |
4120273995 ns |
4122189304 ns |
1.00 |
integration/volumerhs |
9609087 ns |
9624973 ns |
1.00 |
integration/byval/slices=1 |
146813 ns |
147064 ns |
1.00 |
integration/byval/slices=3 |
425709 ns |
425893 ns |
1.00 |
integration/byval/reference |
144869 ns |
145082 ns |
1.00 |
integration/byval/slices=2 |
286216 ns |
286384 ns |
1.00 |
integration/cudadevrt |
103582 ns |
103602 ns |
1.00 |
kernel/indexing |
14109 ns |
14225 ns |
0.99 |
kernel/indexing_checked |
14715 ns |
14969 ns |
0.98 |
kernel/occupancy |
719.9230769230769 ns |
732.5227272727273 ns |
0.98 |
kernel/launch |
2487 ns |
2249.4444444444443 ns |
1.11 |
kernel/rand |
14840 ns |
18642 ns |
0.80 |
array/reverse/1d |
19661 ns |
19990 ns |
0.98 |
array/reverse/2dL_inplace |
66712 ns |
66917 ns |
1.00 |
array/reverse/1dL |
69754 ns |
70158 ns |
0.99 |
array/reverse/2d |
21612 ns |
21954 ns |
0.98 |
array/reverse/1d_inplace |
9523 ns |
9677 ns |
0.98 |
array/reverse/2d_inplace |
10951 ns |
11077 ns |
0.99 |
array/reverse/2dL |
73712 ns |
74051.5 ns |
1.00 |
array/reverse/1dL_inplace |
66734 ns |
66880 ns |
1.00 |
array/copy |
20406 ns |
20660 ns |
0.99 |
array/iteration/findall/int |
156259 ns |
158373 ns |
0.99 |
array/iteration/findall/bool |
139121.5 ns |
140139 ns |
0.99 |
array/iteration/findfirst/int |
160318 ns |
161271 ns |
0.99 |
array/iteration/findfirst/bool |
160750.5 ns |
162049 ns |
0.99 |
array/iteration/scalar |
71772 ns |
72812.5 ns |
0.99 |
array/iteration/logical |
214982 ns |
216894.5 ns |
0.99 |
array/iteration/findmin/1d |
51356 ns |
50981 ns |
1.01 |
array/iteration/findmin/2d |
95639 ns |
96704 ns |
0.99 |
array/reductions/reduce/Int64/1d |
42833 ns |
43491 ns |
0.98 |
array/reductions/reduce/Int64/dims=1 |
44548.5 ns |
52642.5 ns |
0.85 |
array/reductions/reduce/Int64/dims=2 |
61336 ns |
61484 ns |
1.00 |
array/reductions/reduce/Int64/dims=1L |
88872 ns |
88879 ns |
1.00 |
array/reductions/reduce/Int64/dims=2L |
87447 ns |
87977 ns |
0.99 |
array/reductions/reduce/Float32/1d |
35926 ns |
37248.5 ns |
0.96 |
array/reductions/reduce/Float32/dims=1 |
47245 ns |
43278 ns |
1.09 |
array/reductions/reduce/Float32/dims=2 |
59452 ns |
60066 ns |
0.99 |
array/reductions/reduce/Float32/dims=1L |
52057 ns |
52282 ns |
1.00 |
array/reductions/reduce/Float32/dims=2L |
71472 ns |
72365.5 ns |
0.99 |
array/reductions/mapreduce/Int64/1d |
42652 ns |
43561 ns |
0.98 |
array/reductions/mapreduce/Int64/dims=1 |
44200 ns |
44306 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2 |
61326 ns |
61482 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=1L |
88603 ns |
89001 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2L |
87623 ns |
88320 ns |
0.99 |
array/reductions/mapreduce/Float32/1d |
36241 ns |
38092.5 ns |
0.95 |
array/reductions/mapreduce/Float32/dims=1 |
45113 ns |
41962 ns |
1.08 |
array/reductions/mapreduce/Float32/dims=2 |
59389 ns |
60039 ns |
0.99 |
array/reductions/mapreduce/Float32/dims=1L |
52306 ns |
52636 ns |
0.99 |
array/reductions/mapreduce/Float32/dims=2L |
71604 ns |
72310 ns |
0.99 |
array/broadcast |
19787 ns |
20127 ns |
0.98 |
array/copyto!/gpu_to_gpu |
12685 ns |
12738 ns |
1.00 |
array/copyto!/cpu_to_gpu |
213773 ns |
217857 ns |
0.98 |
array/copyto!/gpu_to_cpu |
286784 ns |
287088 ns |
1.00 |
array/accumulate/Int64/1d |
124050 ns |
124778 ns |
0.99 |
array/accumulate/Int64/dims=1 |
83524.5 ns |
83708 ns |
1.00 |
array/accumulate/Int64/dims=2 |
157928 ns |
158367 ns |
1.00 |
array/accumulate/Int64/dims=1L |
1709757 ns |
1710164 ns |
1.00 |
array/accumulate/Int64/dims=2L |
965855 ns |
967254 ns |
1.00 |
array/accumulate/Float32/1d |
108765 ns |
109314 ns |
0.99 |
array/accumulate/Float32/dims=1 |
80132 ns |
80184 ns |
1.00 |
array/accumulate/Float32/dims=2 |
147436 ns |
147922 ns |
1.00 |
array/accumulate/Float32/dims=1L |
1618684 ns |
1618786 ns |
1.00 |
array/accumulate/Float32/dims=2L |
697887 ns |
698724 ns |
1.00 |
array/construct |
1307.2 ns |
1295.5 ns |
1.01 |
array/random/randn/Float32 |
47414.5 ns |
47861 ns |
0.99 |
array/random/randn!/Float32 |
24572 ns |
24875 ns |
0.99 |
array/random/rand!/Int64 |
27209 ns |
27408 ns |
0.99 |
array/random/rand!/Float32 |
8733.333333333334 ns |
8909.666666666666 ns |
0.98 |
array/random/rand/Int64 |
38056 ns |
30055 ns |
1.27 |
array/random/rand/Float32 |
13097 ns |
13184 ns |
0.99 |
array/permutedims/4d |
55080.5 ns |
55109 ns |
1.00 |
array/permutedims/2d |
53606 ns |
53832 ns |
1.00 |
array/permutedims/3d |
54539 ns |
54841 ns |
0.99 |
array/sorting/1d |
2757516.5 ns |
2757534 ns |
1.00 |
array/sorting/by |
3344110 ns |
3344541 ns |
1.00 |
array/sorting/2d |
1080569 ns |
1081521 ns |
1.00 |
cuda/synchronization/stream/auto |
1007 ns |
1036.5 ns |
0.97 |
cuda/synchronization/stream/nonblocking |
7651.9 ns |
7410.8 ns |
1.03 |
cuda/synchronization/stream/blocking |
808.6 ns |
820.6336633663366 ns |
0.99 |
cuda/synchronization/context/auto |
1154.2 ns |
1154.3 ns |
1.00 |
cuda/synchronization/context/nonblocking |
7125.8 ns |
7124.4 ns |
1.00 |
cuda/synchronization/context/blocking |
936.6078431372549 ns |
887.4107142857143 ns |
1.06 |
This comment was automatically generated by workflow using github-action-benchmark.
|
The cuDNN run on CUDA 13 appears to fail due to running on SM75. |
|
rebase on master? |
e7e97ca to
6359bf0
Compare
|
Done! |
213affc to
2931a8f
Compare
|
CI failures look related. |
|
Hi Tim. Test outcome seems to depend on compute capability / arch: Tests for cuDNN on CUDA 13.0 ran on an A100 (passed): But the tests for cuDNN on CUDA 12.0 happened to run on a Turing GPU (failed): See also my earlier comment:
Should the tests be conditionally skipped instead? Is there a way to require a certain compute capability for e.g. cuDNN? |
|
Yes, the tests should take care not to cover unsupported code paths by e.g. inspecting |
This PR defines methods for making cuDNN work with
BFloat16s.BFloat16.In the following example, I show how the new methods fixes the
BFloat16backward pass ofFlux.logitcrossentropy:Before
Note: Core.BFloat16 === BFloat16s.BFloat16, but I didn't explicitly import in this REPL session.
After defining cudnnDataType(::Type{BFloat16})
After defining scalingParameter(::Type{BFloat16}, val)
I also define a
cptrmethod for consistency, but it appears the function isn't used anywhere.Tests are added for softmax, activations, and pooling. I initially also tested convolutions, normalization, RNNs, and MHA but they don't appear to support BFloat16.
Along with my proposed fix in FluxML/Optimisers.jl#215, this has allowed me to train LLMs in BFloat16 with Flux.jl in Julia v1.12. I am still tinkering with Optimisers.jl, but these together would be a significant unlock for my lab.