Optimized backward_pass in FCLayer
This commit is contained in:
parent
7b12a054d5
commit
f2d3d00ce6
1 changed files with 7 additions and 3 deletions
|
@ -85,13 +85,17 @@ impl Layer for FCLayer {
|
||||||
|
|
||||||
fn backward_pass(&mut self, output_error: Array1<f64>, learning_rate: f64) -> Array1<f64> {
|
fn backward_pass(&mut self, output_error: Array1<f64>, learning_rate: f64) -> Array1<f64> {
|
||||||
let input_error = output_error.dot(&self.weights.t());
|
let input_error = output_error.dot(&self.weights.t());
|
||||||
let delta_weights = self
|
let mut delta_weights = self
|
||||||
.input
|
.input
|
||||||
.to_shape((self.input.len(), 1usize))
|
.to_shape((self.input.len(), 1usize))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.dot(&output_error.to_shape((1usize, output_error.len())).unwrap());
|
.dot(&output_error.to_shape((1usize, output_error.len())).unwrap());
|
||||||
self.weights = &self.weights + learning_rate * &delta_weights;
|
delta_weights *= learning_rate;
|
||||||
self.biases = &self.biases + learning_rate * &output_error;
|
let delta_biases = output_error * learning_rate;
|
||||||
|
|
||||||
|
self.weights += &delta_weights;
|
||||||
|
self.biases += &delta_biases;
|
||||||
|
|
||||||
input_error
|
input_error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue