diff --git a/examples/example_sine.rs b/examples/example_sine.rs index 87ef3d7..e968bcf 100644 --- a/examples/example_sine.rs +++ b/examples/example_sine.rs @@ -49,7 +49,7 @@ fn main() { Initializer::GaussianWFactor(0.0, 1.0, 0.1), ))); network.add_layer(Box::new(ActivationLayer::new( - activation_functions::Type::Gelu, + activation_functions::Type::LeakyRelu, ))); network.add_layer(Box::new(FCLayer::new( 8, @@ -57,7 +57,7 @@ fn main() { Initializer::GaussianWFactor(0.0, 1.0, 0.1), ))); network.add_layer(Box::new(ActivationLayer::new( - activation_functions::Type::Gelu, + activation_functions::Type::LeakyRelu, ))); network.add_layer(Box::new(FCLayer::new( 1, diff --git a/src/functions/activation_functions.rs b/src/functions/activation_functions.rs index 0ab9c8b..c65f728 100644 --- a/src/functions/activation_functions.rs +++ b/src/functions/activation_functions.rs @@ -7,7 +7,6 @@ pub enum Type { Tanh, Relu, LeakyRelu, - Gelu, } type ActFuncTuple = ( @@ -22,7 +21,6 @@ pub fn parse_type(t: Type) -> ActFuncTuple { Type::Tanh => (tanh, tanh_prime), Type::Relu => (relu, relu_prime), Type::LeakyRelu => (leaky_relu, leaky_relu_prime), - Type::Gelu => (gelu, gelu_prime), } } @@ -31,7 +29,11 @@ pub fn identity(matrix: &Array1) -> Array1 { } pub fn identity_prime(matrix: &Array1) -> Array1 { - Array1::ones(matrix.raw_dim()) + let mut result = matrix.clone(); + for x in result.iter_mut() { + *x = 1.0; + } + result } fn sigmoid(x: f64) -> f64 { @@ -46,7 +48,7 @@ pub fn logistic(matrix: &Array1) -> Array1 { pub fn logistic_prime(matrix: &Array1) -> Array1 { let mut result = matrix.clone(); - result.mapv_inplace(|x| sigmoid(x) * (1.0 - sigmoid(x))); + result.mapv_inplace(|x| sigmoid(x * (1.0 - sigmoid(x)))); result } @@ -85,15 +87,3 @@ pub fn leaky_relu_prime(matrix: &Array1) -> Array1 { result.mapv_inplace(|x| if x <= 0.0 { 0.001 } else { 1.0 }); result } - -pub fn gelu(matrix: &Array1) -> Array1 { - let mut result = matrix.clone(); - result.mapv_inplace(|x| x * sigmoid(1.702 * x)); - result -} - -pub fn gelu_prime(matrix: &Array1) -> Array1 { - let mut result = matrix.clone(); - result.mapv_inplace(|x| sigmoid(1.702 * x) * (1.0 + 1.702 * (1.0 - sigmoid(1.702 * x)))); - result -}