39 lines
1.2 KiB
Rust
39 lines
1.2 KiB
Rust
use ndarray::{arr1, Array1, ArrayView1};
|
|
|
|
use super::Layer;
|
|
use crate::functions::activation_functions::*;
|
|
|
|
pub struct ActivationLayer {
|
|
input: Array1<f64>,
|
|
output: Array1<f64>,
|
|
activation: fn(&Array1<f64>) -> Array1<f64>,
|
|
activation_prime: fn(&Array1<f64>) -> Array1<f64>,
|
|
}
|
|
|
|
impl ActivationLayer {
|
|
pub fn new(activation_fn: Type) -> Self {
|
|
let (activation, activation_prime) = parse_type(activation_fn);
|
|
ActivationLayer {
|
|
input: arr1(&[]),
|
|
output: arr1(&[]),
|
|
activation,
|
|
activation_prime,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Layer for ActivationLayer {
|
|
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
|
|
self.input = input.to_owned();
|
|
self.output = (self.activation)(&self.input);
|
|
self.output.clone()
|
|
}
|
|
|
|
fn backward_pass(&mut self, output_error: ArrayView1<f64>, _learning_rate: f64) -> Array1<f64> {
|
|
// (self.activation_prime)(&self.input).into_shape((1 as usize, output_error.len() as usize)).unwrap().dot(&output_error)
|
|
// (self.activation_prime)(&self.input) * &output_error
|
|
let mut temp = (self.activation_prime)(&self.input);
|
|
temp.zip_mut_with(&output_error, |x, y| *x *= y);
|
|
temp
|
|
}
|
|
}
|