Don't let Network consume training nor test data
This commit is contained in:
parent
a8270914e0
commit
e24b05f4bc
3 changed files with 8 additions and 8 deletions
|
@ -66,10 +66,10 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// train network on training data
|
// train network on training data
|
||||||
network.fit(x_train, y_train, 100, 0.05, true);
|
network.fit(&x_train, &y_train, 100, 0.05, true);
|
||||||
|
|
||||||
// predict test dataset
|
// predict test dataset
|
||||||
let y_test_pred = network.predict(x_test.clone());
|
let y_test_pred = network.predict(&x_test);
|
||||||
|
|
||||||
// create the chart
|
// create the chart
|
||||||
let buf = BitMapBackend::new("./examples/sine.png", (800, 600)).into_drawing_area();
|
let buf = BitMapBackend::new("./examples/sine.png", (800, 600)).into_drawing_area();
|
||||||
|
|
|
@ -45,10 +45,10 @@ fn main() {
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// train network on training data
|
// train network on training data
|
||||||
network.fit(x_train, y_train, 1000, 0.1, false);
|
network.fit(&x_train, &y_train, 1000, 0.1, false);
|
||||||
|
|
||||||
// print predictions
|
// print predictions
|
||||||
let y_test = network.predict(x_test.clone());
|
let y_test = network.predict(&x_test);
|
||||||
println!("{}", x_test.get(0).unwrap());
|
println!("{}", x_test.get(0).unwrap());
|
||||||
for i in 0..y_test.len() {
|
for i in 0..y_test.len() {
|
||||||
print!("input: {}\t\t", x_test.get(i).unwrap());
|
print!("input: {}\t\t", x_test.get(i).unwrap());
|
||||||
|
|
|
@ -25,7 +25,7 @@ impl Network {
|
||||||
self.layers.push(layer);
|
self.layers.push(layer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn predict(&mut self, inputs: Vec<Array1<f64>>) -> Vec<Array1<f64>> {
|
pub fn predict(&mut self, inputs: &[Array1<f64>]) -> Vec<Array1<f64>> {
|
||||||
assert!(!inputs.is_empty());
|
assert!(!inputs.is_empty());
|
||||||
let mut result = vec![];
|
let mut result = vec![];
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ impl Network {
|
||||||
for layer in &mut self.layers {
|
for layer in &mut self.layers {
|
||||||
output = layer.forward_pass(output.view());
|
output = layer.forward_pass(output.view());
|
||||||
}
|
}
|
||||||
result.push(output.to_owned());
|
result.push(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
result
|
result
|
||||||
|
@ -43,8 +43,8 @@ impl Network {
|
||||||
|
|
||||||
pub fn fit(
|
pub fn fit(
|
||||||
&mut self,
|
&mut self,
|
||||||
x_train: Vec<Array1<f64>>,
|
x_train: &[Array1<f64>],
|
||||||
y_train: Vec<Array1<f64>>,
|
y_train: &[Array1<f64>],
|
||||||
epochs: usize,
|
epochs: usize,
|
||||||
learning_rate: f64,
|
learning_rate: f64,
|
||||||
trivial_optimize: bool,
|
trivial_optimize: bool,
|
||||||
|
|
Loading…
Reference in a new issue