From 1c66f1b72feba721e1f34bf9fcfaaa6d9f6db038 Mon Sep 17 00:00:00 2001 From: lluni Date: Wed, 25 May 2022 00:41:38 +0200 Subject: [PATCH] The inputSize of each layer does not need to be specified anymore --- src/main/java/ExampleXOR.java | 4 +- src/main/java/ExampleXORBlankLayers.java | 4 +- src/main/java/FCLayer.java | 93 ++++++++++++++---------- 3 files changed, 60 insertions(+), 41 deletions(-) diff --git a/src/main/java/ExampleXOR.java b/src/main/java/ExampleXOR.java index 8cb612d..b6f9ad2 100644 --- a/src/main/java/ExampleXOR.java +++ b/src/main/java/ExampleXOR.java @@ -12,9 +12,9 @@ public class ExampleXOR { new SimpleMatrix(new double[][]{{0}})}; Network network = new Network(); - network.addLayer(new FCLayer(2, 3)); + network.addLayer(new FCLayer(3)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); - network.addLayer(new FCLayer(3, 1)); + network.addLayer(new FCLayer(1)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.use(LossFunctions::MSE, LossFunctions::MSEPrime); diff --git a/src/main/java/ExampleXORBlankLayers.java b/src/main/java/ExampleXORBlankLayers.java index 253e712..5cc9c6c 100644 --- a/src/main/java/ExampleXORBlankLayers.java +++ b/src/main/java/ExampleXORBlankLayers.java @@ -12,9 +12,9 @@ public class ExampleXORBlankLayers { new SimpleMatrix(new double[][]{{0}})}; Network network = new Network(); - network.addLayer(new FCLayer(2, 1)); + network.addLayer(new FCLayer(1)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); - network.addLayer(new FCLayer(1, 1)); + network.addLayer(new FCLayer(1)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addNeuron(0, 2); diff --git a/src/main/java/FCLayer.java b/src/main/java/FCLayer.java index e517ced..038d7a9 100644 --- a/src/main/java/FCLayer.java +++ b/src/main/java/FCLayer.java @@ -3,15 +3,23 @@ import org.ejml.simple.SimpleMatrix; import java.util.Random; public class FCLayer extends Layer { - SimpleMatrix weights; - SimpleMatrix biases; + private SimpleMatrix weights; + private SimpleMatrix biases; + private int numNeurons; + private boolean isInitialized; - public FCLayer(int inputSize, int outputSize) { + public FCLayer(int numNeurons) { + this.numNeurons = numNeurons; + isInitialized = false; + } + + private void initialize(int inputSize) { Random random = new Random(); - weights = new SimpleMatrix(inputSize, outputSize, true, - random.doubles((long) inputSize*outputSize, -1, 1).toArray()); - biases = new SimpleMatrix(1, outputSize, true, - random.doubles(outputSize, -1, 1).toArray()); + this.weights = new SimpleMatrix(inputSize, numNeurons, true, + random.doubles((long) inputSize*numNeurons, -1, 1).toArray()); + this.biases = new SimpleMatrix(1, numNeurons, true, + random.doubles(numNeurons, -1, 1).toArray()); + this.isInitialized = true; } /** @@ -19,21 +27,23 @@ public class FCLayer extends Layer { * @param n amount of new neurons in previous layer */ public void updateInputSize(int n) { - Random random = new Random(); + if (isInitialized) { + Random random = new Random(); - // add new weights - SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows() + n, this.weights.numCols()); - for (int i = 0; i < this.weights.numRows(); i++) { - for (int j = 0; j < this.weights.numCols(); j++) { - newWeights.set(i, j, this.weights.get(i, j)); + // add new weights + SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows() + n, this.weights.numCols()); + for (int i = 0; i < this.weights.numRows(); i++) { + for (int j = 0; j < this.weights.numCols(); j++) { + newWeights.set(i, j, this.weights.get(i, j)); + } } - } - for (int i = 0; i < newWeights.getNumElements(); i++) { - if (newWeights.get(i) == 0) { - newWeights.set(i, random.nextDouble(-1, 1)); + for (int i = 0; i < newWeights.getNumElements(); i++) { + if (newWeights.get(i) == 0) { + newWeights.set(i, random.nextDouble(-1, 1)); + } } + this.weights = newWeights; } - this.weights = newWeights; } /** @@ -43,30 +53,39 @@ public class FCLayer extends Layer { public void addNeuron(int n) { Random random = new Random(); - // add new weights - SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n); - for (int i = 0; i < this.weights.numRows(); i++) { - for (int j = 0; j < this.weights.numCols(); j++) { - newWeights.set(i, j, this.weights.get(i, j)); - } - } - for (int i = 0; i < newWeights.getNumElements(); i++) { - if (newWeights.get(i) == 0) { - newWeights.set(i, random.nextDouble(-1, 1)); - } - } - this.weights = newWeights; + // update neuron count + this.numNeurons += n; - // add new biases - SimpleMatrix newBiases = new SimpleMatrix(1, this.biases.numCols() + n); - double[] newBiasValues = random.doubles(n, -1, 1).toArray(); - System.arraycopy(this.biases.getDDRM().data, 0, newBiases.getDDRM().data, 0, this.biases.numCols()); - System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n); - this.biases = newBiases; + if (isInitialized) { + // add new weights + SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n); + for (int i = 0; i < this.weights.numRows(); i++) { + for (int j = 0; j < this.weights.numCols(); j++) { + newWeights.set(i, j, this.weights.get(i, j)); + } + } + for (int i = 0; i < newWeights.getNumElements(); i++) { + if (newWeights.get(i) == 0) { + newWeights.set(i, random.nextDouble(-1, 1)); + } + } + this.weights = newWeights; + + // add new biases + SimpleMatrix newBiases = new SimpleMatrix(1, this.biases.numCols() + n); + double[] newBiasValues = random.doubles(n, -1, 1).toArray(); + System.arraycopy(this.biases.getDDRM().data, 0, newBiases.getDDRM().data, 0, this.biases.numCols()); + System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n); + this.biases = newBiases; + } } @Override public SimpleMatrix forwardPropagation(SimpleMatrix inputs) { + if (!isInitialized) { + initialize(inputs.numCols()); + } + this.input = inputs; this.output = this.input.mult(this.weights).plus(this.biases); return this.output;