Don't let Network consume training nor test data
This commit is contained in:
parent
a8270914e0
commit
e24b05f4bc
3 changed files with 8 additions and 8 deletions
|
@ -66,10 +66,10 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||
)));
|
||||
|
||||
// train network on training data
|
||||
network.fit(x_train, y_train, 100, 0.05, true);
|
||||
network.fit(&x_train, &y_train, 100, 0.05, true);
|
||||
|
||||
// predict test dataset
|
||||
let y_test_pred = network.predict(x_test.clone());
|
||||
let y_test_pred = network.predict(&x_test);
|
||||
|
||||
// create the chart
|
||||
let buf = BitMapBackend::new("./examples/sine.png", (800, 600)).into_drawing_area();
|
||||
|
|
|
@ -45,10 +45,10 @@ fn main() {
|
|||
)));
|
||||
|
||||
// train network on training data
|
||||
network.fit(x_train, y_train, 1000, 0.1, false);
|
||||
network.fit(&x_train, &y_train, 1000, 0.1, false);
|
||||
|
||||
// print predictions
|
||||
let y_test = network.predict(x_test.clone());
|
||||
let y_test = network.predict(&x_test);
|
||||
println!("{}", x_test.get(0).unwrap());
|
||||
for i in 0..y_test.len() {
|
||||
print!("input: {}\t\t", x_test.get(i).unwrap());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue