diff --git a/src/main/java/BlankLayer.java b/src/main/java/BlankLayer.java index 2dcfd00..e444ecb 100644 --- a/src/main/java/BlankLayer.java +++ b/src/main/java/BlankLayer.java @@ -3,6 +3,7 @@ import org.ejml.simple.SimpleMatrix; import java.util.Random; /** + * Goal: initialize layer without any neurons. Not yet implemented. * Layer initialized with 1 neuron. * Assumes that each new neuron is fully connected to every previous neuron (this will be changed in the future). */ @@ -18,57 +19,6 @@ public class BlankLayer extends Layer { random.doubles(1, -1, 1).toArray()); } - /** - * Updates input size when previous layer has newly added neurons. - * @param n amount of new neurons in previous layer - */ - public void updateInputSize(int n) { - Random random = new Random(); - - // add new weights - SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows() + n, this.weights.numCols()); - for (int i = 0; i < this.weights.numRows(); i++) { - for (int j = 0; j < this.weights.numCols(); j++) { - newWeights.set(i, j, this.weights.get(i, j)); - } - } - for (int i = 0; i < newWeights.getNumElements(); i++) { - if (newWeights.get(i) == 0) { - newWeights.set(i, random.nextDouble(-1, 1)); - } - } - this.weights = newWeights; - } - - /** - * Adds new neurons at the end of the layer - * @param n amount how many new neurons should be added - */ - public void addNeuron(int n) { - Random random = new Random(); - - // add new weights - SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n); - for (int i = 0; i < this.weights.numRows(); i++) { - for (int j = 0; j < this.weights.numCols(); j++) { - newWeights.set(i, j, this.weights.get(i, j)); - } - } - for (int i = 0; i < newWeights.getNumElements(); i++) { - if (newWeights.get(i) == 0) { - newWeights.set(i, random.nextDouble(-1, 1)); - } - } - this.weights = newWeights; - - // add new biases - SimpleMatrix newBiases = new SimpleMatrix(1, this.biases.numCols() + n); - double[] newBiasValues = random.doubles(n, -1, 1).toArray(); - System.arraycopy(this.biases.getDDRM().data, 0, newBiases.getDDRM().data, 0, this.biases.numCols()); - System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n); - this.biases = newBiases; - } - @Override public SimpleMatrix forwardPropagation(SimpleMatrix inputs) { this.input = inputs; diff --git a/src/main/java/ExampleXORBlankLayers.java b/src/main/java/ExampleXORBlankLayers.java index 2dca672..253e712 100644 --- a/src/main/java/ExampleXORBlankLayers.java +++ b/src/main/java/ExampleXORBlankLayers.java @@ -12,9 +12,9 @@ public class ExampleXORBlankLayers { new SimpleMatrix(new double[][]{{0}})}; Network network = new Network(); - network.addLayer(new BlankLayer(2)); + network.addLayer(new FCLayer(2, 1)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); - network.addLayer(new BlankLayer(1)); + network.addLayer(new FCLayer(1, 1)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addNeuron(0, 2); diff --git a/src/main/java/FCLayer.java b/src/main/java/FCLayer.java index 952d3f6..e517ced 100644 --- a/src/main/java/FCLayer.java +++ b/src/main/java/FCLayer.java @@ -14,6 +14,57 @@ public class FCLayer extends Layer { random.doubles(outputSize, -1, 1).toArray()); } + /** + * Updates input size when previous layer has newly added neurons. + * @param n amount of new neurons in previous layer + */ + public void updateInputSize(int n) { + Random random = new Random(); + + // add new weights + SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows() + n, this.weights.numCols()); + for (int i = 0; i < this.weights.numRows(); i++) { + for (int j = 0; j < this.weights.numCols(); j++) { + newWeights.set(i, j, this.weights.get(i, j)); + } + } + for (int i = 0; i < newWeights.getNumElements(); i++) { + if (newWeights.get(i) == 0) { + newWeights.set(i, random.nextDouble(-1, 1)); + } + } + this.weights = newWeights; + } + + /** + * Adds new neurons at the end of the layer + * @param n amount how many new neurons should be added + */ + public void addNeuron(int n) { + Random random = new Random(); + + // add new weights + SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n); + for (int i = 0; i < this.weights.numRows(); i++) { + for (int j = 0; j < this.weights.numCols(); j++) { + newWeights.set(i, j, this.weights.get(i, j)); + } + } + for (int i = 0; i < newWeights.getNumElements(); i++) { + if (newWeights.get(i) == 0) { + newWeights.set(i, random.nextDouble(-1, 1)); + } + } + this.weights = newWeights; + + // add new biases + SimpleMatrix newBiases = new SimpleMatrix(1, this.biases.numCols() + n); + double[] newBiasValues = random.doubles(n, -1, 1).toArray(); + System.arraycopy(this.biases.getDDRM().data, 0, newBiases.getDDRM().data, 0, this.biases.numCols()); + System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n); + this.biases = newBiases; + } + @Override public SimpleMatrix forwardPropagation(SimpleMatrix inputs) { this.input = inputs; diff --git a/src/main/java/Network.java b/src/main/java/Network.java index 565a1c0..090a6ed 100644 --- a/src/main/java/Network.java +++ b/src/main/java/Network.java @@ -23,13 +23,13 @@ public class Network { * @param n amount how many new neurons should be added */ public void addNeuron(int layer, int n) { - if (!(this.layers.get(layer) instanceof BlankLayer)) { + if (!(this.layers.get(layer) instanceof FCLayer)) { System.out.println("This layer is not a BlankLayer"); - } else if (!(this.layers.get(layer + 2) instanceof BlankLayer)) { + } else if (!(this.layers.get(layer + 2) instanceof FCLayer)) { System.out.println("The next layer is not a BlankLayer"); } - ((BlankLayer) this.layers.get(layer)).addNeuron(n); - ((BlankLayer) this.layers.get(layer + 2)).updateInputSize(n); + ((FCLayer) this.layers.get(layer)).addNeuron(n); + ((FCLayer) this.layers.get(layer + 2)).updateInputSize(n); } public void use(BiFunction loss, BiFunction lossPrime) {