rustfmt
This commit is contained in:
parent
d80bd3c5e5
commit
2f3745a31c
7 changed files with 105 additions and 61 deletions
|
@ -3,20 +3,24 @@ extern crate rust_nn;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::f64::consts::PI;
|
use std::f64::consts::PI;
|
||||||
|
|
||||||
use ndarray_rand::RandomExt;
|
use ndarray::Array1;
|
||||||
use ndarray_rand::rand_distr::Uniform;
|
use ndarray_rand::rand_distr::Uniform;
|
||||||
|
use ndarray_rand::RandomExt;
|
||||||
use plotters::prelude::*;
|
use plotters::prelude::*;
|
||||||
use rust_nn::Network;
|
|
||||||
use rust_nn::functions::{activation_functions, loss_functions};
|
use rust_nn::functions::{activation_functions, loss_functions};
|
||||||
use rust_nn::layers::activation_layer::ActivationLayer;
|
use rust_nn::layers::activation_layer::ActivationLayer;
|
||||||
use rust_nn::layers::fc_layer::{FCLayer, Initializer};
|
use rust_nn::layers::fc_layer::{FCLayer, Initializer};
|
||||||
use ndarray::Array1;
|
use rust_nn::Network;
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn Error>> {
|
fn main() -> Result<(), Box<dyn Error>> {
|
||||||
// training data
|
// training data
|
||||||
let training_interval = (0.0f64, 2.0f64 * PI);
|
let training_interval = (0.0f64, 2.0f64 * PI);
|
||||||
let steps = 100000;
|
let steps = 100000;
|
||||||
let training_values = Array1::random(steps, Uniform::new(training_interval.0, training_interval.1)).to_vec();
|
let training_values = Array1::random(
|
||||||
|
steps,
|
||||||
|
Uniform::new(training_interval.0, training_interval.1),
|
||||||
|
)
|
||||||
|
.to_vec();
|
||||||
let mut x_train = Vec::new();
|
let mut x_train = Vec::new();
|
||||||
let mut y_train = Vec::new();
|
let mut y_train = Vec::new();
|
||||||
for x in training_values {
|
for x in training_values {
|
||||||
|
@ -42,19 +46,23 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
network.add_layer(Box::new(FCLayer::new(
|
network.add_layer(Box::new(FCLayer::new(
|
||||||
8,
|
8,
|
||||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
||||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1)
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
||||||
|
)));
|
||||||
|
network.add_layer(Box::new(ActivationLayer::new(
|
||||||
|
activation_functions::Type::LeakyRelu,
|
||||||
)));
|
)));
|
||||||
network.add_layer(Box::new(ActivationLayer::new(activation_functions::Type::LeakyRelu)));
|
|
||||||
network.add_layer(Box::new(FCLayer::new(
|
network.add_layer(Box::new(FCLayer::new(
|
||||||
8,
|
8,
|
||||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
||||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1)
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
||||||
|
)));
|
||||||
|
network.add_layer(Box::new(ActivationLayer::new(
|
||||||
|
activation_functions::Type::LeakyRelu,
|
||||||
)));
|
)));
|
||||||
network.add_layer(Box::new(ActivationLayer::new(activation_functions::Type::LeakyRelu)));
|
|
||||||
network.add_layer(Box::new(FCLayer::new(
|
network.add_layer(Box::new(FCLayer::new(
|
||||||
1,
|
1,
|
||||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
||||||
Initializer::GaussianWFactor(0.0, 1.0, 0.1)
|
Initializer::GaussianWFactor(0.0, 1.0, 0.1),
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// train network on training data
|
// train network on training data
|
||||||
|
@ -79,20 +87,26 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
.draw()?;
|
.draw()?;
|
||||||
|
|
||||||
// add the first plot
|
// add the first plot
|
||||||
let data1: Vec<(f64,f64)> = x_test.iter().zip(y_test_true.iter())
|
let data1: Vec<(f64, f64)> = x_test
|
||||||
|
.iter()
|
||||||
|
.zip(y_test_true.iter())
|
||||||
.map(|(x, y)| (x[0], y[0]))
|
.map(|(x, y)| (x[0], y[0]))
|
||||||
.collect();
|
.collect();
|
||||||
chart
|
chart
|
||||||
.draw_series(LineSeries::new(data1, &RED)).unwrap()
|
.draw_series(LineSeries::new(data1, &RED))
|
||||||
|
.unwrap()
|
||||||
.label("true values")
|
.label("true values")
|
||||||
.legend(|(x, y)| PathElement::new(vec![(x, y), (x + 1, y)], &RED));
|
.legend(|(x, y)| PathElement::new(vec![(x, y), (x + 1, y)], &RED));
|
||||||
|
|
||||||
// add the second plot
|
// add the second plot
|
||||||
let data2: Vec<(f64,f64)> = x_test.iter().zip(y_test_pred.iter())
|
let data2: Vec<(f64, f64)> = x_test
|
||||||
|
.iter()
|
||||||
|
.zip(y_test_pred.iter())
|
||||||
.map(|(x, y)| (x[0], y[0]))
|
.map(|(x, y)| (x[0], y[0]))
|
||||||
.collect();
|
.collect();
|
||||||
chart
|
chart
|
||||||
.draw_series(LineSeries::new(data2, &BLUE)).unwrap()
|
.draw_series(LineSeries::new(data2, &BLUE))
|
||||||
|
.unwrap()
|
||||||
.label("predicted values")
|
.label("predicted values")
|
||||||
.legend(|(x, y)| PathElement::new(vec![(x, y), (x + 1, y)], &BLUE));
|
.legend(|(x, y)| PathElement::new(vec![(x, y), (x + 1, y)], &BLUE));
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
extern crate rust_nn;
|
extern crate rust_nn;
|
||||||
|
|
||||||
use rust_nn::Network;
|
use ndarray::array;
|
||||||
use rust_nn::functions::{activation_functions, loss_functions};
|
use rust_nn::functions::{activation_functions, loss_functions};
|
||||||
use rust_nn::layers::activation_layer::ActivationLayer;
|
use rust_nn::layers::activation_layer::ActivationLayer;
|
||||||
use rust_nn::layers::fc_layer::{FCLayer, Initializer};
|
use rust_nn::layers::fc_layer::{FCLayer, Initializer};
|
||||||
use ndarray::array;
|
use rust_nn::Network;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// training data
|
// training data
|
||||||
|
@ -12,20 +12,15 @@ fn main() {
|
||||||
array![0.0, 0.0],
|
array![0.0, 0.0],
|
||||||
array![0.0, 1.0],
|
array![0.0, 1.0],
|
||||||
array![1.0, 0.0],
|
array![1.0, 0.0],
|
||||||
array![1.0, 1.0]
|
array![1.0, 1.0],
|
||||||
];
|
|
||||||
let y_train = vec![
|
|
||||||
array![0.0],
|
|
||||||
array![1.0],
|
|
||||||
array![1.0],
|
|
||||||
array![0.0]
|
|
||||||
];
|
];
|
||||||
|
let y_train = vec![array![0.0], array![1.0], array![1.0], array![0.0]];
|
||||||
// test data
|
// test data
|
||||||
let x_test= vec![
|
let x_test = vec![
|
||||||
array![0.0, 0.0],
|
array![0.0, 0.0],
|
||||||
array![0.0, 1.0],
|
array![0.0, 1.0],
|
||||||
array![1.0, 0.0],
|
array![1.0, 0.0],
|
||||||
array![1.0, 1.0]
|
array![1.0, 1.0],
|
||||||
];
|
];
|
||||||
|
|
||||||
// initialize neural network
|
// initialize neural network
|
||||||
|
@ -35,15 +30,19 @@ fn main() {
|
||||||
network.add_layer(Box::new(FCLayer::new(
|
network.add_layer(Box::new(FCLayer::new(
|
||||||
3,
|
3,
|
||||||
Initializer::Gaussian(0.0, 1.0),
|
Initializer::Gaussian(0.0, 1.0),
|
||||||
Initializer::Gaussian(0.0, 1.0)
|
Initializer::Gaussian(0.0, 1.0),
|
||||||
|
)));
|
||||||
|
network.add_layer(Box::new(ActivationLayer::new(
|
||||||
|
activation_functions::Type::Tanh,
|
||||||
)));
|
)));
|
||||||
network.add_layer(Box::new(ActivationLayer::new(activation_functions::Type::Tanh)));
|
|
||||||
network.add_layer(Box::new(FCLayer::new(
|
network.add_layer(Box::new(FCLayer::new(
|
||||||
1,
|
1,
|
||||||
Initializer::Gaussian(0.0, 1.0),
|
Initializer::Gaussian(0.0, 1.0),
|
||||||
Initializer::Gaussian(0.0, 1.0)
|
Initializer::Gaussian(0.0, 1.0),
|
||||||
|
)));
|
||||||
|
network.add_layer(Box::new(ActivationLayer::new(
|
||||||
|
activation_functions::Type::Tanh,
|
||||||
)));
|
)));
|
||||||
network.add_layer(Box::new(ActivationLayer::new(activation_functions::Type::Tanh)));
|
|
||||||
|
|
||||||
// train network on training data
|
// train network on training data
|
||||||
network.fit(x_train, y_train, 1000, 0.1, false);
|
network.fit(x_train, y_train, 1000, 0.1, false);
|
||||||
|
@ -58,4 +57,4 @@ fn main() {
|
||||||
prediction.map_mut(|x| *x = x.round());
|
prediction.map_mut(|x| *x = x.round());
|
||||||
print!("prediction: {}\n", prediction);
|
print!("prediction: {}\n", prediction);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,16 +6,21 @@ pub enum Type {
|
||||||
Logistic,
|
Logistic,
|
||||||
Tanh,
|
Tanh,
|
||||||
Relu,
|
Relu,
|
||||||
LeakyRelu
|
LeakyRelu,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse_type(t: Type) -> (fn(&Array1<f64>) -> Array1<f64>, fn(&Array1<f64>) -> Array1<f64>) {
|
pub fn parse_type(
|
||||||
|
t: Type,
|
||||||
|
) -> (
|
||||||
|
fn(&Array1<f64>) -> Array1<f64>,
|
||||||
|
fn(&Array1<f64>) -> Array1<f64>,
|
||||||
|
) {
|
||||||
match t {
|
match t {
|
||||||
Type::Identity => (identity, identity_prime),
|
Type::Identity => (identity, identity_prime),
|
||||||
Type::Logistic => (logistic, logistic_prime),
|
Type::Logistic => (logistic, logistic_prime),
|
||||||
Type::Tanh => (tanh, tanh_prime),
|
Type::Tanh => (tanh, tanh_prime),
|
||||||
Type::Relu => (relu, relu_prime),
|
Type::Relu => (relu, relu_prime),
|
||||||
Type::LeakyRelu => (leaky_relu, leaky_relu_prime)
|
Type::LeakyRelu => (leaky_relu, leaky_relu_prime),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +83,7 @@ pub fn relu(matrix: &Array1<f64>) -> Array1<f64> {
|
||||||
pub fn relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
pub fn relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
||||||
let mut result = matrix.clone();
|
let mut result = matrix.clone();
|
||||||
for x in result.iter_mut() {
|
for x in result.iter_mut() {
|
||||||
*x = if (*x) <= 0.0 {0.0} else {1.0};
|
*x = if (*x) <= 0.0 { 0.0 } else { 1.0 };
|
||||||
}
|
}
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
@ -94,7 +99,7 @@ pub fn leaky_relu(matrix: &Array1<f64>) -> Array1<f64> {
|
||||||
pub fn leaky_relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
pub fn leaky_relu_prime(matrix: &Array1<f64>) -> Array1<f64> {
|
||||||
let mut result = matrix.clone();
|
let mut result = matrix.clone();
|
||||||
for x in result.iter_mut() {
|
for x in result.iter_mut() {
|
||||||
*x = if (*x) <= 0.0 {0.001} else {1.0};
|
*x = if (*x) <= 0.0 { 0.001 } else { 1.0 };
|
||||||
}
|
}
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,13 +2,18 @@ use ndarray::{Array1, ArrayView1};
|
||||||
|
|
||||||
pub enum Type {
|
pub enum Type {
|
||||||
MSE,
|
MSE,
|
||||||
MAE
|
MAE,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse_type(t: Type) -> (fn(ArrayView1<f64>, ArrayView1<f64>) -> f64, fn(ArrayView1<f64>, ArrayView1<f64>) -> Array1<f64>) {
|
pub fn parse_type(
|
||||||
|
t: Type,
|
||||||
|
) -> (
|
||||||
|
fn(ArrayView1<f64>, ArrayView1<f64>) -> f64,
|
||||||
|
fn(ArrayView1<f64>, ArrayView1<f64>) -> Array1<f64>,
|
||||||
|
) {
|
||||||
match t {
|
match t {
|
||||||
Type::MSE => (mse, mse_prime),
|
Type::MSE => (mse, mse_prime),
|
||||||
Type::MAE => (mae, mae_prime)
|
Type::MAE => (mae, mae_prime),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
use ndarray::{Array1, arr1, ArrayView1};
|
use ndarray::{arr1, Array1, ArrayView1};
|
||||||
|
|
||||||
use crate::functions::activation_functions::*;
|
|
||||||
use super::Layer;
|
use super::Layer;
|
||||||
|
use crate::functions::activation_functions::*;
|
||||||
|
|
||||||
pub struct ActivationLayer {
|
pub struct ActivationLayer {
|
||||||
input: Array1<f64>,
|
input: Array1<f64>,
|
||||||
output: Array1<f64>,
|
output: Array1<f64>,
|
||||||
activation: fn(&Array1<f64>) -> Array1<f64>,
|
activation: fn(&Array1<f64>) -> Array1<f64>,
|
||||||
activation_prime: fn(&Array1<f64>) -> Array1<f64>
|
activation_prime: fn(&Array1<f64>) -> Array1<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ActivationLayer {
|
impl ActivationLayer {
|
||||||
|
@ -17,7 +17,7 @@ impl ActivationLayer {
|
||||||
input: arr1(&[]),
|
input: arr1(&[]),
|
||||||
output: arr1(&[]),
|
output: arr1(&[]),
|
||||||
activation,
|
activation,
|
||||||
activation_prime
|
activation_prime,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -36,5 +36,4 @@ impl Layer for ActivationLayer {
|
||||||
temp.zip_mut_with(&output_error, |x, y| *x *= y);
|
temp.zip_mut_with(&output_error, |x, y| *x *= y);
|
||||||
temp
|
temp
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
extern crate ndarray;
|
extern crate ndarray;
|
||||||
|
|
||||||
use ndarray::{Array1, Array2, arr1, arr2, Array, ArrayView1, ShapeBuilder};
|
use ndarray::{arr1, arr2, Array, Array1, Array2, ArrayView1, ShapeBuilder};
|
||||||
use ndarray_rand::RandomExt;
|
|
||||||
use ndarray_rand::rand_distr::{Normal, Uniform};
|
use ndarray_rand::rand_distr::{Normal, Uniform};
|
||||||
|
use ndarray_rand::RandomExt;
|
||||||
|
|
||||||
use super::Layer;
|
use super::Layer;
|
||||||
|
|
||||||
|
@ -11,21 +11,25 @@ pub enum Initializer {
|
||||||
Ones,
|
Ones,
|
||||||
Gaussian(f64, f64),
|
Gaussian(f64, f64),
|
||||||
GaussianWFactor(f64, f64, f64),
|
GaussianWFactor(f64, f64, f64),
|
||||||
Uniform(f64, f64)
|
Uniform(f64, f64),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Initializer {
|
impl Initializer {
|
||||||
pub fn init<Sh, D>(&self, shape: Sh) -> Array<f64, D>
|
pub fn init<Sh, D>(&self, shape: Sh) -> Array<f64, D>
|
||||||
where
|
where
|
||||||
Sh: ShapeBuilder<Dim = D>, D: ndarray::Dimension
|
Sh: ShapeBuilder<Dim = D>,
|
||||||
|
D: ndarray::Dimension,
|
||||||
{
|
{
|
||||||
match self {
|
match self {
|
||||||
Self::Zeros => Array::zeros(shape),
|
Self::Zeros => Array::zeros(shape),
|
||||||
Self::Ones => Array::ones(shape),
|
Self::Ones => Array::ones(shape),
|
||||||
Self::Gaussian(mean, stddev) => Array::random(shape, Normal::new(*mean, *stddev).unwrap()),
|
Self::Gaussian(mean, stddev) => {
|
||||||
Self::GaussianWFactor(mean, stddev, factor)
|
Array::random(shape, Normal::new(*mean, *stddev).unwrap())
|
||||||
=> Array::random(shape, Normal::new(*mean, *stddev).unwrap()) * *factor,
|
}
|
||||||
Self::Uniform(low, high) => Array::random(shape, Uniform::new(low, high))
|
Self::GaussianWFactor(mean, stddev, factor) => {
|
||||||
|
Array::random(shape, Normal::new(*mean, *stddev).unwrap()) * *factor
|
||||||
|
}
|
||||||
|
Self::Uniform(low, high) => Array::random(shape, Uniform::new(low, high)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -42,7 +46,11 @@ pub struct FCLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FCLayer {
|
impl FCLayer {
|
||||||
pub fn new(num_neurons: usize, weight_initializer: Initializer, bias_initializer: Initializer) -> Self {
|
pub fn new(
|
||||||
|
num_neurons: usize,
|
||||||
|
weight_initializer: Initializer,
|
||||||
|
bias_initializer: Initializer,
|
||||||
|
) -> Self {
|
||||||
FCLayer {
|
FCLayer {
|
||||||
num_neurons,
|
num_neurons,
|
||||||
is_initialized: false,
|
is_initialized: false,
|
||||||
|
@ -51,7 +59,7 @@ impl FCLayer {
|
||||||
input: arr1(&[]),
|
input: arr1(&[]),
|
||||||
output: arr1(&[]),
|
output: arr1(&[]),
|
||||||
weights: arr2(&[[]]),
|
weights: arr2(&[[]]),
|
||||||
biases: arr1(&[])
|
biases: arr1(&[]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,11 +83,18 @@ impl Layer for FCLayer {
|
||||||
|
|
||||||
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64> {
|
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64> {
|
||||||
let input_error = output_error.dot(&self.weights.t());
|
let input_error = output_error.dot(&self.weights.t());
|
||||||
let delta_weights =
|
let delta_weights = self
|
||||||
self.input.to_owned().into_shape((self.input.len(), 1usize)).unwrap()
|
.input
|
||||||
.dot(&output_error.into_shape((1usize, output_error.len())).unwrap());
|
.to_owned()
|
||||||
|
.into_shape((self.input.len(), 1usize))
|
||||||
|
.unwrap()
|
||||||
|
.dot(
|
||||||
|
&output_error
|
||||||
|
.into_shape((1usize, output_error.len()))
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
self.weights = &self.weights + learning_rate * &delta_weights;
|
self.weights = &self.weights + learning_rate * &delta_weights;
|
||||||
self.biases = &self.biases + learning_rate * &output_error;
|
self.biases = &self.biases + learning_rate * &output_error;
|
||||||
input_error
|
input_error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
17
src/lib.rs
17
src/lib.rs
|
@ -8,7 +8,7 @@ use ndarray::{Array1, ArrayView1};
|
||||||
pub struct Network {
|
pub struct Network {
|
||||||
layers: Vec<Box<dyn Layer>>,
|
layers: Vec<Box<dyn Layer>>,
|
||||||
loss: fn(ArrayView1<f64>, ArrayView1<f64>) -> f64,
|
loss: fn(ArrayView1<f64>, ArrayView1<f64>) -> f64,
|
||||||
loss_prime: fn(ArrayView1<f64>, ArrayView1<f64>) -> Array1<f64>
|
loss_prime: fn(ArrayView1<f64>, ArrayView1<f64>) -> Array1<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Network {
|
impl Network {
|
||||||
|
@ -17,7 +17,7 @@ impl Network {
|
||||||
Network {
|
Network {
|
||||||
layers: vec![],
|
layers: vec![],
|
||||||
loss,
|
loss,
|
||||||
loss_prime
|
loss_prime,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,7 +41,14 @@ impl Network {
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fit(&mut self, x_train: Vec<Array1<f64>>, y_train: Vec<Array1<f64>>, epochs: usize, learning_rate: f64, trivial_optimize: bool) {
|
pub fn fit(
|
||||||
|
&mut self,
|
||||||
|
x_train: Vec<Array1<f64>>,
|
||||||
|
y_train: Vec<Array1<f64>>,
|
||||||
|
epochs: usize,
|
||||||
|
learning_rate: f64,
|
||||||
|
trivial_optimize: bool,
|
||||||
|
) {
|
||||||
assert!(x_train.len() > 0);
|
assert!(x_train.len() > 0);
|
||||||
assert!(x_train.len() == y_train.len());
|
assert!(x_train.len() == y_train.len());
|
||||||
let num_samples = x_train.len();
|
let num_samples = x_train.len();
|
||||||
|
@ -63,7 +70,7 @@ impl Network {
|
||||||
let mut error = (self.loss_prime)(y_train[j].view(), output.view());
|
let mut error = (self.loss_prime)(y_train[j].view(), output.view());
|
||||||
for layer in self.layers.iter_mut().rev() {
|
for layer in self.layers.iter_mut().rev() {
|
||||||
if trivial_optimize {
|
if trivial_optimize {
|
||||||
error = layer.backward_pass(error.view(), learning_rate / (i+1) as f64);
|
error = layer.backward_pass(error.view(), learning_rate / (i + 1) as f64);
|
||||||
} else {
|
} else {
|
||||||
error = layer.backward_pass(error.view(), learning_rate);
|
error = layer.backward_pass(error.view(), learning_rate);
|
||||||
}
|
}
|
||||||
|
@ -71,7 +78,7 @@ impl Network {
|
||||||
}
|
}
|
||||||
// calculate average error on all samples
|
// calculate average error on all samples
|
||||||
err /= num_samples as f64;
|
err /= num_samples as f64;
|
||||||
println!("epoch {}/{} error={}", i+1, epochs, err);
|
println!("epoch {}/{} error={}", i + 1, epochs, err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue