diff --git a/.idea/gradle.xml b/.idea/gradle.xml
index ba1ec5c..611e7c8 100644
--- a/.idea/gradle.xml
+++ b/.idea/gradle.xml
@@ -1,5 +1,6 @@
+
-
+
+
+
+
-
-
+
@@ -21,6 +23,23 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -29,6 +48,11 @@
+
+
+
@@ -59,6 +83,7 @@
+
@@ -71,6 +96,14 @@
1653154357924
+
+ 1653168727185
+
+
+
+ 1653168727186
+
+
@@ -86,9 +119,10 @@
-
+
+
-
+
\ No newline at end of file
diff --git a/src/main/java/ActivationFunctions.java b/src/main/java/ActivationFunctions.java
new file mode 100644
index 0000000..b073eb0
--- /dev/null
+++ b/src/main/java/ActivationFunctions.java
@@ -0,0 +1,39 @@
+import org.ejml.simple.SimpleMatrix;
+
+public class ActivationFunctions {
+ public static SimpleMatrix tanh(SimpleMatrix A) {
+ SimpleMatrix B = new SimpleMatrix(A);
+ for (int i = 0; i < A.getNumElements(); i++) {
+ B.set(i, Math.tanh(A.get(i)));
+ }
+ return B;
+ }
+
+ public static SimpleMatrix tanhPrime(SimpleMatrix A) {
+ SimpleMatrix B = new SimpleMatrix(A);
+ for (int i = 0; i < A.getNumElements(); i++) {
+ B.set(i, 1 - Math.pow(Math.tanh(A.get(i)), 2));
+ }
+ return B;
+ }
+
+ public static SimpleMatrix logistic(SimpleMatrix A) {
+ SimpleMatrix B = new SimpleMatrix(A);
+ for (int i = 0; i < A.getNumElements(); i++) {
+ B.set(i, sigma(-A.get(i)));
+ }
+ return B;
+ }
+
+ public static SimpleMatrix logisticPrime(SimpleMatrix A) {
+ SimpleMatrix B = new SimpleMatrix(A);
+ for (int i = 0; i < A.getNumElements(); i++) {
+ B.set(i, sigma(A.get(i) * (1 - sigma(A.get(i)))));
+ }
+ return B;
+ }
+
+ private static double sigma(double value) {
+ return 1 / (1 + Math.exp(-value));
+ }
+}
diff --git a/src/main/java/ActivationLayer.java b/src/main/java/ActivationLayer.java
index e9598ca..0e42421 100644
--- a/src/main/java/ActivationLayer.java
+++ b/src/main/java/ActivationLayer.java
@@ -1,7 +1,8 @@
import org.ejml.simple.SimpleMatrix;
+import java.util.function.Function;
+
public class ActivationLayer extends Layer {
- /* custom activation functions not yet supported
Function activation;
Function activationPrime;
@@ -9,35 +10,16 @@ public class ActivationLayer extends Layer {
this.activation = activation;
this.activationPrime = activationPrime;
}
- */
-
- public ActivationLayer() {}
@Override
public SimpleMatrix forwardPropagation(SimpleMatrix input) {
this.input = input;
- this.output = tanh(input);
+ this.output = activation.apply(input);
return this.output;
}
@Override
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
- return tanhPrime(this.input).elementMult(outputError);
- }
-
- private SimpleMatrix tanh(SimpleMatrix A) {
- SimpleMatrix B = new SimpleMatrix(A);
- for (int i = 0; i < A.getNumElements(); i++) {
- B.set(i, Math.tanh(A.get(i)));
- }
- return B;
- }
-
- private SimpleMatrix tanhPrime(SimpleMatrix A) {
- SimpleMatrix B = new SimpleMatrix(A);
- for (int i = 0; i < A.getNumElements(); i++) {
- B.set(i, 1 - Math.pow(Math.tanh(A.get(i)), 2));
- }
- return B;
+ return activationPrime.apply(this.input).elementMult(outputError);
}
}
diff --git a/src/main/java/ExampleXOR.java b/src/main/java/ExampleXOR.java
index 445e0b9..c08dde3 100644
--- a/src/main/java/ExampleXOR.java
+++ b/src/main/java/ExampleXOR.java
@@ -1,7 +1,5 @@
import org.ejml.simple.SimpleMatrix;
-import java.util.ArrayList;
-
public class ExampleXOR {
public static void main(String[] args) {
SimpleMatrix[] X_train = {new SimpleMatrix(new double[][]{{0, 0}}),
@@ -15,13 +13,14 @@ public class ExampleXOR {
Network network = new Network();
network.add(new FCLayer(2, 3));
- network.add(new ActivationLayer());
+ network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.add(new FCLayer(3, 1));
- network.add(new ActivationLayer());
+ network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
+ network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
network.fit(X_train, y_train, 1000, 0.1d);
- ArrayList output = network.predict(X_train);
+ SimpleMatrix[] output = network.predict(X_train);
for (SimpleMatrix entry : output) {
System.out.println("Prediction:");
for (int i = 0; i < entry.getNumElements(); i++) {
diff --git a/src/main/java/LossFunctions.java b/src/main/java/LossFunctions.java
new file mode 100644
index 0000000..c59193d
--- /dev/null
+++ b/src/main/java/LossFunctions.java
@@ -0,0 +1,17 @@
+import org.ejml.simple.SimpleMatrix;
+
+public class LossFunctions {
+ public static double MSE(SimpleMatrix y_true, SimpleMatrix y_pred) {
+ SimpleMatrix temp = y_true.minus(y_pred);
+ temp = temp.elementMult(temp);
+ double sum = 0;
+ for (int i = 0; i < temp.getNumElements(); i++) {
+ sum += temp.get(i);
+ }
+ return sum / temp.getNumElements();
+ }
+
+ public static SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
+ return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2);
+ }
+}
diff --git a/src/main/java/Network.java b/src/main/java/Network.java
index 404483b..e330de2 100644
--- a/src/main/java/Network.java
+++ b/src/main/java/Network.java
@@ -1,20 +1,13 @@
import org.ejml.simple.SimpleMatrix;
import java.util.ArrayList;
+import java.util.function.BiFunction;
public class Network {
ArrayList layers;
-
- /* custom loss functions not yet supported
- Function loss;
- Function lossPrime;
-
- public void use(Function loss, Function lossPrime) {
- this.loss = loss;
- this.lossPrime = lossPrime;
- }
- */
+ BiFunction loss;
+ BiFunction lossPrime;
public Network() {
layers = new ArrayList<>();
@@ -24,16 +17,23 @@ public class Network {
layers.add(layer);
}
- public ArrayList predict(SimpleMatrix[] inputs) {
- ArrayList result = new ArrayList<>();
+ public void use(BiFunction loss, BiFunction lossPrime) {
+ this.loss = loss;
+ this.lossPrime = lossPrime;
+ }
+
+ public SimpleMatrix[] predict(SimpleMatrix[] inputs) {
+ SimpleMatrix[] result = new SimpleMatrix[inputs.length];
SimpleMatrix output;
+ int i = 0;
for (SimpleMatrix input : inputs) {
output = input;
for (Layer l : layers) {
output = l.forwardPropagation(output);
}
- result.add(output);
+ result[i] = output;
+ i++;
}
return result;
@@ -52,10 +52,10 @@ public class Network {
}
// compute loss (for display purpose only)
- err = MSE(y_train[j], output);
+ err = loss.apply(y_train[j], output);
// backward propagation
- SimpleMatrix error = MSEPrime(y_train[j], output);
+ SimpleMatrix error = lossPrime.apply(y_train[j], output);
for (int k = layers.size() - 1; k >= 0; k--) {
error = layers.get(k).backwardPropagation(error, learningRate);
}
@@ -65,18 +65,4 @@ public class Network {
System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err);
}
}
-
- private double MSE(SimpleMatrix y_true, SimpleMatrix y_pred) {
- SimpleMatrix temp = y_true.minus(y_pred);
- temp = temp.elementMult(temp);
- double sum = 0;
- for (int i = 0; i < temp.getNumElements(); i++) {
- sum += temp.get(i);
- }
- return sum / temp.getNumElements();
- }
-
- private SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
- return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2);
- }
}
\ No newline at end of file