rust-nn/src/layers/activation_layer.rs
2023-01-21 15:19:55 +01:00

39 lines
1.2 KiB
Rust

use ndarray::{arr1, Array1, ArrayView1};
use super::Layer;
use crate::functions::activation_functions::*;
pub struct ActivationLayer {
input: Array1<f64>,
output: Array1<f64>,
activation: fn(&Array1<f64>) -> Array1<f64>,
activation_prime: fn(&Array1<f64>) -> Array1<f64>,
}
impl ActivationLayer {
pub fn new(activation_fn: Type) -> Self {
let (activation, activation_prime) = parse_type(activation_fn);
ActivationLayer {
input: arr1(&[]),
output: arr1(&[]),
activation,
activation_prime,
}
}
}
impl Layer for ActivationLayer {
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
self.input = input.to_owned();
self.output = (self.activation)(&self.input);
self.output.clone()
}
fn backward_pass(&mut self, output_error: ArrayView1<f64>, _learning_rate: f64) -> Array1<f64> {
// (self.activation_prime)(&self.input).into_shape((1 as usize, output_error.len() as usize)).unwrap().dot(&output_error)
// (self.activation_prime)(&self.input) * &output_error
let mut temp = (self.activation_prime)(&self.input);
temp.zip_mut_with(&output_error, |x, y| *x *= y);
temp
}
}