From d130c7cce11b487b566028b70ec8c3a87a5041aa Mon Sep 17 00:00:00 2001 From: lluni Date: Sat, 4 Feb 2023 17:56:30 +0100 Subject: [PATCH] Layers don't need to store the forward pass output --- src/layers/activation_layer.rs | 10 ++++++---- src/layers/fc_layer.rs | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/layers/activation_layer.rs b/src/layers/activation_layer.rs index 8a271ed..47a0c03 100644 --- a/src/layers/activation_layer.rs +++ b/src/layers/activation_layer.rs @@ -5,7 +5,7 @@ use crate::functions::activation_functions::*; pub struct ActivationLayer { input: Array1, - output: Array1, + // output: Array1, activation: fn(&Array1) -> Array1, activation_prime: fn(&Array1) -> Array1, } @@ -15,7 +15,7 @@ impl ActivationLayer { let (activation, activation_prime) = parse_type(activation_fn); ActivationLayer { input: arr1(&[]), - output: arr1(&[]), + // output: arr1(&[]), activation, activation_prime, } @@ -25,8 +25,10 @@ impl ActivationLayer { impl Layer for ActivationLayer { fn forward_pass(&mut self, input: ArrayView1) -> Array1 { self.input = input.to_owned(); - self.output = (self.activation)(&self.input); - self.output.clone() + // output isn't needed elsewhere + // self.output = (self.activation)(&self.input); + // self.output.clone() + (self.activation)(&self.input) } fn backward_pass(&mut self, output_error: ArrayView1, _learning_rate: f64) -> Array1 { diff --git a/src/layers/fc_layer.rs b/src/layers/fc_layer.rs index 1aa0202..af0330d 100644 --- a/src/layers/fc_layer.rs +++ b/src/layers/fc_layer.rs @@ -40,7 +40,7 @@ pub struct FCLayer { weight_initializer: Initializer, bias_initializer: Initializer, input: Array1, - output: Array1, + // output: Array1, weights: Array2, biases: Array1, } @@ -57,7 +57,7 @@ impl FCLayer { weight_initializer, bias_initializer, input: arr1(&[]), - output: arr1(&[]), + // output: arr1(&[]), weights: arr2(&[[]]), biases: arr1(&[]), } @@ -77,8 +77,10 @@ impl Layer for FCLayer { } self.input = input.to_owned(); - self.output = self.input.dot(&self.weights) + &self.biases; - self.output.clone() + // output isn't needed elsewhere + // self.output = self.input.dot(&self.weights) + &self.biases; + // self.output.clone() + self.input.dot(&self.weights) + &self.biases } fn backward_pass(&mut self, output_error: ArrayView1, learning_rate: f64) -> Array1 {