85 lines
2 KiB
Rust
85 lines
2 KiB
Rust
use ndarray::Array1;
|
|
use ndarray_rand::rand_distr::num_traits::Pow;
|
|
|
|
pub enum Type {
|
|
Identity,
|
|
Logistic,
|
|
Tanh,
|
|
Relu,
|
|
LeakyRelu,
|
|
}
|
|
|
|
type ActFuncTuple = (
|
|
fn(&Array1<f64>) -> Array1<f64>,
|
|
fn(&Array1<f64>) -> Array1<f64>,
|
|
);
|
|
|
|
pub fn parse_type(t: Type) -> ActFuncTuple {
|
|
match t {
|
|
Type::Identity => (identity, identity_prime),
|
|
Type::Logistic => (logistic, logistic_prime),
|
|
Type::Tanh => (tanh, tanh_prime),
|
|
Type::Relu => (relu, relu_prime),
|
|
Type::LeakyRelu => (leaky_relu, leaky_relu_prime),
|
|
}
|
|
}
|
|
|
|
pub fn identity(matrix: &Array1<f64>) -> Array1<f64> {
|
|
matrix.to_owned()
|
|
}
|
|
|
|
pub fn identity_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
|
Array1::ones(matrix.raw_dim())
|
|
}
|
|
|
|
fn sigmoid(x: f64) -> f64 {
|
|
1.0 / (1.0 + (-x).exp())
|
|
}
|
|
|
|
pub fn logistic(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
result.mapv_inplace(sigmoid);
|
|
result
|
|
}
|
|
|
|
pub fn logistic_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
result.mapv_inplace(|x| sigmoid(x * (1.0 - sigmoid(x))));
|
|
result
|
|
}
|
|
|
|
pub fn tanh(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
result.mapv_inplace(|x| x.tanh());
|
|
result
|
|
}
|
|
|
|
pub fn tanh_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
result.mapv_inplace(|x| 1.0 - x.tanh().pow(2));
|
|
result
|
|
}
|
|
|
|
pub fn relu(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
result.mapv_inplace(|x| x.max(0.0));
|
|
result
|
|
}
|
|
|
|
pub fn relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
result.mapv_inplace(|x| if x <= 0.0 { 0.0 } else { 1.0 });
|
|
result
|
|
}
|
|
|
|
pub fn leaky_relu(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
result.mapv_inplace(|x| x.max(0.001 * x));
|
|
result
|
|
}
|
|
|
|
pub fn leaky_relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
|
let mut result = matrix.clone();
|
|
result.mapv_inplace(|x| if x <= 0.0 { 0.001 } else { 1.0 });
|
|
result
|
|
}
|