From 7b12a054d5a19ca34ed5329e2b7125f8501ff7c0 Mon Sep 17 00:00:00 2001 From: lluni Date: Sat, 4 Feb 2023 18:31:49 +0100 Subject: [PATCH] Improved loss functions --- src/functions/loss_functions.rs | 24 ++++++++++++------------ src/lib.rs | 10 +++++----- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/functions/loss_functions.rs b/src/functions/loss_functions.rs index a5d8a94..383ae21 100644 --- a/src/functions/loss_functions.rs +++ b/src/functions/loss_functions.rs @@ -1,4 +1,4 @@ -use ndarray::{Array1, ArrayView1}; +use ndarray::Array1; pub enum Type { MSE, @@ -6,8 +6,8 @@ pub enum Type { } type LossFuncTuple = ( - fn(ArrayView1, ArrayView1) -> f64, - fn(ArrayView1, ArrayView1) -> Array1, + fn(&Array1, &Array1) -> f64, + fn(&Array1, &Array1) -> Array1, ); pub fn parse_type(t: Type) -> LossFuncTuple { @@ -17,23 +17,23 @@ pub fn parse_type(t: Type) -> LossFuncTuple { } } -pub fn mse(y_true: ArrayView1, y_pred: ArrayView1) -> f64 { - let mut temp = &y_true - &y_pred; +pub fn mse(y_true: &Array1, y_pred: &Array1) -> f64 { + let mut temp = y_true - y_pred; temp.mapv_inplace(|x| x * x); let mut sum = 0.0; - for i in 0..temp.len() { - sum += temp.get(i).unwrap(); + for entry in temp.iter() { + sum += entry; } sum / temp.len() as f64 } -pub fn mse_prime(y_true: ArrayView1, y_pred: ArrayView1) -> Array1 { - let temp = &y_true - &y_pred; +pub fn mse_prime(y_true: &Array1, y_pred: &Array1) -> Array1 { + let temp = y_true - y_pred; temp / (y_true.len() as f64 / 2.0) } -pub fn mae(y_true: ArrayView1, y_pred: ArrayView1) -> f64 { - let temp = &y_true - &y_pred; +pub fn mae(y_true: &Array1, y_pred: &Array1) -> f64 { + let temp = y_true - y_pred; let mut sum = 0.0; for i in 0..temp.len() { sum += temp.get(i).unwrap().abs(); @@ -41,7 +41,7 @@ pub fn mae(y_true: ArrayView1, y_pred: ArrayView1) -> f64 { sum / temp.len() as f64 } -pub fn mae_prime(y_true: ArrayView1, y_pred: ArrayView1) -> Array1 { +pub fn mae_prime(y_true: &Array1, y_pred: &Array1) -> Array1 { let mut result = Array1::zeros(y_true.raw_dim()); for i in 0..result.len() { if y_true.get(i).unwrap() < y_pred.get(i).unwrap() { diff --git a/src/lib.rs b/src/lib.rs index aecc97a..2b3f2af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,12 +3,12 @@ pub mod layers; use functions::loss_functions::{self, parse_type}; use layers::*; -use ndarray::{Array1, ArrayView1}; +use ndarray::Array1; pub struct Network { layers: Vec>, - loss: fn(ArrayView1, ArrayView1) -> f64, - loss_prime: fn(ArrayView1, ArrayView1) -> Array1, + loss: fn(&Array1, &Array1) -> f64, + loss_prime: fn(&Array1, &Array1) -> Array1, } impl Network { @@ -62,10 +62,10 @@ impl Network { } // compute loss - err += (self.loss)(y_train[j].view(), output.view()); + err += (self.loss)(&y_train[j], &output); // backward propagation - let mut error = (self.loss_prime)(y_train[j].view(), output.view()); + let mut error = (self.loss_prime)(&y_train[j], &output); for layer in self.layers.iter_mut().rev() { if trivial_optimize { error = layer.backward_pass(error, learning_rate / (i + 1) as f64);