From 98bc599dac1614b5aae537c678e1b8ab3633fa7f Mon Sep 17 00:00:00 2001 From: lluni Date: Sat, 4 Feb 2023 18:24:19 +0100 Subject: [PATCH] Improved backward_pass --- src/layers/activation_layer.rs | 4 ++-- src/layers/fc_layer.rs | 13 ++++--------- src/layers/mod.rs | 4 ++-- src/lib.rs | 4 ++-- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/layers/activation_layer.rs b/src/layers/activation_layer.rs index a3c47f6..48e2502 100644 --- a/src/layers/activation_layer.rs +++ b/src/layers/activation_layer.rs @@ -1,4 +1,4 @@ -use ndarray::{arr1, Array1, ArrayView1}; +use ndarray::{arr1, Array1}; use super::Layer; use crate::functions::activation_functions::*; @@ -31,7 +31,7 @@ impl Layer for ActivationLayer { (self.activation)(&self.input) } - fn backward_pass(&mut self, output_error: ArrayView1, _learning_rate: f64) -> Array1 { + fn backward_pass(&mut self, output_error: Array1, _learning_rate: f64) -> Array1 { (self.activation_prime)(&self.input) * output_error } } diff --git a/src/layers/fc_layer.rs b/src/layers/fc_layer.rs index 9fb9848..8f4da68 100644 --- a/src/layers/fc_layer.rs +++ b/src/layers/fc_layer.rs @@ -1,6 +1,6 @@ extern crate ndarray; -use ndarray::{arr1, arr2, Array, Array1, Array2, ArrayView1, ShapeBuilder}; +use ndarray::{arr1, arr2, Array, Array1, Array2, ShapeBuilder}; use ndarray_rand::rand_distr::{Normal, Uniform}; use ndarray_rand::RandomExt; @@ -83,18 +83,13 @@ impl Layer for FCLayer { self.input.dot(&self.weights) + &self.biases } - fn backward_pass(&mut self, output_error: ArrayView1, learning_rate: f64) -> Array1 { + fn backward_pass(&mut self, output_error: Array1, learning_rate: f64) -> Array1 { let input_error = output_error.dot(&self.weights.t()); let delta_weights = self .input - .to_owned() - .into_shape((self.input.len(), 1usize)) + .to_shape((self.input.len(), 1usize)) .unwrap() - .dot( - &output_error - .into_shape((1usize, output_error.len())) - .unwrap(), - ); + .dot(&output_error.to_shape((1usize, output_error.len())).unwrap()); self.weights = &self.weights + learning_rate * &delta_weights; self.biases = &self.biases + learning_rate * &output_error; input_error diff --git a/src/layers/mod.rs b/src/layers/mod.rs index 8a3d5c0..cbcfb28 100644 --- a/src/layers/mod.rs +++ b/src/layers/mod.rs @@ -1,9 +1,9 @@ -use ndarray::{Array1, ArrayView1}; +use ndarray::Array1; pub mod activation_layer; pub mod fc_layer; pub trait Layer { fn forward_pass(&mut self, input: Array1) -> Array1; - fn backward_pass(&mut self, output_error: ArrayView1, learning_rate: f64) -> Array1; + fn backward_pass(&mut self, output_error: Array1, learning_rate: f64) -> Array1; } diff --git a/src/lib.rs b/src/lib.rs index 26520ec..aecc97a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -68,9 +68,9 @@ impl Network { let mut error = (self.loss_prime)(y_train[j].view(), output.view()); for layer in self.layers.iter_mut().rev() { if trivial_optimize { - error = layer.backward_pass(error.view(), learning_rate / (i + 1) as f64); + error = layer.backward_pass(error, learning_rate / (i + 1) as f64); } else { - error = layer.backward_pass(error.view(), learning_rate); + error = layer.backward_pass(error, learning_rate); } } }