From e24b05f4bc01581bf7129299a24efaa8409a4ca6 Mon Sep 17 00:00:00 2001 From: lluni Date: Sat, 4 Feb 2023 15:42:55 +0100 Subject: [PATCH] Don't let Network consume training nor test data --- examples/example_sine.rs | 4 ++-- examples/example_xor.rs | 4 ++-- src/lib.rs | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/example_sine.rs b/examples/example_sine.rs index 18d6945..09de34c 100644 --- a/examples/example_sine.rs +++ b/examples/example_sine.rs @@ -66,10 +66,10 @@ fn main() -> Result<(), Box> { ))); // train network on training data - network.fit(x_train, y_train, 100, 0.05, true); + network.fit(&x_train, &y_train, 100, 0.05, true); // predict test dataset - let y_test_pred = network.predict(x_test.clone()); + let y_test_pred = network.predict(&x_test); // create the chart let buf = BitMapBackend::new("./examples/sine.png", (800, 600)).into_drawing_area(); diff --git a/examples/example_xor.rs b/examples/example_xor.rs index 88863e4..6939bb8 100644 --- a/examples/example_xor.rs +++ b/examples/example_xor.rs @@ -45,10 +45,10 @@ fn main() { ))); // train network on training data - network.fit(x_train, y_train, 1000, 0.1, false); + network.fit(&x_train, &y_train, 1000, 0.1, false); // print predictions - let y_test = network.predict(x_test.clone()); + let y_test = network.predict(&x_test); println!("{}", x_test.get(0).unwrap()); for i in 0..y_test.len() { print!("input: {}\t\t", x_test.get(i).unwrap()); diff --git a/src/lib.rs b/src/lib.rs index e9e06b4..bbff9d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,7 @@ impl Network { self.layers.push(layer); } - pub fn predict(&mut self, inputs: Vec>) -> Vec> { + pub fn predict(&mut self, inputs: &[Array1]) -> Vec> { assert!(!inputs.is_empty()); let mut result = vec![]; @@ -35,7 +35,7 @@ impl Network { for layer in &mut self.layers { output = layer.forward_pass(output.view()); } - result.push(output.to_owned()); + result.push(output); } result @@ -43,8 +43,8 @@ impl Network { pub fn fit( &mut self, - x_train: Vec>, - y_train: Vec>, + x_train: &[Array1], + y_train: &[Array1], epochs: usize, learning_rate: f64, trivial_optimize: bool,