Initial commit
This commit is contained in:
commit
961626616f
12 changed files with 569 additions and 0 deletions
102
examples/example_sine.rs
Normal file
102
examples/example_sine.rs
Normal file
|
@ -0,0 +1,102 @@
|
|||
extern crate rust_nn;
|
||||
|
||||
use std::error::Error;
|
||||
use std::f64::consts::PI;
|
||||
|
||||
use ndarray_rand::RandomExt;
|
||||
use ndarray_rand::rand_distr::Uniform;
|
||||
use plotters::prelude::*;
|
||||
use rust_nn::Network;
|
||||
use rust_nn::functions::{activation_functions, loss_functions};
|
||||
use rust_nn::layers::activation_layer::ActivationLayer;
|
||||
use rust_nn::layers::fc_layer::{FCLayer, Initializer};
|
||||
use ndarray::Array1;
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
// training data
|
||||
let training_interval = (0.0f64, 2.0f64 * PI);
|
||||
let steps = 100000;
|
||||
let training_values = Array1::random(steps, Uniform::new(training_interval.0, training_interval.1)).to_vec();
|
||||
let mut x_train = Vec::new();
|
||||
let mut y_train = Vec::new();
|
||||
for x in training_values {
|
||||
x_train.push(Array1::from_elem(1usize, x));
|
||||
y_train.push(Array1::from_elem(1usize, x.sin()));
|
||||
}
|
||||
// test data
|
||||
let test_steps = 1000;
|
||||
let interval_length = training_interval.1 - training_interval.0;
|
||||
let step_size = interval_length / test_steps as f64;
|
||||
let testing_values = Array1::range(training_interval.0, training_interval.1, step_size);
|
||||
let mut x_test = Vec::new();
|
||||
let mut y_test_true = Vec::new();
|
||||
for x in testing_values {
|
||||
x_test.push(Array1::from_elem(1usize, x));
|
||||
y_test_true.push(Array1::from_elem(1usize, x.sin()));
|
||||
}
|
||||
|
||||
// initialize neural network
|
||||
let mut network = Network::new(loss_functions::Type::MSE);
|
||||
|
||||
// add layers
|
||||
network.add_layer(Box::new(FCLayer::new(
|
||||
8,
|
||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1)
|
||||
)));
|
||||
network.add_layer(Box::new(ActivationLayer::new(activation_functions::Type::LeakyRelu)));
|
||||
network.add_layer(Box::new(FCLayer::new(
|
||||
8,
|
||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1)
|
||||
)));
|
||||
network.add_layer(Box::new(ActivationLayer::new(activation_functions::Type::LeakyRelu)));
|
||||
network.add_layer(Box::new(FCLayer::new(
|
||||
1,
|
||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1)
|
||||
)));
|
||||
|
||||
// train network on training data
|
||||
network.fit(x_train, y_train, 100, 0.05, true);
|
||||
|
||||
// predict test dataset
|
||||
let y_test_pred = network.predict(x_test.clone());
|
||||
|
||||
// create the chart
|
||||
let buf = BitMapBackend::new("./examples/sine.png", (800, 600)).into_drawing_area();
|
||||
buf.fill(&WHITE)?;
|
||||
let mut chart = ChartBuilder::on(&buf)
|
||||
//.caption("sin(x)", ("sans-serif", 30))
|
||||
.x_label_area_size(30)
|
||||
.y_label_area_size(30)
|
||||
.build_cartesian_2d(training_interval.0..training_interval.1, -1.0f64..1.0f64)?;
|
||||
|
||||
chart
|
||||
.configure_mesh()
|
||||
.disable_x_mesh()
|
||||
.disable_y_mesh()
|
||||
.draw()?;
|
||||
|
||||
// add the first plot
|
||||
let mut data1: Vec<(f64,f64)> = x_test.iter().zip(y_test_true.iter())
|
||||
.map(|(x, y)| (x[0], y[0]))
|
||||
.collect();
|
||||
data1.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
chart
|
||||
.draw_series(LineSeries::new(data1, &RED)).unwrap()
|
||||
.label("true values")
|
||||
.legend(|(x, y)| PathElement::new(vec![(x, y), (x + 1, y)], &RED));
|
||||
|
||||
// add the second plot
|
||||
let mut data2: Vec<(f64,f64)> = x_test.iter().zip(y_test_pred.iter())
|
||||
.map(|(x, y)| (x[0], y[0]))
|
||||
.collect();
|
||||
data2.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
chart
|
||||
.draw_series(LineSeries::new(data2, &BLUE)).unwrap()
|
||||
.label("predicted values")
|
||||
.legend(|(x, y)| PathElement::new(vec![(x, y), (x + 1, y)], &BLUE));
|
||||
|
||||
Ok(())
|
||||
}
|
61
examples/example_xor.rs
Normal file
61
examples/example_xor.rs
Normal file
|
@ -0,0 +1,61 @@
|
|||
extern crate rust_nn;
|
||||
|
||||
use rust_nn::Network;
|
||||
use rust_nn::functions::{activation_functions, loss_functions};
|
||||
use rust_nn::layers::activation_layer::ActivationLayer;
|
||||
use rust_nn::layers::fc_layer::{FCLayer, Initializer};
|
||||
use ndarray::array;
|
||||
|
||||
fn main() {
|
||||
// training data
|
||||
let x_train = vec![
|
||||
array![0.0, 0.0],
|
||||
array![0.0, 1.0],
|
||||
array![1.0, 0.0],
|
||||
array![1.0, 1.0]
|
||||
];
|
||||
let y_train = vec![
|
||||
array![0.0],
|
||||
array![1.0],
|
||||
array![1.0],
|
||||
array![0.0]
|
||||
];
|
||||
// test data
|
||||
let x_test= vec![
|
||||
array![0.0, 0.0],
|
||||
array![0.0, 1.0],
|
||||
array![1.0, 0.0],
|
||||
array![1.0, 1.0]
|
||||
];
|
||||
|
||||
// initialize neural network
|
||||
let mut network = Network::new(loss_functions::Type::MSE);
|
||||
|
||||
// add layers
|
||||
network.add_layer(Box::new(FCLayer::new(
|
||||
3,
|
||||
Initializer::Gaussian(0.0, 1.0),
|
||||
Initializer::Gaussian(0.0, 1.0)
|
||||
)));
|
||||
network.add_layer(Box::new(ActivationLayer::new(activation_functions::Type::Tanh)));
|
||||
network.add_layer(Box::new(FCLayer::new(
|
||||
1,
|
||||
Initializer::Gaussian(0.0, 1.0),
|
||||
Initializer::Gaussian(0.0, 1.0)
|
||||
)));
|
||||
network.add_layer(Box::new(ActivationLayer::new(activation_functions::Type::Tanh)));
|
||||
|
||||
// train network on training data
|
||||
network.fit(x_train, y_train, 1000, 0.1, false);
|
||||
|
||||
// print predictions
|
||||
let y_test = network.predict(x_test.clone());
|
||||
println!("{}", x_test.get(0).unwrap());
|
||||
for i in 0..y_test.len() {
|
||||
print!("input: {}\t\t", x_test.get(i).unwrap());
|
||||
let mut prediction = y_test.get(i).unwrap().to_owned();
|
||||
// comment the following line to see the exact predictions
|
||||
prediction.map_mut(|x| *x = x.round());
|
||||
print!("prediction: {}\n", prediction);
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue