2023-01-15 23:18:58 +01:00
|
|
|
extern crate rust_nn;
|
|
|
|
|
|
|
|
use std::error::Error;
|
|
|
|
use std::f64::consts::PI;
|
|
|
|
|
2023-01-21 15:19:55 +01:00
|
|
|
use ndarray::Array1;
|
2023-01-15 23:18:58 +01:00
|
|
|
use ndarray_rand::rand_distr::Uniform;
|
2023-01-21 15:19:55 +01:00
|
|
|
use ndarray_rand::RandomExt;
|
2023-01-15 23:18:58 +01:00
|
|
|
use plotters::prelude::*;
|
|
|
|
use rust_nn::functions::{activation_functions, loss_functions};
|
|
|
|
use rust_nn::layers::activation_layer::ActivationLayer;
|
|
|
|
use rust_nn::layers::fc_layer::{FCLayer, Initializer};
|
2023-01-21 15:19:55 +01:00
|
|
|
use rust_nn::Network;
|
2023-01-15 23:18:58 +01:00
|
|
|
|
2023-02-04 17:14:25 +01:00
|
|
|
fn main() {
|
2023-01-15 23:18:58 +01:00
|
|
|
// training data
|
|
|
|
let training_interval = (0.0f64, 2.0f64 * PI);
|
|
|
|
let steps = 100000;
|
2023-01-21 15:19:55 +01:00
|
|
|
let training_values = Array1::random(
|
|
|
|
steps,
|
|
|
|
Uniform::new(training_interval.0, training_interval.1),
|
|
|
|
)
|
|
|
|
.to_vec();
|
2023-02-05 18:58:20 +01:00
|
|
|
let mut x_train = Vec::with_capacity(steps);
|
|
|
|
let mut y_train = Vec::with_capacity(steps);
|
2023-01-15 23:18:58 +01:00
|
|
|
for x in training_values {
|
|
|
|
x_train.push(Array1::from_elem(1usize, x));
|
|
|
|
y_train.push(Array1::from_elem(1usize, x.sin()));
|
|
|
|
}
|
|
|
|
// test data
|
|
|
|
let test_steps = 1000;
|
|
|
|
let interval_length = training_interval.1 - training_interval.0;
|
|
|
|
let step_size = interval_length / test_steps as f64;
|
|
|
|
let testing_values = Array1::range(training_interval.0, training_interval.1, step_size);
|
2023-02-05 18:58:20 +01:00
|
|
|
let mut x_test = Vec::with_capacity(test_steps);
|
|
|
|
let mut y_test_true = Vec::with_capacity(test_steps);
|
2023-01-15 23:18:58 +01:00
|
|
|
for x in testing_values {
|
|
|
|
x_test.push(Array1::from_elem(1usize, x));
|
|
|
|
y_test_true.push(Array1::from_elem(1usize, x.sin()));
|
|
|
|
}
|
|
|
|
|
|
|
|
// initialize neural network
|
|
|
|
let mut network = Network::new(loss_functions::Type::MSE);
|
|
|
|
|
|
|
|
// add layers
|
|
|
|
network.add_layer(Box::new(FCLayer::new(
|
|
|
|
8,
|
|
|
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
2023-01-21 15:19:55 +01:00
|
|
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
|
|
|
)));
|
|
|
|
network.add_layer(Box::new(ActivationLayer::new(
|
2023-02-04 20:35:17 +01:00
|
|
|
activation_functions::Type::Gelu,
|
2023-01-15 23:18:58 +01:00
|
|
|
)));
|
|
|
|
network.add_layer(Box::new(FCLayer::new(
|
|
|
|
8,
|
|
|
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
2023-01-21 15:19:55 +01:00
|
|
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
|
|
|
)));
|
|
|
|
network.add_layer(Box::new(ActivationLayer::new(
|
2023-02-04 20:35:17 +01:00
|
|
|
activation_functions::Type::Gelu,
|
2023-01-15 23:18:58 +01:00
|
|
|
)));
|
|
|
|
network.add_layer(Box::new(FCLayer::new(
|
|
|
|
1,
|
|
|
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
2023-01-21 15:19:55 +01:00
|
|
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
2023-01-15 23:18:58 +01:00
|
|
|
)));
|
|
|
|
|
|
|
|
// train network on training data
|
2023-02-04 15:42:55 +01:00
|
|
|
network.fit(&x_train, &y_train, 100, 0.05, true);
|
2023-01-15 23:18:58 +01:00
|
|
|
|
|
|
|
// predict test dataset
|
2023-02-04 15:42:55 +01:00
|
|
|
let y_test_pred = network.predict(&x_test);
|
2023-01-15 23:18:58 +01:00
|
|
|
|
2023-02-04 17:14:25 +01:00
|
|
|
// show results
|
|
|
|
if let Ok(()) = draw_results(&training_interval, &x_test, &y_test_true, &y_test_pred) {
|
|
|
|
println!("results can be seen in ./examples/sine.png");
|
|
|
|
} else {
|
|
|
|
println!("failed to draw results");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn draw_results(
|
|
|
|
training_interval: &(f64, f64),
|
|
|
|
x_test: &[Array1<f64>],
|
|
|
|
y_test_true: &[Array1<f64>],
|
|
|
|
y_test_pred: &[Array1<f64>],
|
|
|
|
) -> Result<(), Box<dyn Error>> {
|
2023-01-15 23:18:58 +01:00
|
|
|
// create the chart
|
|
|
|
let buf = BitMapBackend::new("./examples/sine.png", (800, 600)).into_drawing_area();
|
|
|
|
buf.fill(&WHITE)?;
|
|
|
|
let mut chart = ChartBuilder::on(&buf)
|
|
|
|
//.caption("sin(x)", ("sans-serif", 30))
|
|
|
|
.x_label_area_size(30)
|
|
|
|
.y_label_area_size(30)
|
|
|
|
.build_cartesian_2d(training_interval.0..training_interval.1, -1.0f64..1.0f64)?;
|
|
|
|
|
|
|
|
chart
|
|
|
|
.configure_mesh()
|
|
|
|
.disable_x_mesh()
|
|
|
|
.disable_y_mesh()
|
|
|
|
.draw()?;
|
|
|
|
|
|
|
|
// add the first plot
|
2023-01-21 15:19:55 +01:00
|
|
|
let data1: Vec<(f64, f64)> = x_test
|
|
|
|
.iter()
|
|
|
|
.zip(y_test_true.iter())
|
2023-01-15 23:18:58 +01:00
|
|
|
.map(|(x, y)| (x[0], y[0]))
|
|
|
|
.collect();
|
|
|
|
chart
|
2023-01-21 15:19:55 +01:00
|
|
|
.draw_series(LineSeries::new(data1, &RED))
|
|
|
|
.unwrap()
|
2023-01-15 23:18:58 +01:00
|
|
|
.label("true values")
|
2023-02-01 16:10:56 +01:00
|
|
|
.legend(|(x, y)| PathElement::new(vec![(x, y), (x + 1, y)], RED));
|
2023-01-15 23:18:58 +01:00
|
|
|
|
|
|
|
// add the second plot
|
2023-01-21 15:19:55 +01:00
|
|
|
let data2: Vec<(f64, f64)> = x_test
|
|
|
|
.iter()
|
|
|
|
.zip(y_test_pred.iter())
|
2023-01-15 23:18:58 +01:00
|
|
|
.map(|(x, y)| (x[0], y[0]))
|
|
|
|
.collect();
|
|
|
|
chart
|
2023-01-21 15:19:55 +01:00
|
|
|
.draw_series(LineSeries::new(data2, &BLUE))
|
|
|
|
.unwrap()
|
2023-01-15 23:18:58 +01:00
|
|
|
.label("predicted values")
|
2023-02-01 16:10:56 +01:00
|
|
|
.legend(|(x, y)| PathElement::new(vec![(x, y), (x + 1, y)], BLUE));
|
2023-01-15 23:18:58 +01:00
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|