use ndarray::{arr1, Array1, ArrayView1}; use super::Layer; use crate::functions::activation_functions::*; pub struct ActivationLayer { input: Array1, output: Array1, activation: fn(&Array1) -> Array1, activation_prime: fn(&Array1) -> Array1, } impl ActivationLayer { pub fn new(activation_fn: Type) -> Self { let (activation, activation_prime) = parse_type(activation_fn); ActivationLayer { input: arr1(&[]), output: arr1(&[]), activation, activation_prime, } } } impl Layer for ActivationLayer { fn forward_pass(&mut self, input: ArrayView1) -> Array1 { self.input = input.to_owned(); self.output = (self.activation)(&self.input); self.output.clone() } fn backward_pass(&mut self, output_error: ArrayView1, _learning_rate: f64) -> Array1 { // (self.activation_prime)(&self.input).into_shape((1 as usize, output_error.len() as usize)).unwrap().dot(&output_error) // (self.activation_prime)(&self.input) * &output_error let mut temp = (self.activation_prime)(&self.input); temp.zip_mut_with(&output_error, |x, y| *x *= y); temp } }