Improved forward_pass

This commit is contained in:
lluni 2023-02-04 18:14:06 +01:00
parent d130c7cce1
commit f2e54cfac1
Signed by: lluni
GPG key ID: ACEEB468BC325D35
4 changed files with 9 additions and 11 deletions

View file

@ -23,8 +23,8 @@ impl ActivationLayer {
} }
impl Layer for ActivationLayer { impl Layer for ActivationLayer {
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> { fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64> {
self.input = input.to_owned(); self.input = input;
// output isn't needed elsewhere // output isn't needed elsewhere
// self.output = (self.activation)(&self.input); // self.output = (self.activation)(&self.input);
// self.output.clone() // self.output.clone()

View file

@ -71,12 +71,12 @@ impl FCLayer {
} }
impl Layer for FCLayer { impl Layer for FCLayer {
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> { fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64> {
if !self.is_initialized { if !self.is_initialized {
self.initialize(input.len()); self.initialize(input.len());
} }
self.input = input.to_owned(); self.input = input;
// output isn't needed elsewhere // output isn't needed elsewhere
// self.output = self.input.dot(&self.weights) + &self.biases; // self.output = self.input.dot(&self.weights) + &self.biases;
// self.output.clone() // self.output.clone()

View file

@ -4,6 +4,6 @@ pub mod activation_layer;
pub mod fc_layer; pub mod fc_layer;
pub trait Layer { pub trait Layer {
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64>; fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64>;
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64>; fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64>;
} }

View file

@ -30,10 +30,9 @@ impl Network {
let mut result = vec![]; let mut result = vec![];
for input in inputs.iter() { for input in inputs.iter() {
let mut output = Array1::default(inputs[0].raw_dim()); let mut output = input.to_owned();
output.assign(input);
for layer in &mut self.layers { for layer in &mut self.layers {
output = layer.forward_pass(output.view()); output = layer.forward_pass(output);
} }
result.push(output); result.push(output);
} }
@ -57,10 +56,9 @@ impl Network {
let mut err = 0.0; let mut err = 0.0;
for j in 0..num_samples { for j in 0..num_samples {
// forward propagation // forward propagation
let mut output = Array1::default(x_train[0].raw_dim()); let mut output = x_train[j].to_owned();
output.assign(&x_train[j]);
for layer in self.layers.iter_mut() { for layer in self.layers.iter_mut() {
output = layer.forward_pass(output.view()); output = layer.forward_pass(output);
} }
// compute loss // compute loss