diff --git a/examples/example_sine.rs b/examples/example_sine.rs index e968bcf..87ef3d7 100644 --- a/examples/example_sine.rs +++ b/examples/example_sine.rs @@ -49,7 +49,7 @@ fn main() { Initializer::GaussianWFactor(0.0, 1.0, 0.1), ))); network.add_layer(Box::new(ActivationLayer::new( - activation_functions::Type::LeakyRelu, + activation_functions::Type::Gelu, ))); network.add_layer(Box::new(FCLayer::new( 8, @@ -57,7 +57,7 @@ fn main() { Initializer::GaussianWFactor(0.0, 1.0, 0.1), ))); network.add_layer(Box::new(ActivationLayer::new( - activation_functions::Type::LeakyRelu, + activation_functions::Type::Gelu, ))); network.add_layer(Box::new(FCLayer::new( 1, diff --git a/src/functions/activation_functions.rs b/src/functions/activation_functions.rs index 074166c..0ab9c8b 100644 --- a/src/functions/activation_functions.rs +++ b/src/functions/activation_functions.rs @@ -7,6 +7,7 @@ pub enum Type { Tanh, Relu, LeakyRelu, + Gelu, } type ActFuncTuple = ( @@ -21,6 +22,7 @@ pub fn parse_type(t: Type) -> ActFuncTuple { Type::Tanh => (tanh, tanh_prime), Type::Relu => (relu, relu_prime), Type::LeakyRelu => (leaky_relu, leaky_relu_prime), + Type::Gelu => (gelu, gelu_prime), } } @@ -83,3 +85,15 @@ pub fn leaky_relu_prime(matrix: &Array1) -> Array1 { result.mapv_inplace(|x| if x <= 0.0 { 0.001 } else { 1.0 }); result } + +pub fn gelu(matrix: &Array1) -> Array1 { + let mut result = matrix.clone(); + result.mapv_inplace(|x| x * sigmoid(1.702 * x)); + result +} + +pub fn gelu_prime(matrix: &Array1) -> Array1 { + let mut result = matrix.clone(); + result.mapv_inplace(|x| sigmoid(1.702 * x) * (1.0 + 1.702 * (1.0 - sigmoid(1.702 * x)))); + result +}