Don't let Network consume training nor test data

This commit is contained in:
lluni 2023-02-04 15:42:55 +01:00
parent a8270914e0
commit e24b05f4bc
Signed by: lluni
GPG key ID: ACEEB468BC325D35
3 changed files with 8 additions and 8 deletions

View file

@ -66,10 +66,10 @@ fn main() -> Result<(), Box<dyn Error>> {
))); )));
// train network on training data // train network on training data
network.fit(x_train, y_train, 100, 0.05, true); network.fit(&x_train, &y_train, 100, 0.05, true);
// predict test dataset // predict test dataset
let y_test_pred = network.predict(x_test.clone()); let y_test_pred = network.predict(&x_test);
// create the chart // create the chart
let buf = BitMapBackend::new("./examples/sine.png", (800, 600)).into_drawing_area(); let buf = BitMapBackend::new("./examples/sine.png", (800, 600)).into_drawing_area();

View file

@ -45,10 +45,10 @@ fn main() {
))); )));
// train network on training data // train network on training data
network.fit(x_train, y_train, 1000, 0.1, false); network.fit(&x_train, &y_train, 1000, 0.1, false);
// print predictions // print predictions
let y_test = network.predict(x_test.clone()); let y_test = network.predict(&x_test);
println!("{}", x_test.get(0).unwrap()); println!("{}", x_test.get(0).unwrap());
for i in 0..y_test.len() { for i in 0..y_test.len() {
print!("input: {}\t\t", x_test.get(i).unwrap()); print!("input: {}\t\t", x_test.get(i).unwrap());

View file

@ -25,7 +25,7 @@ impl Network {
self.layers.push(layer); self.layers.push(layer);
} }
pub fn predict(&mut self, inputs: Vec<Array1<f64>>) -> Vec<Array1<f64>> { pub fn predict(&mut self, inputs: &[Array1<f64>]) -> Vec<Array1<f64>> {
assert!(!inputs.is_empty()); assert!(!inputs.is_empty());
let mut result = vec![]; let mut result = vec![];
@ -35,7 +35,7 @@ impl Network {
for layer in &mut self.layers { for layer in &mut self.layers {
output = layer.forward_pass(output.view()); output = layer.forward_pass(output.view());
} }
result.push(output.to_owned()); result.push(output);
} }
result result
@ -43,8 +43,8 @@ impl Network {
pub fn fit( pub fn fit(
&mut self, &mut self,
x_train: Vec<Array1<f64>>, x_train: &[Array1<f64>],
y_train: Vec<Array1<f64>>, y_train: &[Array1<f64>],
epochs: usize, epochs: usize,
learning_rate: f64, learning_rate: f64,
trivial_optimize: bool, trivial_optimize: bool,