diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index 5d270db..7688800 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -5,24 +5,11 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
@@ -53,14 +40,14 @@
- {
+ "keyToString": {
+ "SHARE_PROJECT_CONFIGURATION_FILES": "true",
+ "project.structure.last.edited": "Project",
+ "project.structure.proportion": "0.15",
+ "project.structure.side.proportion": "0.2"
}
-}]]>
+}
diff --git a/src/main/java/ActivationLayer.java b/src/main/java/ActivationLayer.java
index 1e106d0..e9598ca 100644
--- a/src/main/java/ActivationLayer.java
+++ b/src/main/java/ActivationLayer.java
@@ -1,17 +1,17 @@
import org.ejml.simple.SimpleMatrix;
-import java.util.function.Function;
-
public class ActivationLayer extends Layer {
+ /* custom activation functions not yet supported
Function activation;
Function activationPrime;
- public ActivationLayer() {}
-
public ActivationLayer(Function activation, Function activationPrime) {
this.activation = activation;
this.activationPrime = activationPrime;
}
+ */
+
+ public ActivationLayer() {}
@Override
public SimpleMatrix forwardPropagation(SimpleMatrix input) {
@@ -21,7 +21,7 @@ public class ActivationLayer extends Layer {
}
@Override
- public SimpleMatrix backwardPropagation(SimpleMatrix outputError, Double learningRate) {
+ public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
return tanhPrime(this.input).elementMult(outputError);
}
diff --git a/src/main/java/FCLayer.java b/src/main/java/FCLayer.java
index 6e07284..e25a8c6 100644
--- a/src/main/java/FCLayer.java
+++ b/src/main/java/FCLayer.java
@@ -22,7 +22,7 @@ public class FCLayer extends Layer {
}
@Override
- public SimpleMatrix backwardPropagation(SimpleMatrix outputError, Double learningRate) {
+ public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
SimpleMatrix inputError = outputError.mult(this.weights.transpose());
SimpleMatrix weightsError = this.input.transpose().mult(outputError);
diff --git a/src/main/java/Layer.java b/src/main/java/Layer.java
index 16f5d7e..710c293 100644
--- a/src/main/java/Layer.java
+++ b/src/main/java/Layer.java
@@ -6,5 +6,5 @@ public abstract class Layer {
public abstract SimpleMatrix forwardPropagation(SimpleMatrix inputs);
- public abstract SimpleMatrix backwardPropagation(SimpleMatrix outputError, Double learningRate);
+ public abstract SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate);
}
diff --git a/src/main/java/Network.java b/src/main/java/Network.java
index b8df682..404483b 100644
--- a/src/main/java/Network.java
+++ b/src/main/java/Network.java
@@ -1,14 +1,21 @@
import org.ejml.simple.SimpleMatrix;
import java.util.ArrayList;
-import java.util.function.Function;
public class Network {
ArrayList layers;
+
+ /* custom loss functions not yet supported
Function loss;
Function lossPrime;
+ public void use(Function loss, Function lossPrime) {
+ this.loss = loss;
+ this.lossPrime = lossPrime;
+ }
+ */
+
public Network() {
layers = new ArrayList<>();
}
@@ -17,13 +24,7 @@ public class Network {
layers.add(layer);
}
- public void use(Function loss, Function lossPrime) {
- this.loss = loss;
- this.lossPrime = lossPrime;
- }
-
public ArrayList predict(SimpleMatrix[] inputs) {
- int samples = inputs.length;
ArrayList result = new ArrayList<>();
SimpleMatrix output;