Improved loss functions

This commit is contained in:
lluni 2023-02-04 18:31:49 +01:00
parent 98bc599dac
commit 7b12a054d5
Signed by: lluni
GPG key ID: ACEEB468BC325D35
2 changed files with 17 additions and 17 deletions

View file

@ -1,4 +1,4 @@
use ndarray::{Array1, ArrayView1}; use ndarray::Array1;
pub enum Type { pub enum Type {
MSE, MSE,
@ -6,8 +6,8 @@ pub enum Type {
} }
type LossFuncTuple = ( type LossFuncTuple = (
fn(ArrayView1<f64>, ArrayView1<f64>) -> f64, fn(&Array1<f64>, &Array1<f64>) -> f64,
fn(ArrayView1<f64>, ArrayView1<f64>) -> Array1<f64>, fn(&Array1<f64>, &Array1<f64>) -> Array1<f64>,
); );
pub fn parse_type(t: Type) -> LossFuncTuple { pub fn parse_type(t: Type) -> LossFuncTuple {
@ -17,23 +17,23 @@ pub fn parse_type(t: Type) -> LossFuncTuple {
} }
} }
pub fn mse(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> f64 { pub fn mse(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> f64 {
let mut temp = &y_true - &y_pred; let mut temp = y_true - y_pred;
temp.mapv_inplace(|x| x * x); temp.mapv_inplace(|x| x * x);
let mut sum = 0.0; let mut sum = 0.0;
for i in 0..temp.len() { for entry in temp.iter() {
sum += temp.get(i).unwrap(); sum += entry;
} }
sum / temp.len() as f64 sum / temp.len() as f64
} }
pub fn mse_prime(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> Array1<f64> { pub fn mse_prime(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> Array1<f64> {
let temp = &y_true - &y_pred; let temp = y_true - y_pred;
temp / (y_true.len() as f64 / 2.0) temp / (y_true.len() as f64 / 2.0)
} }
pub fn mae(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> f64 { pub fn mae(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> f64 {
let temp = &y_true - &y_pred; let temp = y_true - y_pred;
let mut sum = 0.0; let mut sum = 0.0;
for i in 0..temp.len() { for i in 0..temp.len() {
sum += temp.get(i).unwrap().abs(); sum += temp.get(i).unwrap().abs();
@ -41,7 +41,7 @@ pub fn mae(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> f64 {
sum / temp.len() as f64 sum / temp.len() as f64
} }
pub fn mae_prime(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> Array1<f64> { pub fn mae_prime(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> Array1<f64> {
let mut result = Array1::zeros(y_true.raw_dim()); let mut result = Array1::zeros(y_true.raw_dim());
for i in 0..result.len() { for i in 0..result.len() {
if y_true.get(i).unwrap() < y_pred.get(i).unwrap() { if y_true.get(i).unwrap() < y_pred.get(i).unwrap() {

View file

@ -3,12 +3,12 @@ pub mod layers;
use functions::loss_functions::{self, parse_type}; use functions::loss_functions::{self, parse_type};
use layers::*; use layers::*;
use ndarray::{Array1, ArrayView1}; use ndarray::Array1;
pub struct Network { pub struct Network {
layers: Vec<Box<dyn Layer>>, layers: Vec<Box<dyn Layer>>,
loss: fn(ArrayView1<f64>, ArrayView1<f64>) -> f64, loss: fn(&Array1<f64>, &Array1<f64>) -> f64,
loss_prime: fn(ArrayView1<f64>, ArrayView1<f64>) -> Array1<f64>, loss_prime: fn(&Array1<f64>, &Array1<f64>) -> Array1<f64>,
} }
impl Network { impl Network {
@ -62,10 +62,10 @@ impl Network {
} }
// compute loss // compute loss
err += (self.loss)(y_train[j].view(), output.view()); err += (self.loss)(&y_train[j], &output);
// backward propagation // backward propagation
let mut error = (self.loss_prime)(y_train[j].view(), output.view()); let mut error = (self.loss_prime)(&y_train[j], &output);
for layer in self.layers.iter_mut().rev() { for layer in self.layers.iter_mut().rev() {
if trivial_optimize { if trivial_optimize {
error = layer.backward_pass(error, learning_rate / (i + 1) as f64); error = layer.backward_pass(error, learning_rate / (i + 1) as f64);