Optimized backward_pass in FCLayer

This commit is contained in:
lluni 2023-02-04 18:41:35 +01:00
parent 7b12a054d5
commit f2d3d00ce6
Signed by: lluni
GPG key ID: ACEEB468BC325D35

View file

@ -85,13 +85,17 @@ impl Layer for FCLayer {
fn backward_pass(&mut self, output_error: Array1<f64>, learning_rate: f64) -> Array1<f64> { fn backward_pass(&mut self, output_error: Array1<f64>, learning_rate: f64) -> Array1<f64> {
let input_error = output_error.dot(&self.weights.t()); let input_error = output_error.dot(&self.weights.t());
let delta_weights = self let mut delta_weights = self
.input .input
.to_shape((self.input.len(), 1usize)) .to_shape((self.input.len(), 1usize))
.unwrap() .unwrap()
.dot(&output_error.to_shape((1usize, output_error.len())).unwrap()); .dot(&output_error.to_shape((1usize, output_error.len())).unwrap());
self.weights = &self.weights + learning_rate * &delta_weights; delta_weights *= learning_rate;
self.biases = &self.biases + learning_rate * &output_error; let delta_biases = output_error * learning_rate;
self.weights += &delta_weights;
self.biases += &delta_biases;
input_error input_error
} }
} }