From f2d3d00ce6760bfd7a16ad8b8ef69b3fb493691b Mon Sep 17 00:00:00 2001 From: lluni Date: Sat, 4 Feb 2023 18:41:35 +0100 Subject: [PATCH] Optimized backward_pass in FCLayer --- src/layers/fc_layer.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/layers/fc_layer.rs b/src/layers/fc_layer.rs index 8f4da68..0728a84 100644 --- a/src/layers/fc_layer.rs +++ b/src/layers/fc_layer.rs @@ -85,13 +85,17 @@ impl Layer for FCLayer { fn backward_pass(&mut self, output_error: Array1, learning_rate: f64) -> Array1 { let input_error = output_error.dot(&self.weights.t()); - let delta_weights = self + let mut delta_weights = self .input .to_shape((self.input.len(), 1usize)) .unwrap() .dot(&output_error.to_shape((1usize, output_error.len())).unwrap()); - self.weights = &self.weights + learning_rate * &delta_weights; - self.biases = &self.biases + learning_rate * &output_error; + delta_weights *= learning_rate; + let delta_biases = output_error * learning_rate; + + self.weights += &delta_weights; + self.biases += &delta_biases; + input_error } }