Compare commits

...

3 commits

Author SHA1 Message Date
a39ac835e4
Added (approximated) GELU activation function 2023-02-04 20:35:17 +01:00
e564f1de30
Fixed logistic_prime 2023-02-04 20:22:00 +01:00
9a88ef6071
Forgot an activation function 2023-02-04 20:01:04 +01:00
2 changed files with 18 additions and 8 deletions

View file

@ -49,7 +49,7 @@ fn main() {
Initializer::GaussianWFactor(0.0, 1.0, 0.1), Initializer::GaussianWFactor(0.0, 1.0, 0.1),
))); )));
network.add_layer(Box::new(ActivationLayer::new( network.add_layer(Box::new(ActivationLayer::new(
activation_functions::Type::LeakyRelu, activation_functions::Type::Gelu,
))); )));
network.add_layer(Box::new(FCLayer::new( network.add_layer(Box::new(FCLayer::new(
8, 8,
@ -57,7 +57,7 @@ fn main() {
Initializer::GaussianWFactor(0.0, 1.0, 0.1), Initializer::GaussianWFactor(0.0, 1.0, 0.1),
))); )));
network.add_layer(Box::new(ActivationLayer::new( network.add_layer(Box::new(ActivationLayer::new(
activation_functions::Type::LeakyRelu, activation_functions::Type::Gelu,
))); )));
network.add_layer(Box::new(FCLayer::new( network.add_layer(Box::new(FCLayer::new(
1, 1,

View file

@ -7,6 +7,7 @@ pub enum Type {
Tanh, Tanh,
Relu, Relu,
LeakyRelu, LeakyRelu,
Gelu,
} }
type ActFuncTuple = ( type ActFuncTuple = (
@ -21,6 +22,7 @@ pub fn parse_type(t: Type) -> ActFuncTuple {
Type::Tanh => (tanh, tanh_prime), Type::Tanh => (tanh, tanh_prime),
Type::Relu => (relu, relu_prime), Type::Relu => (relu, relu_prime),
Type::LeakyRelu => (leaky_relu, leaky_relu_prime), Type::LeakyRelu => (leaky_relu, leaky_relu_prime),
Type::Gelu => (gelu, gelu_prime),
} }
} }
@ -29,11 +31,7 @@ pub fn identity(matrix: &Array1<f64>) -> Array1<f64> {
} }
pub fn identity_prime(matrix: &Array1<f64>) -> Array1<f64> { pub fn identity_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone(); Array1::ones(matrix.raw_dim())
for x in result.iter_mut() {
*x = 1.0;
}
result
} }
fn sigmoid(x: f64) -> f64 { fn sigmoid(x: f64) -> f64 {
@ -48,7 +46,7 @@ pub fn logistic(matrix: &Array1<f64>) -> Array1<f64> {
pub fn logistic_prime(matrix: &Array1<f64>) -> Array1<f64> { pub fn logistic_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone(); let mut result = matrix.clone();
result.mapv_inplace(|x| sigmoid(x * (1.0 - sigmoid(x)))); result.mapv_inplace(|x| sigmoid(x) * (1.0 - sigmoid(x)));
result result
} }
@ -87,3 +85,15 @@ pub fn leaky_relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
result.mapv_inplace(|x| if x <= 0.0 { 0.001 } else { 1.0 }); result.mapv_inplace(|x| if x <= 0.0 { 0.001 } else { 1.0 });
result result
} }
pub fn gelu(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
result.mapv_inplace(|x| x * sigmoid(1.702 * x));
result
}
pub fn gelu_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
result.mapv_inplace(|x| sigmoid(1.702 * x) * (1.0 + 1.702 * (1.0 - sigmoid(1.702 * x))));
result
}