From cc662f749b0c5153989596d87e2a5286e6178aaa Mon Sep 17 00:00:00 2001 From: matsueushi Date: Wed, 4 Mar 2020 05:08:08 +0000 Subject: [PATCH 1/2] Wrapper functions for NNlib --- Manifest.toml | 4 ++-- src/nnlib.jl | 44 +++++++++++++++++++++++++------------------- test/dnn.jl | 5 ++++- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 97b96ac6..a031af43 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -103,9 +103,9 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[NNlib]] deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"] -git-tree-sha1 = "21a3c22bc197b6ae2f8d4d75631876e2b6506dbe" +git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.6.5" +version = "0.6.6" [[OrderedCollections]] deps = ["Random", "Serialization", "Test"] diff --git a/src/nnlib.jl b/src/nnlib.jl index a2e7844a..f62e901e 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -2,30 +2,36 @@ using NNlib # Activation functions -@cufunc σ(x) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x))) +@cufunc σ(x::Real) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x))) -@cufunc function logσ(x) - max_v = max(zero(x), -x) - z = exp(-max_v) + exp(-x-max_v) - -(max_v + log(z)) -end +@cufunc softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x))) -@cufunc elu(x, α = one(x)) = - ifelse(x ≥ 0, x/1, α * (exp(x) - one(x))) +@cufunc logσ(x::Real) = -softplus(-x) -@cufunc swish(x) = x * σ(x) +@cufunc elu(x::Real, α = one(x)) = ifelse(x ≥ 0, x / one(x), α * (exp(x) - one(x))) -@cufunc function gelu(x) - λ = oftype(x/1, √(2/π)) - α = oftype(x/1, 0.044715) - h = oftype(x/1, 0.5) - h * x * (one(x) + tanh(λ * (x + α * x^3))) +@cufunc function gelu(x::Real) + p = oftype(x / 1, π) + λ = oftype(x / 1, √(2 / p)) + α = oftype(x / 1, 0.044715) + h = oftype(x / 1, 0.5) + h * x * (one(x) + tanh(λ * (x + α * x^3))) end -@cufunc function selu(x) - λ = oftype(x/1, 1.0507009873554804934193349852946) - α = oftype(x/1, 1.6732632423543772848170429916717) - λ * ifelse(x > 0, x/1, α * (exp(x) - 1)) +@cufunc swish(x::Real) = x * σ(x) + +@cufunc lisht(x::Real) = x * tanh(x) + +@cufunc function selu(x::Real) + λ = oftype(x / 1, 1.0507009873554804934193349852946) + α = oftype(x / 1, 1.6732632423543772848170429916717) + λ * ifelse(x > 0, x / one(x), α * (exp(x) - one(x))) end -@cufunc softplus(x) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x))) +@cufunc celu(x::Real, α::Real = one(x)) = ifelse(x ≥ 0, x / one(x), α * (exp(x/α) - one(x))) + +@cufunc logcosh(x::Real) = x + softplus(-2x) - log(oftype(x, 2)) + +@cufunc mish(x::Real) = x * tanh(softplus(x)) + +@cufunc tanhshrink(x::Real) = x - tanh(x) diff --git a/test/dnn.jl b/test/dnn.jl index 413feab0..c437b459 100644 --- a/test/dnn.jl +++ b/test/dnn.jl @@ -80,8 +80,11 @@ end @test testf(CUDNN.cudnnActivationBackward, cu(rand(Float64, 10, 10, 3, 1)), cu(rand(Float64, 10, 10, 3, 1)), cu(rand(Float64, 10, 10, 3, 1)), cu(rand(Float64, 10, 10, 3, 1))) # activations defined in src/nnlib.jl + ACTIVATION_FUNCTIONS = [σ, logσ, hardσ, hardtanh, relu, leakyrelu, relu6, rrelu, + elu, gelu, celu, swish, lisht, selu, trelu, softplus, + softsign, logcosh, mish, tanhshrink, softshrink]; for dims in ((5,5), (5,)) - for f in (σ, logσ, elu, swish, gelu, selu, softplus) + for f in filter(x -> x != rrelu, ACTIVATION_FUNCTIONS) @test testf(x -> f.(x), rand(Float64, dims)) end end From 2a96cacc11242db96d53b975075d396661509886 Mon Sep 17 00:00:00 2001 From: matsueushi Date: Sat, 14 Mar 2020 20:23:28 +0000 Subject: [PATCH 2/2] Remove wrappers --- src/nnlib.jl | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/nnlib.jl b/src/nnlib.jl index 07d036fe..57cdd36a 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -1,14 +1,10 @@ using NNlib # Activation functions -@cufunc σ(x::Real) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x))) - @cufunc softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x))) @cufunc logσ(x::Real) = -softplus(-x) -@cufunc elu(x::Real, α = one(x)) = ifelse(x ≥ 0, x / one(x), α * (exp(x) - one(x))) - @cufunc function gelu(x::Real) p = oftype(x / 1, π) λ = oftype(x / 1, √(2 / p)) @@ -17,24 +13,15 @@ using NNlib h * x * (one(x) + tanh(λ * (x + α * x^3))) end -@cufunc swish(x::Real) = x * σ(x) - @cufunc lisht(x::Real) = x * tanh(x) -@cufunc function selu(x::Real) - λ = oftype(x / 1, 1.0507009873554804934193349852946) - α = oftype(x / 1, 1.6732632423543772848170429916717) - λ * ifelse(x > 0, x / one(x), α * (exp(x) - one(x))) -end - -@cufunc celu(x::Real, α::Real = one(x)) = ifelse(x ≥ 0, x / one(x), α * (exp(x/α) - one(x))) - @cufunc logcosh(x::Real) = x + softplus(-2x) - log(oftype(x, 2)) @cufunc mish(x::Real) = x * tanh(softplus(x)) @cufunc tanhshrink(x::Real) = x - tanh(x) + # Batched matrix multiplication _BATCHED_GEMM_LIST = [