rust-nn/src/functions/activation_functions.rs
2023-01-21 15:19:55 +01:00

105 lines
2.3 KiB
Rust

use ndarray::Array1;
use ndarray_rand::rand_distr::num_traits::Pow;
pub enum Type {
Identity,
Logistic,
Tanh,
Relu,
LeakyRelu,
}
pub fn parse_type(
t: Type,
) -> (
fn(&Array1<f64>) -> Array1<f64>,
fn(&Array1<f64>) -> Array1<f64>,
) {
match t {
Type::Identity => (identity, identity_prime),
Type::Logistic => (logistic, logistic_prime),
Type::Tanh => (tanh, tanh_prime),
Type::Relu => (relu, relu_prime),
Type::LeakyRelu => (leaky_relu, leaky_relu_prime),
}
}
pub fn identity(matrix: &Array1<f64>) -> Array1<f64> {
matrix.to_owned()
}
pub fn identity_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = 1.0;
}
result
}
fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
pub fn logistic(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = sigmoid(*x);
}
result
}
pub fn logistic_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = sigmoid(*x * (1.0 - sigmoid(*x)));
}
result
}
pub fn tanh(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = (*x).tanh();
}
result
}
pub fn tanh_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = 1.0 as f64 - (*x).tanh().pow(2);
}
result
}
pub fn relu(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = (*x).max(0.0);
}
result
}
pub fn relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = if (*x) <= 0.0 { 0.0 } else { 1.0 };
}
result
}
pub fn leaky_relu(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = (*x).max(0.001 * (*x));
}
result
}
pub fn leaky_relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
for x in result.iter_mut() {
*x = if (*x) <= 0.0 { 0.001 } else { 1.0 };
}
result
}