From faa547564c7daf3a9cd3f5e8a37fcfe9fdd6136d Mon Sep 17 00:00:00 2001 From: lluni Date: Sat, 28 May 2022 03:19:49 +0200 Subject: [PATCH] Added support for choosing weight and bias initializers --- .../java/de/lluni/javann/Initializer.java | 8 +++++++ .../de/lluni/javann/examples/ExampleSine.java | 7 ++++--- .../de/lluni/javann/examples/ExampleXOR.java | 5 +++-- .../examples/ExampleXORAddedNeurons.java | 5 +++-- .../java/de/lluni/javann/layers/FCLayer.java | 21 ++++++++++++++++--- 5 files changed, 36 insertions(+), 10 deletions(-) create mode 100644 src/main/java/de/lluni/javann/Initializer.java diff --git a/src/main/java/de/lluni/javann/Initializer.java b/src/main/java/de/lluni/javann/Initializer.java new file mode 100644 index 0000000..2458500 --- /dev/null +++ b/src/main/java/de/lluni/javann/Initializer.java @@ -0,0 +1,8 @@ +package de.lluni.javann; + +public enum Initializer { + ZEROS, + ONES, + GAUSSIAN, + RANDOM +} diff --git a/src/main/java/de/lluni/javann/examples/ExampleSine.java b/src/main/java/de/lluni/javann/examples/ExampleSine.java index 3dbddfb..c0701d2 100644 --- a/src/main/java/de/lluni/javann/examples/ExampleSine.java +++ b/src/main/java/de/lluni/javann/examples/ExampleSine.java @@ -1,5 +1,6 @@ package de.lluni.javann.examples; +import de.lluni.javann.Initializer; import de.lluni.javann.Network; import de.lluni.javann.functions.ActivationFunctions; import de.lluni.javann.functions.LossFunctions; @@ -46,11 +47,11 @@ public class ExampleSine { // create network and add layers Network network = new Network(); - network.addLayer(new FCLayer(8)); + network.addLayer(new FCLayer(8, Initializer.GAUSSIAN, Initializer.ONES)); network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime)); - network.addLayer(new FCLayer(8)); + network.addLayer(new FCLayer(8, Initializer.GAUSSIAN, Initializer.ONES)); network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime)); - network.addLayer(new FCLayer(1)); + network.addLayer(new FCLayer(1, Initializer.GAUSSIAN, Initializer.ONES)); // configure loss function for the network network.use(LossFunctions::MSE, LossFunctions::MSEPrime); diff --git a/src/main/java/de/lluni/javann/examples/ExampleXOR.java b/src/main/java/de/lluni/javann/examples/ExampleXOR.java index 0e42079..f647673 100644 --- a/src/main/java/de/lluni/javann/examples/ExampleXOR.java +++ b/src/main/java/de/lluni/javann/examples/ExampleXOR.java @@ -1,5 +1,6 @@ package de.lluni.javann.examples; +import de.lluni.javann.Initializer; import de.lluni.javann.Network; import de.lluni.javann.functions.ActivationFunctions; import de.lluni.javann.functions.LossFunctions; @@ -19,9 +20,9 @@ public class ExampleXOR { new SimpleMatrix(new double[][]{{0}})}; Network network = new Network(); - network.addLayer(new FCLayer(3)); + network.addLayer(new FCLayer(3, Initializer.RANDOM, Initializer.RANDOM)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); - network.addLayer(new FCLayer(1)); + network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.use(LossFunctions::MSE, LossFunctions::MSEPrime); diff --git a/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java b/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java index 20d1da2..b19ab58 100644 --- a/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java +++ b/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java @@ -1,5 +1,6 @@ package de.lluni.javann.examples; +import de.lluni.javann.Initializer; import de.lluni.javann.Network; import de.lluni.javann.functions.ActivationFunctions; import de.lluni.javann.functions.LossFunctions; @@ -19,9 +20,9 @@ public class ExampleXORAddedNeurons { new SimpleMatrix(new double[][]{{0}})}; Network network = new Network(); - network.addLayer(new FCLayer(1)); + network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); - network.addLayer(new FCLayer(1)); + network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addNeuron(0, 2); diff --git a/src/main/java/de/lluni/javann/layers/FCLayer.java b/src/main/java/de/lluni/javann/layers/FCLayer.java index 2c78c5c..8d2617f 100644 --- a/src/main/java/de/lluni/javann/layers/FCLayer.java +++ b/src/main/java/de/lluni/javann/layers/FCLayer.java @@ -1,5 +1,6 @@ package de.lluni.javann.layers; +import de.lluni.javann.Initializer; import de.lluni.javann.util.Utilities; import org.ejml.simple.SimpleMatrix; @@ -9,6 +10,8 @@ import java.util.Random; * Fully connected layer with n Neurons */ public class FCLayer extends Layer { + private final Initializer weightInit; + private final Initializer biasInit; private SimpleMatrix weights; private SimpleMatrix biases; private int numNeurons; @@ -18,14 +21,26 @@ public class FCLayer extends Layer { * Creates a fully connected layer with numNeurons neurons * @param numNeurons amount of neurons in this layer */ - public FCLayer(int numNeurons) { + public FCLayer(int numNeurons, Initializer weightInit, Initializer biasInit) { this.numNeurons = numNeurons; + this.weightInit = weightInit; + this.biasInit = biasInit; isInitialized = false; } private void initialize(int inputSize) { - this.weights = Utilities.gaussianMatrix(inputSize, numNeurons, 0, 1, 0.1d); - this.biases = Utilities.ones(1, numNeurons); + switch (weightInit) { + case ZEROS -> this.weights = new SimpleMatrix(inputSize, numNeurons); + case ONES -> this.weights = Utilities.ones(inputSize, numNeurons); + case GAUSSIAN -> this.weights = Utilities.gaussianMatrix(inputSize, numNeurons, 0, 1, 0.1d); + case RANDOM -> this.weights = Utilities.randomMatrix(inputSize, numNeurons); + } + switch (biasInit) { + case ZEROS -> this.biases = new SimpleMatrix(1, numNeurons); + case ONES -> this.biases = Utilities.ones(1, numNeurons); + case GAUSSIAN -> this.biases = Utilities.gaussianMatrix(1, numNeurons, 0, 1, 0.1d); + case RANDOM -> this.biases = Utilities.randomMatrix(1, numNeurons); + } this.isInitialized = true; }