From f2e54cfac124d0af7ff32388b3d9325989d23476 Mon Sep 17 00:00:00 2001 From: lluni Date: Sat, 4 Feb 2023 18:14:06 +0100 Subject: [PATCH] Improved forward_pass --- src/layers/activation_layer.rs | 4 ++-- src/layers/fc_layer.rs | 4 ++-- src/layers/mod.rs | 2 +- src/lib.rs | 10 ++++------ 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/layers/activation_layer.rs b/src/layers/activation_layer.rs index 47a0c03..a3c47f6 100644 --- a/src/layers/activation_layer.rs +++ b/src/layers/activation_layer.rs @@ -23,8 +23,8 @@ impl ActivationLayer { } impl Layer for ActivationLayer { - fn forward_pass(&mut self, input: ArrayView1) -> Array1 { - self.input = input.to_owned(); + fn forward_pass(&mut self, input: Array1) -> Array1 { + self.input = input; // output isn't needed elsewhere // self.output = (self.activation)(&self.input); // self.output.clone() diff --git a/src/layers/fc_layer.rs b/src/layers/fc_layer.rs index af0330d..9fb9848 100644 --- a/src/layers/fc_layer.rs +++ b/src/layers/fc_layer.rs @@ -71,12 +71,12 @@ impl FCLayer { } impl Layer for FCLayer { - fn forward_pass(&mut self, input: ArrayView1) -> Array1 { + fn forward_pass(&mut self, input: Array1) -> Array1 { if !self.is_initialized { self.initialize(input.len()); } - self.input = input.to_owned(); + self.input = input; // output isn't needed elsewhere // self.output = self.input.dot(&self.weights) + &self.biases; // self.output.clone() diff --git a/src/layers/mod.rs b/src/layers/mod.rs index 53c3eee..8a3d5c0 100644 --- a/src/layers/mod.rs +++ b/src/layers/mod.rs @@ -4,6 +4,6 @@ pub mod activation_layer; pub mod fc_layer; pub trait Layer { - fn forward_pass(&mut self, input: ArrayView1) -> Array1; + fn forward_pass(&mut self, input: Array1) -> Array1; fn backward_pass(&mut self, output_error: ArrayView1, learning_rate: f64) -> Array1; } diff --git a/src/lib.rs b/src/lib.rs index bbff9d5..26520ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,10 +30,9 @@ impl Network { let mut result = vec![]; for input in inputs.iter() { - let mut output = Array1::default(inputs[0].raw_dim()); - output.assign(input); + let mut output = input.to_owned(); for layer in &mut self.layers { - output = layer.forward_pass(output.view()); + output = layer.forward_pass(output); } result.push(output); } @@ -57,10 +56,9 @@ impl Network { let mut err = 0.0; for j in 0..num_samples { // forward propagation - let mut output = Array1::default(x_train[0].raw_dim()); - output.assign(&x_train[j]); + let mut output = x_train[j].to_owned(); for layer in self.layers.iter_mut() { - output = layer.forward_pass(output.view()); + output = layer.forward_pass(output); } // compute loss