37 lines
1.2 KiB
Java
37 lines
1.2 KiB
Java
import org.ejml.simple.SimpleMatrix;
|
|
|
|
public class LossFunctions {
|
|
public static double MSE(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
|
SimpleMatrix temp = y_true.minus(y_pred);
|
|
temp = temp.elementMult(temp);
|
|
double sum = 0;
|
|
for (int i = 0; i < temp.getNumElements(); i++) {
|
|
sum += temp.get(i);
|
|
}
|
|
return sum / temp.getNumElements();
|
|
}
|
|
|
|
public static SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
|
return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2);
|
|
}
|
|
|
|
public static double MAE(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
|
double sum = 0;
|
|
for (int i = 0; i < y_true.getNumElements(); i++) {
|
|
sum += Math.abs(y_true.get(i) - y_pred.get(i));
|
|
}
|
|
return sum / y_true.getNumElements();
|
|
}
|
|
|
|
public static SimpleMatrix MAEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
|
SimpleMatrix result = new SimpleMatrix(y_true);
|
|
for (int i = 0; i < result.getNumElements(); i++) {
|
|
if (y_true.get(i) < y_pred.get(i)) {
|
|
result.set(i, 1d);
|
|
} else {
|
|
result.set(i, -1d);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
}
|