From 9a88ef60716105c34d164db1e49292dd3a3187fd Mon Sep 17 00:00:00 2001 From: lluni Date: Sat, 4 Feb 2023 20:01:04 +0100 Subject: [PATCH 1/3] Forgot an activation function --- src/functions/activation_functions.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/functions/activation_functions.rs b/src/functions/activation_functions.rs index c65f728..d233ffb 100644 --- a/src/functions/activation_functions.rs +++ b/src/functions/activation_functions.rs @@ -29,11 +29,7 @@ pub fn identity(matrix: &Array1) -> Array1 { } pub fn identity_prime(matrix: &Array1) -> Array1 { - let mut result = matrix.clone(); - for x in result.iter_mut() { - *x = 1.0; - } - result + Array1::ones(matrix.raw_dim()) } fn sigmoid(x: f64) -> f64 { From e564f1de30f49a513fc3ac5c610aefb515ca1c6e Mon Sep 17 00:00:00 2001 From: lluni Date: Sat, 4 Feb 2023 20:22:00 +0100 Subject: [PATCH 2/3] Fixed logistic_prime --- src/functions/activation_functions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/functions/activation_functions.rs b/src/functions/activation_functions.rs index d233ffb..074166c 100644 --- a/src/functions/activation_functions.rs +++ b/src/functions/activation_functions.rs @@ -44,7 +44,7 @@ pub fn logistic(matrix: &Array1) -> Array1 { pub fn logistic_prime(matrix: &Array1) -> Array1 { let mut result = matrix.clone(); - result.mapv_inplace(|x| sigmoid(x * (1.0 - sigmoid(x)))); + result.mapv_inplace(|x| sigmoid(x) * (1.0 - sigmoid(x))); result } From a39ac835e4d3435a34c03ddc18d008d07113ce63 Mon Sep 17 00:00:00 2001 From: lluni Date: Sat, 4 Feb 2023 20:35:17 +0100 Subject: [PATCH 3/3] Added (approximated) GELU activation function --- examples/example_sine.rs | 4 ++-- src/functions/activation_functions.rs | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/examples/example_sine.rs b/examples/example_sine.rs index e968bcf..87ef3d7 100644 --- a/examples/example_sine.rs +++ b/examples/example_sine.rs @@ -49,7 +49,7 @@ fn main() { Initializer::GaussianWFactor(0.0, 1.0, 0.1), ))); network.add_layer(Box::new(ActivationLayer::new( - activation_functions::Type::LeakyRelu, + activation_functions::Type::Gelu, ))); network.add_layer(Box::new(FCLayer::new( 8, @@ -57,7 +57,7 @@ fn main() { Initializer::GaussianWFactor(0.0, 1.0, 0.1), ))); network.add_layer(Box::new(ActivationLayer::new( - activation_functions::Type::LeakyRelu, + activation_functions::Type::Gelu, ))); network.add_layer(Box::new(FCLayer::new( 1, diff --git a/src/functions/activation_functions.rs b/src/functions/activation_functions.rs index 074166c..0ab9c8b 100644 --- a/src/functions/activation_functions.rs +++ b/src/functions/activation_functions.rs @@ -7,6 +7,7 @@ pub enum Type { Tanh, Relu, LeakyRelu, + Gelu, } type ActFuncTuple = ( @@ -21,6 +22,7 @@ pub fn parse_type(t: Type) -> ActFuncTuple { Type::Tanh => (tanh, tanh_prime), Type::Relu => (relu, relu_prime), Type::LeakyRelu => (leaky_relu, leaky_relu_prime), + Type::Gelu => (gelu, gelu_prime), } } @@ -83,3 +85,15 @@ pub fn leaky_relu_prime(matrix: &Array1) -> Array1 { result.mapv_inplace(|x| if x <= 0.0 { 0.001 } else { 1.0 }); result } + +pub fn gelu(matrix: &Array1) -> Array1 { + let mut result = matrix.clone(); + result.mapv_inplace(|x| x * sigmoid(1.702 * x)); + result +} + +pub fn gelu_prime(matrix: &Array1) -> Array1 { + let mut result = matrix.clone(); + result.mapv_inplace(|x| sigmoid(1.702 * x) * (1.0 + 1.702 * (1.0 - sigmoid(1.702 * x)))); + result +}