diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index 79fa19a..c8d666d 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -4,14 +4,8 @@
-
-
-
-
-
-
-
-
+
+
@@ -103,7 +97,14 @@
1653168727186
-
+
+ 1653176561735
+
+
+
+ 1653176561735
+
+
@@ -120,7 +121,8 @@
-
+
+
diff --git a/src/main/java/ActivationFunctions.java b/src/main/java/ActivationFunctions.java
index b073eb0..ae5e7c8 100644
--- a/src/main/java/ActivationFunctions.java
+++ b/src/main/java/ActivationFunctions.java
@@ -36,4 +36,20 @@ public class ActivationFunctions {
private static double sigma(double value) {
return 1 / (1 + Math.exp(-value));
}
+
+ public static SimpleMatrix ReLu(SimpleMatrix A) {
+ SimpleMatrix B = new SimpleMatrix(A);
+ for (int i = 0; i < A.getNumElements(); i++) {
+ B.set(i, Math.max(0, A.get(i)));
+ }
+ return B;
+ }
+
+ public static SimpleMatrix ReLuPrime(SimpleMatrix A) {
+ SimpleMatrix B = new SimpleMatrix(A);
+ for (int i = 0; i < A.getNumElements(); i++) {
+ B.set(i, A.get(i) < 0 ? 0 : 1);
+ }
+ return B;
+ }
}
diff --git a/src/main/java/FCLayer.java b/src/main/java/FCLayer.java
index e25a8c6..952d3f6 100644
--- a/src/main/java/FCLayer.java
+++ b/src/main/java/FCLayer.java
@@ -9,9 +9,9 @@ public class FCLayer extends Layer {
public FCLayer(int inputSize, int outputSize) {
Random random = new Random();
weights = new SimpleMatrix(inputSize, outputSize, true,
- random.doubles((long) inputSize*outputSize, -0.5, 0.5).toArray());
+ random.doubles((long) inputSize*outputSize, -1, 1).toArray());
biases = new SimpleMatrix(1, outputSize, true,
- random.doubles(outputSize, -0.5, 0.5).toArray());
+ random.doubles(outputSize, -1, 1).toArray());
}
@Override
diff --git a/src/main/java/LossFunctions.java b/src/main/java/LossFunctions.java
index c59193d..26d735a 100644
--- a/src/main/java/LossFunctions.java
+++ b/src/main/java/LossFunctions.java
@@ -14,4 +14,24 @@ public class LossFunctions {
public static SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2);
}
+
+ public static double MAE(SimpleMatrix y_true, SimpleMatrix y_pred) {
+ double sum = 0;
+ for (int i = 0; i < y_true.getNumElements(); i++) {
+ sum += Math.abs(y_true.get(i) - y_pred.get(i));
+ }
+ return sum / y_true.getNumElements();
+ }
+
+ public static SimpleMatrix MAEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
+ SimpleMatrix result = new SimpleMatrix(y_true);
+ for (int i = 0; i < result.getNumElements(); i++) {
+ if (y_true.get(i) < y_pred.get(i)) {
+ result.set(i, 1d);
+ } else {
+ result.set(i, -1d);
+ }
+ }
+ return result;
+ }
}