rust-nn/examples/example_xor.rs

61 lines
1.7 KiB
Rust
Raw Normal View History

2023-01-15 23:18:58 +01:00
extern crate rust_nn;
use rust_nn::Network;
use rust_nn::functions::{activation_functions, loss_functions};
use rust_nn::layers::activation_layer::ActivationLayer;
use rust_nn::layers::fc_layer::{FCLayer, Initializer};
use ndarray::array;
fn main() {
// training data
let x_train = vec![
array![0.0, 0.0],
array![0.0, 1.0],
array![1.0, 0.0],
array![1.0, 1.0]
];
let y_train = vec![
array![0.0],
array![1.0],
array![1.0],
array![0.0]
];
// test data
let x_test= vec![
array![0.0, 0.0],
array![0.0, 1.0],
array![1.0, 0.0],
array![1.0, 1.0]
];
// initialize neural network
let mut network = Network::new(loss_functions::Type::MSE);
// add layers
network.add_layer(Box::new(FCLayer::new(
3,
Initializer::Gaussian(0.0, 1.0),
Initializer::Gaussian(0.0, 1.0)
)));
network.add_layer(Box::new(ActivationLayer::new(activation_functions::Type::Tanh)));
network.add_layer(Box::new(FCLayer::new(
1,
Initializer::Gaussian(0.0, 1.0),
Initializer::Gaussian(0.0, 1.0)
)));
network.add_layer(Box::new(ActivationLayer::new(activation_functions::Type::Tanh)));
// train network on training data
network.fit(x_train, y_train, 1000, 0.1, false);
// print predictions
let y_test = network.predict(x_test.clone());
println!("{}", x_test.get(0).unwrap());
for i in 0..y_test.len() {
print!("input: {}\t\t", x_test.get(i).unwrap());
let mut prediction = y_test.get(i).unwrap().to_owned();
// comment the following line to see the exact predictions
prediction.map_mut(|x| *x = x.round());
print!("prediction: {}\n", prediction);
}
}