rust-nn/examples/example_xor.rs
2023-01-21 15:19:55 +01:00

60 lines
1.7 KiB
Rust

extern crate rust_nn;
use ndarray::array;
use rust_nn::functions::{activation_functions, loss_functions};
use rust_nn::layers::activation_layer::ActivationLayer;
use rust_nn::layers::fc_layer::{FCLayer, Initializer};
use rust_nn::Network;
fn main() {
// training data
let x_train = vec![
array![0.0, 0.0],
array![0.0, 1.0],
array![1.0, 0.0],
array![1.0, 1.0],
];
let y_train = vec![array![0.0], array![1.0], array![1.0], array![0.0]];
// test data
let x_test = vec![
array![0.0, 0.0],
array![0.0, 1.0],
array![1.0, 0.0],
array![1.0, 1.0],
];
// initialize neural network
let mut network = Network::new(loss_functions::Type::MSE);
// add layers
network.add_layer(Box::new(FCLayer::new(
3,
Initializer::Gaussian(0.0, 1.0),
Initializer::Gaussian(0.0, 1.0),
)));
network.add_layer(Box::new(ActivationLayer::new(
activation_functions::Type::Tanh,
)));
network.add_layer(Box::new(FCLayer::new(
1,
Initializer::Gaussian(0.0, 1.0),
Initializer::Gaussian(0.0, 1.0),
)));
network.add_layer(Box::new(ActivationLayer::new(
activation_functions::Type::Tanh,
)));
// train network on training data
network.fit(x_train, y_train, 1000, 0.1, false);
// print predictions
let y_test = network.predict(x_test.clone());
println!("{}", x_test.get(0).unwrap());
for i in 0..y_test.len() {
print!("input: {}\t\t", x_test.get(i).unwrap());
let mut prediction = y_test.get(i).unwrap().to_owned();
// comment the following line to see the exact predictions
prediction.map_mut(|x| *x = x.round());
print!("prediction: {}\n", prediction);
}
}