Improved loss functions
This commit is contained in:
parent
98bc599dac
commit
7b12a054d5
2 changed files with 17 additions and 17 deletions
|
@ -1,4 +1,4 @@
|
|||
use ndarray::{Array1, ArrayView1};
|
||||
use ndarray::Array1;
|
||||
|
||||
pub enum Type {
|
||||
MSE,
|
||||
|
@ -6,8 +6,8 @@ pub enum Type {
|
|||
}
|
||||
|
||||
type LossFuncTuple = (
|
||||
fn(ArrayView1<f64>, ArrayView1<f64>) -> f64,
|
||||
fn(ArrayView1<f64>, ArrayView1<f64>) -> Array1<f64>,
|
||||
fn(&Array1<f64>, &Array1<f64>) -> f64,
|
||||
fn(&Array1<f64>, &Array1<f64>) -> Array1<f64>,
|
||||
);
|
||||
|
||||
pub fn parse_type(t: Type) -> LossFuncTuple {
|
||||
|
@ -17,23 +17,23 @@ pub fn parse_type(t: Type) -> LossFuncTuple {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn mse(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> f64 {
|
||||
let mut temp = &y_true - &y_pred;
|
||||
pub fn mse(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> f64 {
|
||||
let mut temp = y_true - y_pred;
|
||||
temp.mapv_inplace(|x| x * x);
|
||||
let mut sum = 0.0;
|
||||
for i in 0..temp.len() {
|
||||
sum += temp.get(i).unwrap();
|
||||
for entry in temp.iter() {
|
||||
sum += entry;
|
||||
}
|
||||
sum / temp.len() as f64
|
||||
}
|
||||
|
||||
pub fn mse_prime(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> Array1<f64> {
|
||||
let temp = &y_true - &y_pred;
|
||||
pub fn mse_prime(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> Array1<f64> {
|
||||
let temp = y_true - y_pred;
|
||||
temp / (y_true.len() as f64 / 2.0)
|
||||
}
|
||||
|
||||
pub fn mae(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> f64 {
|
||||
let temp = &y_true - &y_pred;
|
||||
pub fn mae(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> f64 {
|
||||
let temp = y_true - y_pred;
|
||||
let mut sum = 0.0;
|
||||
for i in 0..temp.len() {
|
||||
sum += temp.get(i).unwrap().abs();
|
||||
|
@ -41,7 +41,7 @@ pub fn mae(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> f64 {
|
|||
sum / temp.len() as f64
|
||||
}
|
||||
|
||||
pub fn mae_prime(y_true: ArrayView1<f64>, y_pred: ArrayView1<f64>) -> Array1<f64> {
|
||||
pub fn mae_prime(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> Array1<f64> {
|
||||
let mut result = Array1::zeros(y_true.raw_dim());
|
||||
for i in 0..result.len() {
|
||||
if y_true.get(i).unwrap() < y_pred.get(i).unwrap() {
|
||||
|
|
10
src/lib.rs
10
src/lib.rs
|
@ -3,12 +3,12 @@ pub mod layers;
|
|||
|
||||
use functions::loss_functions::{self, parse_type};
|
||||
use layers::*;
|
||||
use ndarray::{Array1, ArrayView1};
|
||||
use ndarray::Array1;
|
||||
|
||||
pub struct Network {
|
||||
layers: Vec<Box<dyn Layer>>,
|
||||
loss: fn(ArrayView1<f64>, ArrayView1<f64>) -> f64,
|
||||
loss_prime: fn(ArrayView1<f64>, ArrayView1<f64>) -> Array1<f64>,
|
||||
loss: fn(&Array1<f64>, &Array1<f64>) -> f64,
|
||||
loss_prime: fn(&Array1<f64>, &Array1<f64>) -> Array1<f64>,
|
||||
}
|
||||
|
||||
impl Network {
|
||||
|
@ -62,10 +62,10 @@ impl Network {
|
|||
}
|
||||
|
||||
// compute loss
|
||||
err += (self.loss)(y_train[j].view(), output.view());
|
||||
err += (self.loss)(&y_train[j], &output);
|
||||
|
||||
// backward propagation
|
||||
let mut error = (self.loss_prime)(y_train[j].view(), output.view());
|
||||
let mut error = (self.loss_prime)(&y_train[j], &output);
|
||||
for layer in self.layers.iter_mut().rev() {
|
||||
if trivial_optimize {
|
||||
error = layer.backward_pass(error, learning_rate / (i + 1) as f64);
|
||||
|
|
Loading…
Reference in a new issue