Optimized activation functions

This commit is contained in:
lluni 2023-02-04 19:15:00 +01:00
parent f2d3d00ce6
commit a02110f2db
Signed by: lluni
GPG key ID: ACEEB468BC325D35

View file

@ -42,64 +42,48 @@ fn sigmoid(x: f64) -> f64 {
pub fn logistic(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = sigmoid(*x);
}
result.mapv_inplace(sigmoid);
result
}
pub fn logistic_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = sigmoid(*x * (1.0 - sigmoid(*x)));
}
result.mapv_inplace(|x| sigmoid(x * (1.0 - sigmoid(x))));
result
}
pub fn tanh(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = (*x).tanh();
}
result.mapv_inplace(|x| x.tanh());
result
}
pub fn tanh_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = 1.0 - (*x).tanh().pow(2);
}
result.mapv_inplace(|x| 1.0 - x.tanh().pow(2));
result
}
pub fn relu(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = (*x).max(0.0);
}
result.mapv_inplace(|x| x.max(0.0));
result
}
pub fn relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = if (*x) <= 0.0 { 0.0 } else { 1.0 };
}
result.mapv_inplace(|x| if x <= 0.0 { 0.0 } else { 1.0 });
result
}
pub fn leaky_relu(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = (*x).max(0.001 * (*x));
}
result.mapv_inplace(|x| x.max(0.001 * x));
result
}
pub fn leaky_relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = if (*x) <= 0.0 { 0.001 } else { 1.0 };
}
result.mapv_inplace(|x| if x <= 0.0 { 0.001 } else { 1.0 });
result
}