Added (approximated) GELU activation function
This commit is contained in:
parent
e564f1de30
commit
a39ac835e4
2 changed files with 16 additions and 2 deletions
|
@ -49,7 +49,7 @@ fn main() {
|
|||
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
||||
)));
|
||||
network.add_layer(Box::new(ActivationLayer::new(
|
||||
activation_functions::Type::LeakyRelu,
|
||||
activation_functions::Type::Gelu,
|
||||
)));
|
||||
network.add_layer(Box::new(FCLayer::new(
|
||||
8,
|
||||
|
@ -57,7 +57,7 @@ fn main() {
|
|||
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
||||
)));
|
||||
network.add_layer(Box::new(ActivationLayer::new(
|
||||
activation_functions::Type::LeakyRelu,
|
||||
activation_functions::Type::Gelu,
|
||||
)));
|
||||
network.add_layer(Box::new(FCLayer::new(
|
||||
1,
|
||||
|
|
|
@ -7,6 +7,7 @@ pub enum Type {
|
|||
Tanh,
|
||||
Relu,
|
||||
LeakyRelu,
|
||||
Gelu,
|
||||
}
|
||||
|
||||
type ActFuncTuple = (
|
||||
|
@ -21,6 +22,7 @@ pub fn parse_type(t: Type) -> ActFuncTuple {
|
|||
Type::Tanh => (tanh, tanh_prime),
|
||||
Type::Relu => (relu, relu_prime),
|
||||
Type::LeakyRelu => (leaky_relu, leaky_relu_prime),
|
||||
Type::Gelu => (gelu, gelu_prime),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,3 +85,15 @@ pub fn leaky_relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
|||
result.mapv_inplace(|x| if x <= 0.0 { 0.001 } else { 1.0 });
|
||||
result
|
||||
}
|
||||
|
||||
pub fn gelu(matrix: &Array1<f64>) -> Array1<f64> {
|
||||
let mut result = matrix.clone();
|
||||
result.mapv_inplace(|x| x * sigmoid(1.702 * x));
|
||||
result
|
||||
}
|
||||
|
||||
pub fn gelu_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
||||
let mut result = matrix.clone();
|
||||
result.mapv_inplace(|x| sigmoid(1.702 * x) * (1.0 + 1.702 * (1.0 - sigmoid(1.702 * x))));
|
||||
result
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue