use ndarray::Array1; use ndarray_rand::rand_distr::num_traits::Pow; pub enum Type { Identity, Logistic, Tanh, Relu, LeakyRelu, } pub fn parse_type( t: Type, ) -> ( fn(&Array1) -> Array1, fn(&Array1) -> Array1, ) { match t { Type::Identity => (identity, identity_prime), Type::Logistic => (logistic, logistic_prime), Type::Tanh => (tanh, tanh_prime), Type::Relu => (relu, relu_prime), Type::LeakyRelu => (leaky_relu, leaky_relu_prime), } } pub fn identity(matrix: &Array1) -> Array1 { matrix.to_owned() } pub fn identity_prime(matrix: &Array1) -> Array1 { let mut result = matrix.clone(); for x in result.iter_mut() { *x = 1.0; } result } fn sigmoid(x: f64) -> f64 { 1.0 / (1.0 + (-x).exp()) } pub fn logistic(matrix: &Array1) -> Array1 { let mut result = matrix.clone(); for x in result.iter_mut() { *x = sigmoid(*x); } result } pub fn logistic_prime(matrix: &Array1) -> Array1 { let mut result = matrix.clone(); for x in result.iter_mut() { *x = sigmoid(*x * (1.0 - sigmoid(*x))); } result } pub fn tanh(matrix: &Array1) -> Array1 { let mut result = matrix.clone(); for x in result.iter_mut() { *x = (*x).tanh(); } result } pub fn tanh_prime(matrix: &Array1) -> Array1 { let mut result = matrix.clone(); for x in result.iter_mut() { *x = 1.0 as f64 - (*x).tanh().pow(2); } result } pub fn relu(matrix: &Array1) -> Array1 { let mut result = matrix.clone(); for x in result.iter_mut() { *x = (*x).max(0.0); } result } pub fn relu_prime(matrix: &Array1) -> Array1 { let mut result = matrix.clone(); for x in result.iter_mut() { *x = if (*x) <= 0.0 { 0.0 } else { 1.0 }; } result } pub fn leaky_relu(matrix: &Array1) -> Array1 { let mut result = matrix.clone(); for x in result.iter_mut() { *x = (*x).max(0.001 * (*x)); } result } pub fn leaky_relu_prime(matrix: &Array1) -> Array1 { let mut result = matrix.clone(); for x in result.iter_mut() { *x = if (*x) <= 0.0 { 0.001 } else { 1.0 }; } result }