Added MAE loss function

This commit is contained in:
lluni 2022-05-22 01:42:41 +02:00
parent 8376c00470
commit 69a6c8df29
2 changed files with 21 additions and 1 deletions

View file

@ -6,7 +6,7 @@
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="75a8c215-b746-4a4d-aa16-f3223c12b1ed" name="Changes" comment="Commented out not yet working code"> <list default="true" id="75a8c215-b746-4a4d-aa16-f3223c12b1ed" name="Changes" comment="Commented out not yet working code">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/src/main/java/FCLayer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/FCLayer.java" afterDir="false" /> <change beforePath="$PROJECT_DIR$/src/main/java/LossFunctions.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/LossFunctions.java" afterDir="false" />
</list> </list>
<option name="SHOW_DIALOG" value="false" /> <option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" /> <option name="HIGHLIGHT_CONFLICTS" value="true" />

View file

@ -14,4 +14,24 @@ public class LossFunctions {
public static SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) { public static SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2); return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2);
} }
public static double MAE(SimpleMatrix y_true, SimpleMatrix y_pred) {
double sum = 0;
for (int i = 0; i < y_true.getNumElements(); i++) {
sum += Math.abs(y_true.get(i) - y_pred.get(i));
}
return sum / y_true.getNumElements();
}
public static SimpleMatrix MAEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
SimpleMatrix result = new SimpleMatrix(y_true);
for (int i = 0; i < result.getNumElements(); i++) {
if (y_true.get(i) < y_pred.get(i)) {
result.set(i, 1d);
} else {
result.set(i, -1d);
}
}
return result;
}
} }