diff --git a/src/main/java/BlankLayer.java b/src/main/java/BlankLayer.java new file mode 100644 index 0000000..2dcfd00 --- /dev/null +++ b/src/main/java/BlankLayer.java @@ -0,0 +1,88 @@ +import org.ejml.simple.SimpleMatrix; + +import java.util.Random; + +/** + * Layer initialized with 1 neuron. + * Assumes that each new neuron is fully connected to every previous neuron (this will be changed in the future). + */ +public class BlankLayer extends Layer { + SimpleMatrix weights; + SimpleMatrix biases; + + public BlankLayer(int inputSize) { + Random random = new Random(); + this.weights = new SimpleMatrix(inputSize, 1, true, + random.doubles(inputSize, -1, 1).toArray()); + this.biases = new SimpleMatrix(1, 1, true, + random.doubles(1, -1, 1).toArray()); + } + + /** + * Updates input size when previous layer has newly added neurons. + * @param n amount of new neurons in previous layer + */ + public void updateInputSize(int n) { + Random random = new Random(); + + // add new weights + SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows() + n, this.weights.numCols()); + for (int i = 0; i < this.weights.numRows(); i++) { + for (int j = 0; j < this.weights.numCols(); j++) { + newWeights.set(i, j, this.weights.get(i, j)); + } + } + for (int i = 0; i < newWeights.getNumElements(); i++) { + if (newWeights.get(i) == 0) { + newWeights.set(i, random.nextDouble(-1, 1)); + } + } + this.weights = newWeights; + } + + /** + * Adds new neurons at the end of the layer + * @param n amount how many new neurons should be added + */ + public void addNeuron(int n) { + Random random = new Random(); + + // add new weights + SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n); + for (int i = 0; i < this.weights.numRows(); i++) { + for (int j = 0; j < this.weights.numCols(); j++) { + newWeights.set(i, j, this.weights.get(i, j)); + } + } + for (int i = 0; i < newWeights.getNumElements(); i++) { + if (newWeights.get(i) == 0) { + newWeights.set(i, random.nextDouble(-1, 1)); + } + } + this.weights = newWeights; + + // add new biases + SimpleMatrix newBiases = new SimpleMatrix(1, this.biases.numCols() + n); + double[] newBiasValues = random.doubles(n, -1, 1).toArray(); + System.arraycopy(this.biases.getDDRM().data, 0, newBiases.getDDRM().data, 0, this.biases.numCols()); + System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n); + this.biases = newBiases; + } + + @Override + public SimpleMatrix forwardPropagation(SimpleMatrix inputs) { + this.input = inputs; + this.output = this.input.mult(this.weights).plus(this.biases); + return this.output; + } + + @Override + public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) { + SimpleMatrix inputError = outputError.mult(this.weights.transpose()); + SimpleMatrix weightsError = this.input.transpose().mult(outputError); + + this.weights = this.weights.plus(learningRate, weightsError); + this.biases = this.biases.plus(learningRate, outputError); + return inputError; + } +} diff --git a/src/main/java/ExampleXOR.java b/src/main/java/ExampleXOR.java index c08dde3..8cb612d 100644 --- a/src/main/java/ExampleXOR.java +++ b/src/main/java/ExampleXOR.java @@ -12,10 +12,10 @@ public class ExampleXOR { new SimpleMatrix(new double[][]{{0}})}; Network network = new Network(); - network.add(new FCLayer(2, 3)); - network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); - network.add(new FCLayer(3, 1)); - network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); + network.addLayer(new FCLayer(2, 3)); + network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); + network.addLayer(new FCLayer(3, 1)); + network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.use(LossFunctions::MSE, LossFunctions::MSEPrime); network.fit(X_train, y_train, 1000, 0.1d); diff --git a/src/main/java/Network.java b/src/main/java/Network.java index e330de2..565a1c0 100644 --- a/src/main/java/Network.java +++ b/src/main/java/Network.java @@ -4,19 +4,34 @@ import java.util.ArrayList; import java.util.function.BiFunction; public class Network { - - ArrayList layers; - BiFunction loss; - BiFunction lossPrime; + private ArrayList layers; + private BiFunction loss; + private BiFunction lossPrime; public Network() { layers = new ArrayList<>(); } - public void add(Layer layer) { + public void addLayer(Layer layer) { layers.add(layer); } + /** + * Adds n neurons to a specific layer and also updates this and the next layer's weights and biases. + * Only works if there are two successive BlankLayers. + * @param layer index of layer in the ArrayList layers + * @param n amount how many new neurons should be added + */ + public void addNeuron(int layer, int n) { + if (!(this.layers.get(layer) instanceof BlankLayer)) { + System.out.println("This layer is not a BlankLayer"); + } else if (!(this.layers.get(layer + 2) instanceof BlankLayer)) { + System.out.println("The next layer is not a BlankLayer"); + } + ((BlankLayer) this.layers.get(layer)).addNeuron(n); + ((BlankLayer) this.layers.get(layer + 2)).updateInputSize(n); + } + public void use(BiFunction loss, BiFunction lossPrime) { this.loss = loss; this.lossPrime = lossPrime; @@ -65,4 +80,28 @@ public class Network { System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err); } } + + public ArrayList getLayers() { + return layers; + } + + public void setLayers(ArrayList layers) { + this.layers = layers; + } + + public BiFunction getLoss() { + return loss; + } + + public void setLoss(BiFunction loss) { + this.loss = loss; + } + + public BiFunction getLossPrime() { + return lossPrime; + } + + public void setLossPrime(BiFunction lossPrime) { + this.lossPrime = lossPrime; + } } \ No newline at end of file