rust-nn/src/functions/activation_functions.rs

85 lines
2 KiB
Rust

use ndarray::Array1;
use ndarray_rand::rand_distr::num_traits::Pow;
pub enum Type {
Identity,
Logistic,
Tanh,
Relu,
LeakyRelu,
}
type ActFuncTuple = (
fn(&Array1<f64>) -> Array1<f64>,
fn(&Array1<f64>) -> Array1<f64>,
);
pub fn parse_type(t: Type) -> ActFuncTuple {
match t {
Type::Identity => (identity, identity_prime),
Type::Logistic => (logistic, logistic_prime),
Type::Tanh => (tanh, tanh_prime),
Type::Relu => (relu, relu_prime),
Type::LeakyRelu => (leaky_relu, leaky_relu_prime),
}
}
pub fn identity(matrix: &Array1<f64>) -> Array1<f64> {
matrix.to_owned()
}
pub fn identity_prime(matrix: &Array1<f64>) -> Array1<f64> {
Array1::ones(matrix.raw_dim())
}
fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
pub fn logistic(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
result.mapv_inplace(sigmoid);
result
}
pub fn logistic_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
result.mapv_inplace(|x| sigmoid(x * (1.0 - sigmoid(x))));
result
}
pub fn tanh(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
result.mapv_inplace(|x| x.tanh());
result
}
pub fn tanh_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
result.mapv_inplace(|x| 1.0 - x.tanh().pow(2));
result
}
pub fn relu(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
result.mapv_inplace(|x| x.max(0.0));
result
}
pub fn relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
result.mapv_inplace(|x| if x <= 0.0 { 0.0 } else { 1.0 });
result
}
pub fn leaky_relu(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
result.mapv_inplace(|x| x.max(0.001 * x));
result
}
pub fn leaky_relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
let mut result = matrix.clone();
result.mapv_inplace(|x| if x <= 0.0 { 0.001 } else { 1.0 });
result
}