Added support for choosing weight and bias initializers
This commit is contained in:
parent
e7de925373
commit
faa547564c
5 changed files with 36 additions and 10 deletions
8
src/main/java/de/lluni/javann/Initializer.java
Normal file
8
src/main/java/de/lluni/javann/Initializer.java
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
package de.lluni.javann;
|
||||||
|
|
||||||
|
public enum Initializer {
|
||||||
|
ZEROS,
|
||||||
|
ONES,
|
||||||
|
GAUSSIAN,
|
||||||
|
RANDOM
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
package de.lluni.javann.examples;
|
package de.lluni.javann.examples;
|
||||||
|
|
||||||
|
import de.lluni.javann.Initializer;
|
||||||
import de.lluni.javann.Network;
|
import de.lluni.javann.Network;
|
||||||
import de.lluni.javann.functions.ActivationFunctions;
|
import de.lluni.javann.functions.ActivationFunctions;
|
||||||
import de.lluni.javann.functions.LossFunctions;
|
import de.lluni.javann.functions.LossFunctions;
|
||||||
|
@ -46,11 +47,11 @@ public class ExampleSine {
|
||||||
|
|
||||||
// create network and add layers
|
// create network and add layers
|
||||||
Network network = new Network();
|
Network network = new Network();
|
||||||
network.addLayer(new FCLayer(8));
|
network.addLayer(new FCLayer(8, Initializer.GAUSSIAN, Initializer.ONES));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
||||||
network.addLayer(new FCLayer(8));
|
network.addLayer(new FCLayer(8, Initializer.GAUSSIAN, Initializer.ONES));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
||||||
network.addLayer(new FCLayer(1));
|
network.addLayer(new FCLayer(1, Initializer.GAUSSIAN, Initializer.ONES));
|
||||||
|
|
||||||
// configure loss function for the network
|
// configure loss function for the network
|
||||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package de.lluni.javann.examples;
|
package de.lluni.javann.examples;
|
||||||
|
|
||||||
|
import de.lluni.javann.Initializer;
|
||||||
import de.lluni.javann.Network;
|
import de.lluni.javann.Network;
|
||||||
import de.lluni.javann.functions.ActivationFunctions;
|
import de.lluni.javann.functions.ActivationFunctions;
|
||||||
import de.lluni.javann.functions.LossFunctions;
|
import de.lluni.javann.functions.LossFunctions;
|
||||||
|
@ -19,9 +20,9 @@ public class ExampleXOR {
|
||||||
new SimpleMatrix(new double[][]{{0}})};
|
new SimpleMatrix(new double[][]{{0}})};
|
||||||
|
|
||||||
Network network = new Network();
|
Network network = new Network();
|
||||||
network.addLayer(new FCLayer(3));
|
network.addLayer(new FCLayer(3, Initializer.RANDOM, Initializer.RANDOM));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
network.addLayer(new FCLayer(1));
|
network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
|
|
||||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package de.lluni.javann.examples;
|
package de.lluni.javann.examples;
|
||||||
|
|
||||||
|
import de.lluni.javann.Initializer;
|
||||||
import de.lluni.javann.Network;
|
import de.lluni.javann.Network;
|
||||||
import de.lluni.javann.functions.ActivationFunctions;
|
import de.lluni.javann.functions.ActivationFunctions;
|
||||||
import de.lluni.javann.functions.LossFunctions;
|
import de.lluni.javann.functions.LossFunctions;
|
||||||
|
@ -19,9 +20,9 @@ public class ExampleXORAddedNeurons {
|
||||||
new SimpleMatrix(new double[][]{{0}})};
|
new SimpleMatrix(new double[][]{{0}})};
|
||||||
|
|
||||||
Network network = new Network();
|
Network network = new Network();
|
||||||
network.addLayer(new FCLayer(1));
|
network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
network.addLayer(new FCLayer(1));
|
network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
network.addNeuron(0, 2);
|
network.addNeuron(0, 2);
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package de.lluni.javann.layers;
|
package de.lluni.javann.layers;
|
||||||
|
|
||||||
|
import de.lluni.javann.Initializer;
|
||||||
import de.lluni.javann.util.Utilities;
|
import de.lluni.javann.util.Utilities;
|
||||||
import org.ejml.simple.SimpleMatrix;
|
import org.ejml.simple.SimpleMatrix;
|
||||||
|
|
||||||
|
@ -9,6 +10,8 @@ import java.util.Random;
|
||||||
* Fully connected layer with n Neurons
|
* Fully connected layer with n Neurons
|
||||||
*/
|
*/
|
||||||
public class FCLayer extends Layer {
|
public class FCLayer extends Layer {
|
||||||
|
private final Initializer weightInit;
|
||||||
|
private final Initializer biasInit;
|
||||||
private SimpleMatrix weights;
|
private SimpleMatrix weights;
|
||||||
private SimpleMatrix biases;
|
private SimpleMatrix biases;
|
||||||
private int numNeurons;
|
private int numNeurons;
|
||||||
|
@ -18,14 +21,26 @@ public class FCLayer extends Layer {
|
||||||
* Creates a fully connected layer with numNeurons neurons
|
* Creates a fully connected layer with numNeurons neurons
|
||||||
* @param numNeurons amount of neurons in this layer
|
* @param numNeurons amount of neurons in this layer
|
||||||
*/
|
*/
|
||||||
public FCLayer(int numNeurons) {
|
public FCLayer(int numNeurons, Initializer weightInit, Initializer biasInit) {
|
||||||
this.numNeurons = numNeurons;
|
this.numNeurons = numNeurons;
|
||||||
|
this.weightInit = weightInit;
|
||||||
|
this.biasInit = biasInit;
|
||||||
isInitialized = false;
|
isInitialized = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initialize(int inputSize) {
|
private void initialize(int inputSize) {
|
||||||
this.weights = Utilities.gaussianMatrix(inputSize, numNeurons, 0, 1, 0.1d);
|
switch (weightInit) {
|
||||||
this.biases = Utilities.ones(1, numNeurons);
|
case ZEROS -> this.weights = new SimpleMatrix(inputSize, numNeurons);
|
||||||
|
case ONES -> this.weights = Utilities.ones(inputSize, numNeurons);
|
||||||
|
case GAUSSIAN -> this.weights = Utilities.gaussianMatrix(inputSize, numNeurons, 0, 1, 0.1d);
|
||||||
|
case RANDOM -> this.weights = Utilities.randomMatrix(inputSize, numNeurons);
|
||||||
|
}
|
||||||
|
switch (biasInit) {
|
||||||
|
case ZEROS -> this.biases = new SimpleMatrix(1, numNeurons);
|
||||||
|
case ONES -> this.biases = Utilities.ones(1, numNeurons);
|
||||||
|
case GAUSSIAN -> this.biases = Utilities.gaussianMatrix(1, numNeurons, 0, 1, 0.1d);
|
||||||
|
case RANDOM -> this.biases = Utilities.randomMatrix(1, numNeurons);
|
||||||
|
}
|
||||||
this.isInitialized = true;
|
this.isInitialized = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue