2023-01-15 23:18:58 +01:00
|
|
|
extern crate rust_nn;
|
|
|
|
|
2023-01-21 15:19:55 +01:00
|
|
|
use ndarray::array;
|
2023-01-15 23:18:58 +01:00
|
|
|
use rust_nn::functions::{activation_functions, loss_functions};
|
|
|
|
use rust_nn::layers::activation_layer::ActivationLayer;
|
|
|
|
use rust_nn::layers::fc_layer::{FCLayer, Initializer};
|
2023-01-21 15:19:55 +01:00
|
|
|
use rust_nn::Network;
|
2023-01-15 23:18:58 +01:00
|
|
|
|
|
|
|
fn main() {
|
|
|
|
// training data
|
|
|
|
let x_train = vec![
|
|
|
|
array![0.0, 0.0],
|
|
|
|
array![0.0, 1.0],
|
|
|
|
array![1.0, 0.0],
|
2023-01-21 15:19:55 +01:00
|
|
|
array![1.0, 1.0],
|
2023-01-15 23:18:58 +01:00
|
|
|
];
|
2023-01-21 15:19:55 +01:00
|
|
|
let y_train = vec![array![0.0], array![1.0], array![1.0], array![0.0]];
|
2023-01-15 23:18:58 +01:00
|
|
|
// test data
|
2023-01-21 15:19:55 +01:00
|
|
|
let x_test = vec![
|
2023-01-15 23:18:58 +01:00
|
|
|
array![0.0, 0.0],
|
|
|
|
array![0.0, 1.0],
|
|
|
|
array![1.0, 0.0],
|
2023-01-21 15:19:55 +01:00
|
|
|
array![1.0, 1.0],
|
2023-01-15 23:18:58 +01:00
|
|
|
];
|
|
|
|
|
|
|
|
// initialize neural network
|
|
|
|
let mut network = Network::new(loss_functions::Type::MSE);
|
|
|
|
|
|
|
|
// add layers
|
|
|
|
network.add_layer(Box::new(FCLayer::new(
|
|
|
|
3,
|
|
|
|
Initializer::Gaussian(0.0, 1.0),
|
2023-01-21 15:19:55 +01:00
|
|
|
Initializer::Gaussian(0.0, 1.0),
|
|
|
|
)));
|
|
|
|
network.add_layer(Box::new(ActivationLayer::new(
|
|
|
|
activation_functions::Type::Tanh,
|
2023-01-15 23:18:58 +01:00
|
|
|
)));
|
|
|
|
network.add_layer(Box::new(FCLayer::new(
|
|
|
|
1,
|
|
|
|
Initializer::Gaussian(0.0, 1.0),
|
2023-01-21 15:19:55 +01:00
|
|
|
Initializer::Gaussian(0.0, 1.0),
|
|
|
|
)));
|
|
|
|
network.add_layer(Box::new(ActivationLayer::new(
|
|
|
|
activation_functions::Type::Tanh,
|
2023-01-15 23:18:58 +01:00
|
|
|
)));
|
|
|
|
|
|
|
|
// train network on training data
|
|
|
|
network.fit(x_train, y_train, 1000, 0.1, false);
|
|
|
|
|
|
|
|
// print predictions
|
|
|
|
let y_test = network.predict(x_test.clone());
|
|
|
|
println!("{}", x_test.get(0).unwrap());
|
|
|
|
for i in 0..y_test.len() {
|
|
|
|
print!("input: {}\t\t", x_test.get(i).unwrap());
|
|
|
|
let mut prediction = y_test.get(i).unwrap().to_owned();
|
|
|
|
// comment the following line to see the exact predictions
|
|
|
|
prediction.map_mut(|x| *x = x.round());
|
|
|
|
print!("prediction: {}\n", prediction);
|
|
|
|
}
|
2023-01-21 15:19:55 +01:00
|
|
|
}
|