Improved backward_pass

This commit is contained in:
lluni 2023-02-04 18:24:19 +01:00
parent f2e54cfac1
commit 98bc599dac
Signed by: lluni
GPG key ID: ACEEB468BC325D35
4 changed files with 10 additions and 15 deletions

View file

@ -1,4 +1,4 @@
use ndarray::{arr1, Array1, ArrayView1}; use ndarray::{arr1, Array1};
use super::Layer; use super::Layer;
use crate::functions::activation_functions::*; use crate::functions::activation_functions::*;
@ -31,7 +31,7 @@ impl Layer for ActivationLayer {
(self.activation)(&self.input) (self.activation)(&self.input)
} }
fn backward_pass(&mut self, output_error: ArrayView1<f64>, _learning_rate: f64) -> Array1<f64> { fn backward_pass(&mut self, output_error: Array1<f64>, _learning_rate: f64) -> Array1<f64> {
(self.activation_prime)(&self.input) * output_error (self.activation_prime)(&self.input) * output_error
} }
} }

View file

@ -1,6 +1,6 @@
extern crate ndarray; extern crate ndarray;
use ndarray::{arr1, arr2, Array, Array1, Array2, ArrayView1, ShapeBuilder}; use ndarray::{arr1, arr2, Array, Array1, Array2, ShapeBuilder};
use ndarray_rand::rand_distr::{Normal, Uniform}; use ndarray_rand::rand_distr::{Normal, Uniform};
use ndarray_rand::RandomExt; use ndarray_rand::RandomExt;
@ -83,18 +83,13 @@ impl Layer for FCLayer {
self.input.dot(&self.weights) + &self.biases self.input.dot(&self.weights) + &self.biases
} }
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64> { fn backward_pass(&mut self, output_error: Array1<f64>, learning_rate: f64) -> Array1<f64> {
let input_error = output_error.dot(&self.weights.t()); let input_error = output_error.dot(&self.weights.t());
let delta_weights = self let delta_weights = self
.input .input
.to_owned() .to_shape((self.input.len(), 1usize))
.into_shape((self.input.len(), 1usize))
.unwrap() .unwrap()
.dot( .dot(&output_error.to_shape((1usize, output_error.len())).unwrap());
&output_error
.into_shape((1usize, output_error.len()))
.unwrap(),
);
self.weights = &self.weights + learning_rate * &delta_weights; self.weights = &self.weights + learning_rate * &delta_weights;
self.biases = &self.biases + learning_rate * &output_error; self.biases = &self.biases + learning_rate * &output_error;
input_error input_error

View file

@ -1,9 +1,9 @@
use ndarray::{Array1, ArrayView1}; use ndarray::Array1;
pub mod activation_layer; pub mod activation_layer;
pub mod fc_layer; pub mod fc_layer;
pub trait Layer { pub trait Layer {
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64>; fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64>;
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64>; fn backward_pass(&mut self, output_error: Array1<f64>, learning_rate: f64) -> Array1<f64>;
} }

View file

@ -68,9 +68,9 @@ impl Network {
let mut error = (self.loss_prime)(y_train[j].view(), output.view()); let mut error = (self.loss_prime)(y_train[j].view(), output.view());
for layer in self.layers.iter_mut().rev() { for layer in self.layers.iter_mut().rev() {
if trivial_optimize { if trivial_optimize {
error = layer.backward_pass(error.view(), learning_rate / (i + 1) as f64); error = layer.backward_pass(error, learning_rate / (i + 1) as f64);
} else { } else {
error = layer.backward_pass(error.view(), learning_rate); error = layer.backward_pass(error, learning_rate);
} }
} }
} }