Skip to content

Commit 17b124a

Browse files
authored
feat: add support for DLFP8Types (#1491)
* feat: add support for DLFP8Types * Update Project.toml
1 parent 40d98a5 commit 17b124a

File tree

5 files changed

+59
-32
lines changed

5 files changed

+59
-32
lines changed

Project.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
3030
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3131
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3232
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
33+
DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"
34+
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
3335
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
3436
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3537
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
@@ -43,13 +45,15 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4345
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4446
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
4547

46-
[sources.ReactantCore]
47-
path = "lib/ReactantCore"
48+
[sources]
49+
ReactantCore = {path = "lib/ReactantCore"}
4850

4951
[extensions]
5052
ReactantAbstractFFTsExt = "AbstractFFTs"
5153
ReactantArrayInterfaceExt = "ArrayInterface"
5254
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
55+
ReactantDLFP8TypesExt = "DLFP8Types"
56+
ReactantFloat8sExt = "Float8s"
5357
ReactantKernelAbstractionsExt = "KernelAbstractions"
5458
ReactantMPIExt = "MPI"
5559
ReactantNNlibExt = ["NNlib", "Statistics"]
@@ -67,10 +71,12 @@ Adapt = "4.1"
6771
ArrayInterface = "7.17.1"
6872
CEnum = "0.5"
6973
CUDA = "5.6"
74+
DLFP8Types = "0.1"
7075
Downloads = "1.6"
7176
EnumX = "1"
7277
Enzyme = "0.13.49"
7378
EnzymeCore = "0.8.11"
79+
Float8s = "0.1"
7480
Functors = "0.5"
7581
GPUArraysCore = "0.2"
7682
GPUCompiler = "1.3"

ext/ReactantDLFP8TypesExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module ReactantDLFP8TypesExt
2+
3+
using DLFP8Types: Float8_E4M3FN, Float8_E4M3FNUZ, Float8_E5M2, Float8_E5M2FNUZ
4+
using Reactant: Reactant
5+
6+
Reactant.reactant_primitive(::Type{Float8_E4M3FN}) = Reactant.F8E4M3FN
7+
Reactant.reactant_primitive(::Type{Float8_E4M3FNUZ}) = Reactant.F8E4M3FNUZ
8+
Reactant.reactant_primitive(::Type{Float8_E5M2}) = Reactant.F8E5M2
9+
Reactant.reactant_primitive(::Type{Float8_E5M2FNUZ}) = Reactant.F8E5M2FNUZ
10+
11+
end

ext/ReactantFloat8sExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module ReactantFloat8sExt
2+
3+
using Float8s: Float8_4
4+
using Reactant: Reactant
5+
6+
Reactant.reactant_primitive(::Type{Float8_4}) = Reactant.F8E4M3FN
7+
8+
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
44
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5+
DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"
56
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
67
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
78
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
@@ -34,6 +35,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3435
Adapt = "4.1"
3536
ArrayInterface = "7.17.1"
3637
CUDA = "5.5"
38+
DLFP8Types = "0.1"
3739
Distributions = "0.25"
3840
Enzyme = "0.13.28"
3941
FFTW = "1.8"

test/custom_number_types.jl

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
1-
using Float8s, Reactant
1+
using Float8s, DLFP8Types, Reactant
22
using Reactant: TracedRNumber
33

4-
Reactant.reactant_primitive(::Type{Float8_4}) = Reactant.F8E4M3FN
4+
@testset "Custom number types: $(T)" for T in [Float8_4, Float8_E4M3FN, Float8_E4M3FNUZ]
5+
x = T[
6+
-1.125 -0.21875 1.12
7+
1.875 0.4375 1.0
8+
0.5625 -1.0 0.937
9+
-0.375 -0.34375 -0.6875
10+
0.46875 0.75 -0.23437
11+
-0.6875 -0.203125 0.375
12+
0.875 -0.8125 2.5
13+
-0.6875 -0.1171875 -1.625
14+
0.75 0.9375 1.0
15+
0.5 0.203125 1.75
16+
]
17+
x_64 = Float64.(x)
18+
x_ra = Reactant.to_rarray(x)
519

6-
x = Float8_4[
7-
-1.125 -0.21875 1.12
8-
1.875 0.4375 1.0
9-
0.5625 -1.0 0.937
10-
-0.375 -0.34375 -0.6875
11-
0.46875 0.75 -0.23437
12-
-0.6875 -0.203125 0.375
13-
0.875 -0.8125 2.5
14-
-0.6875 -0.1171875 -1.625
15-
0.75 0.9375 1.0
16-
0.5 0.203125 1.75
17-
]
18-
x_64 = Float64.(x)
19-
x_ra = Reactant.to_rarray(x)
20+
@testset "Reductions" begin
21+
sumall(x) = TracedRNumber{Float64}(sum(x))
2022

21-
@testset "Reductions" begin
22-
sumall(x) = TracedRNumber{Float64}(sum(x))
23+
@test @jit(sumall(x_ra)) sum(x_64) atol = 1e-1 rtol = 1e-1
2324

24-
@test @jit(sumall(x_ra)) sum(x_64) atol = 1e-1 rtol = 1e-1
25+
sum1(x) = TracedRNumber{Float64}.(sum(x; dims=1))
26+
sum2(x) = TracedRNumber{Float64}.(sum(x; dims=2))
27+
sum12(x) = TracedRNumber{Float64}.(sum(x; dims=(1, 2)))
2528

26-
sum1(x) = TracedRNumber{Float64}.(sum(x; dims=1))
27-
sum2(x) = TracedRNumber{Float64}.(sum(x; dims=2))
28-
sum12(x) = TracedRNumber{Float64}.(sum(x; dims=(1, 2)))
29+
@test @jit(sum1(x_ra)) sum(x_64; dims=1) atol = 1e-1 rtol = 1e-1
30+
@test @jit(sum2(x_ra)) sum(x_64; dims=2) atol = 1e-1 rtol = 1e-1
31+
@test @jit(sum12(x_ra)) sum(x_64; dims=(1, 2)) atol = 1e-1 rtol = 1e-1
32+
end
2933

30-
@test @jit(sum1(x_ra)) sum(x_64; dims=1) atol = 1e-1 rtol = 1e-1
31-
@test @jit(sum2(x_ra)) sum(x_64; dims=2) atol = 1e-1 rtol = 1e-1
32-
@test @jit(sum12(x_ra)) sum(x_64; dims=(1, 2)) atol = 1e-1 rtol = 1e-1
33-
end
34-
35-
@testset "Broadcasting" begin
36-
fn(x) = TracedRNumber{Float64}.(x .+ 1)
37-
@test @jit(fn(x_ra)) (x_64 .+ 1) atol = 1e-1 rtol = 1e-1
34+
@testset "Broadcasting" begin
35+
fn(x) = TracedRNumber{Float64}.(x .+ 1)
36+
@test @jit(fn(x_ra)) (x_64 .+ 1) atol = 1e-1 rtol = 1e-1
37+
end
3838
end

0 commit comments

Comments
 (0)