Compare commits

..

No commits in common. "7b12a054d5a19ca34ed5329e2b7125f8501ff7c0" and "d130c7cce11b487b566028b70ec8c3a87a5041aa" have entirely different histories.

5 changed files with 43 additions and 36 deletions

View file

@ -1,4 +1,4 @@
use ndarray::Array1; use ndarray::{Array1, ArrayView1};
pub enum Type { pub enum Type {
MSE, MSE,
@ -6,8 +6,8 @@ pub enum Type {
} }
type LossFuncTuple = ( type LossFuncTuple = (
fn(&Array1<f64>, &Array1<f64>) -> f64, fn(ArrayView1<f64>, ArrayView1<f64>) -> f64,
fn(&Array1<f64>, &Array1<f64>) -> Array1<f64>, fn(ArrayView1<f64>, ArrayView1<f64>) -> Array1<f64>,
); );
pub fn parse_type(t: Type) -> LossFuncTuple { pub fn parse_type(t: Type) -> LossFuncTuple {
@ -17,23 +17,23 @@ pub fn parse_type(t: Type) -> LossFuncTuple {
} }
} }
pub fn mse(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> f64 { pub fn mse(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> f64 {
let mut temp = y_true - y_pred; let mut temp = &y_true - &y_pred;
temp.mapv_inplace(|x| x * x); temp.mapv_inplace(|x| x * x);
let mut sum = 0.0; let mut sum = 0.0;
for entry in temp.iter() { for i in 0..temp.len() {
sum += entry; sum += temp.get(i).unwrap();
} }
sum / temp.len() as f64 sum / temp.len() as f64
} }
pub fn mse_prime(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> Array1<f64> { pub fn mse_prime(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> Array1<f64> {
let temp = y_true - y_pred; let temp = &y_true - &y_pred;
temp / (y_true.len() as f64 / 2.0) temp / (y_true.len() as f64 / 2.0)
} }
pub fn mae(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> f64 { pub fn mae(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> f64 {
let temp = y_true - y_pred; let temp = &y_true - &y_pred;
let mut sum = 0.0; let mut sum = 0.0;
for i in 0..temp.len() { for i in 0..temp.len() {
sum += temp.get(i).unwrap().abs(); sum += temp.get(i).unwrap().abs();
@ -41,7 +41,7 @@ pub fn mae(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> f64 {
sum / temp.len() as f64 sum / temp.len() as f64
} }
pub fn mae_prime(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> Array1<f64> { pub fn mae_prime(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> Array1<f64> {
let mut result = Array1::zeros(y_true.raw_dim()); let mut result = Array1::zeros(y_true.raw_dim());
for i in 0..result.len() { for i in 0..result.len() {
if y_true.get(i).unwrap() < y_pred.get(i).unwrap() { if y_true.get(i).unwrap() < y_pred.get(i).unwrap() {

View file

@ -1,4 +1,4 @@
use ndarray::{arr1, Array1}; use ndarray::{arr1, Array1, ArrayView1};
use super::Layer; use super::Layer;
use crate::functions::activation_functions::*; use crate::functions::activation_functions::*;
@ -23,15 +23,15 @@ impl ActivationLayer {
} }
impl Layer for ActivationLayer { impl Layer for ActivationLayer {
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64> { fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
self.input = input; self.input = input.to_owned();
// output isn't needed elsewhere // output isn't needed elsewhere
// self.output = (self.activation)(&self.input); // self.output = (self.activation)(&self.input);
// self.output.clone() // self.output.clone()
(self.activation)(&self.input) (self.activation)(&self.input)
} }
fn backward_pass(&mut self, output_error: Array1<f64>, _learning_rate: f64) -> Array1<f64> { fn backward_pass(&mut self, output_error: ArrayView1<f64>, _learning_rate: f64) -> Array1<f64> {
(self.activation_prime)(&self.input) * output_error (self.activation_prime)(&self.input) * output_error
} }
} }

View file

@ -1,6 +1,6 @@
extern crate ndarray; extern crate ndarray;
use ndarray::{arr1, arr2, Array, Array1, Array2, ShapeBuilder}; use ndarray::{arr1, arr2, Array, Array1, Array2, ArrayView1, ShapeBuilder};
use ndarray_rand::rand_distr::{Normal, Uniform}; use ndarray_rand::rand_distr::{Normal, Uniform};
use ndarray_rand::RandomExt; use ndarray_rand::RandomExt;
@ -71,25 +71,30 @@ impl FCLayer {
} }
impl Layer for FCLayer { impl Layer for FCLayer {
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64> { fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
if !self.is_initialized { if !self.is_initialized {
self.initialize(input.len()); self.initialize(input.len());
} }
self.input = input; self.input = input.to_owned();
// output isn't needed elsewhere // output isn't needed elsewhere
// self.output = self.input.dot(&self.weights) + &self.biases; // self.output = self.input.dot(&self.weights) + &self.biases;
// self.output.clone() // self.output.clone()
self.input.dot(&self.weights) + &self.biases self.input.dot(&self.weights) + &self.biases
} }
fn backward_pass(&mut self, output_error: Array1<f64>, learning_rate: f64) -> Array1<f64> { fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64> {
let input_error = output_error.dot(&self.weights.t()); let input_error = output_error.dot(&self.weights.t());
let delta_weights = self let delta_weights = self
.input .input
.to_shape((self.input.len(), 1usize)) .to_owned()
.into_shape((self.input.len(), 1usize))
.unwrap() .unwrap()
.dot(&output_error.to_shape((1usize, output_error.len())).unwrap()); .dot(
&output_error
.into_shape((1usize, output_error.len()))
.unwrap(),
);
self.weights = &self.weights + learning_rate * &delta_weights; self.weights = &self.weights + learning_rate * &delta_weights;
self.biases = &self.biases + learning_rate * &output_error; self.biases = &self.biases + learning_rate * &output_error;
input_error input_error

View file

@ -1,9 +1,9 @@
use ndarray::Array1; use ndarray::{Array1, ArrayView1};
pub mod activation_layer; pub mod activation_layer;
pub mod fc_layer; pub mod fc_layer;
pub trait Layer { pub trait Layer {
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64>; fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64>;
fn backward_pass(&mut self, output_error: Array1<f64>, learning_rate: f64) -> Array1<f64>; fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64>;
} }

View file

@ -3,12 +3,12 @@ pub mod layers;
use functions::loss_functions::{self, parse_type}; use functions::loss_functions::{self, parse_type};
use layers::*; use layers::*;
use ndarray::Array1; use ndarray::{Array1, ArrayView1};
pub struct Network { pub struct Network {
layers: Vec<Box<dyn Layer>>, layers: Vec<Box<dyn Layer>>,
loss: fn(&Array1<f64>, &Array1<f64>) -> f64, loss: fn(ArrayView1<f64>, ArrayView1<f64>) -> f64,
loss_prime: fn(&Array1<f64>, &Array1<f64>) -> Array1<f64>, loss_prime: fn(ArrayView1<f64>, ArrayView1<f64>) -> Array1<f64>,
} }
impl Network { impl Network {
@ -30,9 +30,10 @@ impl Network {
let mut result = vec![]; let mut result = vec![];
for input in inputs.iter() { for input in inputs.iter() {
let mut output = input.to_owned(); let mut output = Array1::default(inputs[0].raw_dim());
output.assign(input);
for layer in &mut self.layers { for layer in &mut self.layers {
output = layer.forward_pass(output); output = layer.forward_pass(output.view());
} }
result.push(output); result.push(output);
} }
@ -56,21 +57,22 @@ impl Network {
let mut err = 0.0; let mut err = 0.0;
for j in 0..num_samples { for j in 0..num_samples {
// forward propagation // forward propagation
let mut output = x_train[j].to_owned(); let mut output = Array1::default(x_train[0].raw_dim());
output.assign(&x_train[j]);
for layer in self.layers.iter_mut() { for layer in self.layers.iter_mut() {
output = layer.forward_pass(output); output = layer.forward_pass(output.view());
} }
// compute loss // compute loss
err += (self.loss)(&y_train[j], &output); err += (self.loss)(y_train[j].view(), output.view());
// backward propagation // backward propagation
let mut error = (self.loss_prime)(&y_train[j], &output); let mut error = (self.loss_prime)(y_train[j].view(), output.view());
for layer in self.layers.iter_mut().rev() { for layer in self.layers.iter_mut().rev() {
if trivial_optimize { if trivial_optimize {
error = layer.backward_pass(error, learning_rate / (i + 1) as f64); error = layer.backward_pass(error.view(), learning_rate / (i + 1) as f64);
} else { } else {
error = layer.backward_pass(error, learning_rate); error = layer.backward_pass(error.view(), learning_rate);
} }
} }
} }