rust-nn/examples/example_xor.rs

61 lines
1.7 KiB
Rust
Raw Permalink Normal View History

2023-01-15 23:18:58 +01:00
extern crate rust_nn;
2023-01-21 15:19:55 +01:00
use ndarray::array;
2023-01-15 23:18:58 +01:00
use rust_nn::functions::{activation_functions, loss_functions};
use rust_nn::layers::activation_layer::ActivationLayer;
use rust_nn::layers::fc_layer::{FCLayer, Initializer};
2023-01-21 15:19:55 +01:00
use rust_nn::Network;
2023-01-15 23:18:58 +01:00
fn main() {
// training data
let x_train = vec![
array![0.0, 0.0],
array![0.0, 1.0],
array![1.0, 0.0],
2023-01-21 15:19:55 +01:00
array![1.0, 1.0],
2023-01-15 23:18:58 +01:00
];
2023-01-21 15:19:55 +01:00
let y_train = vec![array![0.0], array![1.0], array![1.0], array![0.0]];
2023-01-15 23:18:58 +01:00
// test data
2023-01-21 15:19:55 +01:00
let x_test = vec![
2023-01-15 23:18:58 +01:00
array![0.0, 0.0],
array![0.0, 1.0],
array![1.0, 0.0],
2023-01-21 15:19:55 +01:00
array![1.0, 1.0],
2023-01-15 23:18:58 +01:00
];
// initialize neural network
let mut network = Network::new(loss_functions::Type::MSE);
// add layers
network.add_layer(Box::new(FCLayer::new(
3,
Initializer::Gaussian(0.0, 1.0),
2023-01-21 15:19:55 +01:00
Initializer::Gaussian(0.0, 1.0),
)));
network.add_layer(Box::new(ActivationLayer::new(
activation_functions::Type::Tanh,
2023-01-15 23:18:58 +01:00
)));
network.add_layer(Box::new(FCLayer::new(
1,
Initializer::Gaussian(0.0, 1.0),
2023-01-21 15:19:55 +01:00
Initializer::Gaussian(0.0, 1.0),
)));
network.add_layer(Box::new(ActivationLayer::new(
activation_functions::Type::Tanh,
2023-01-15 23:18:58 +01:00
)));
// train network on training data
network.fit(&x_train, &y_train, 1000, 0.1, false);
2023-01-15 23:18:58 +01:00
// print predictions
let y_test = network.predict(&x_test);
2023-01-15 23:18:58 +01:00
println!("{}", x_test.get(0).unwrap());
for i in 0..y_test.len() {
print!("input: {}\t\t", x_test.get(i).unwrap());
let mut prediction = y_test.get(i).unwrap().to_owned();
// comment the following line to see the exact predictions
prediction.map_mut(|x| *x = x.round());
2023-02-01 16:10:56 +01:00
println!("prediction: {prediction}");
2023-01-15 23:18:58 +01:00
}
2023-01-21 15:19:55 +01:00
}