105 lines
2.3 KiB
Rust
105 lines
2.3 KiB
Rust
use ndarray::Array1;
|
|
use ndarray_rand::rand_distr::num_traits::Pow;
|
|
|
|
pub enum Type {
|
|
Identity,
|
|
Logistic,
|
|
Tanh,
|
|
Relu,
|
|
LeakyRelu,
|
|
}
|
|
|
|
pub fn parse_type(
|
|
t: Type,
|
|
) -> (
|
|
fn(&Array1<f64>) -> Array1<f64>,
|
|
fn(&Array1<f64>) -> Array1<f64>,
|
|
) {
|
|
match t {
|
|
Type::Identity => (identity, identity_prime),
|
|
Type::Logistic => (logistic, logistic_prime),
|
|
Type::Tanh => (tanh, tanh_prime),
|
|
Type::Relu => (relu, relu_prime),
|
|
Type::LeakyRelu => (leaky_relu, leaky_relu_prime),
|
|
}
|
|
}
|
|
|
|
pub fn identity(matrix: &Array1<f64>) -> Array1<f64> {
|
|
matrix.to_owned()
|
|
}
|
|
|
|
pub fn identity_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
for x in result.iter_mut() {
|
|
*x = 1.0;
|
|
}
|
|
result
|
|
}
|
|
|
|
fn sigmoid(x: f64) -> f64 {
|
|
1.0 / (1.0 + (-x).exp())
|
|
}
|
|
|
|
pub fn logistic(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
for x in result.iter_mut() {
|
|
*x = sigmoid(*x);
|
|
}
|
|
result
|
|
}
|
|
|
|
pub fn logistic_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
for x in result.iter_mut() {
|
|
*x = sigmoid(*x * (1.0 - sigmoid(*x)));
|
|
}
|
|
result
|
|
}
|
|
|
|
pub fn tanh(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
for x in result.iter_mut() {
|
|
*x = (*x).tanh();
|
|
}
|
|
result
|
|
}
|
|
|
|
pub fn tanh_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
for x in result.iter_mut() {
|
|
*x = 1.0 as f64 - (*x).tanh().pow(2);
|
|
}
|
|
result
|
|
}
|
|
|
|
pub fn relu(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
for x in result.iter_mut() {
|
|
*x = (*x).max(0.0);
|
|
}
|
|
result
|
|
}
|
|
|
|
pub fn relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
for x in result.iter_mut() {
|
|
*x = if (*x) <= 0.0 { 0.0 } else { 1.0 };
|
|
}
|
|
result
|
|
}
|
|
|
|
pub fn leaky_relu(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
for x in result.iter_mut() {
|
|
*x = (*x).max(0.001 * (*x));
|
|
}
|
|
result
|
|
}
|
|
|
|
pub fn leaky_relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
for x in result.iter_mut() {
|
|
*x = if (*x) <= 0.0 { 0.001 } else { 1.0 };
|
|
}
|
|
result
|
|
}
|