diff --git a/src/main/java/de/lluni/javann/Initializer.java b/src/main/java/de/lluni/javann/Initializer.java deleted file mode 100644 index 2458500..0000000 --- a/src/main/java/de/lluni/javann/Initializer.java +++ /dev/null @@ -1,8 +0,0 @@ -package de.lluni.javann; - -public enum Initializer { - ZEROS, - ONES, - GAUSSIAN, - RANDOM -} diff --git a/src/main/java/de/lluni/javann/Network.java b/src/main/java/de/lluni/javann/Network.java index d62648d..f03555b 100644 --- a/src/main/java/de/lluni/javann/Network.java +++ b/src/main/java/de/lluni/javann/Network.java @@ -79,9 +79,8 @@ public class Network { * @param y_train labels * @param epochs amount of training iterations * @param learningRate step size of gradient descent - * @param optimize if step size should decrease for each subsequent epoch */ - public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate, boolean optimize) { + public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) { int samples = X_train.length; for (int i = 0; i < epochs; i++) { @@ -99,11 +98,7 @@ public class Network { // backward propagation SimpleMatrix error = lossPrime.apply(y_train[j], output); for (int k = layers.size() - 1; k >= 0; k--) { - if (optimize) { - error = layers.get(k).backwardPropagation(error, learningRate / (i+1)); - } else { - error = layers.get(k).backwardPropagation(error, learningRate); - } + error = layers.get(k).backwardPropagation(error, learningRate / (i+1)); } } // calculate average error on all samples diff --git a/src/main/java/de/lluni/javann/examples/ExampleSine.java b/src/main/java/de/lluni/javann/examples/ExampleSine.java index 1572b0a..3dbddfb 100644 --- a/src/main/java/de/lluni/javann/examples/ExampleSine.java +++ b/src/main/java/de/lluni/javann/examples/ExampleSine.java @@ -1,6 +1,5 @@ package de.lluni.javann.examples; -import de.lluni.javann.Initializer; import de.lluni.javann.Network; import de.lluni.javann.functions.ActivationFunctions; import de.lluni.javann.functions.LossFunctions; @@ -47,17 +46,17 @@ public class ExampleSine { // create network and add layers Network network = new Network(); - network.addLayer(new FCLayer(8, Initializer.GAUSSIAN, Initializer.ONES)); + network.addLayer(new FCLayer(8)); network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime)); - network.addLayer(new FCLayer(8, Initializer.GAUSSIAN, Initializer.ONES)); + network.addLayer(new FCLayer(8)); network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime)); - network.addLayer(new FCLayer(1, Initializer.GAUSSIAN, Initializer.ONES)); + network.addLayer(new FCLayer(1)); // configure loss function for the network network.use(LossFunctions::MSE, LossFunctions::MSEPrime); // train network on X_train and y_train - network.fit(X_train, y_train, 100, 0.05d, true); + network.fit(X_train, y_train, 100, 0.05d); // predict X_test and output results to console SimpleMatrix[] output = network.predict(X_test); diff --git a/src/main/java/de/lluni/javann/examples/ExampleXOR.java b/src/main/java/de/lluni/javann/examples/ExampleXOR.java index 7f4d56f..0e42079 100644 --- a/src/main/java/de/lluni/javann/examples/ExampleXOR.java +++ b/src/main/java/de/lluni/javann/examples/ExampleXOR.java @@ -1,6 +1,5 @@ package de.lluni.javann.examples; -import de.lluni.javann.Initializer; import de.lluni.javann.Network; import de.lluni.javann.functions.ActivationFunctions; import de.lluni.javann.functions.LossFunctions; @@ -20,13 +19,13 @@ public class ExampleXOR { new SimpleMatrix(new double[][]{{0}})}; Network network = new Network(); - network.addLayer(new FCLayer(3, Initializer.RANDOM, Initializer.RANDOM)); + network.addLayer(new FCLayer(3)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); - network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM)); + network.addLayer(new FCLayer(1)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.use(LossFunctions::MSE, LossFunctions::MSEPrime); - network.fit(X_train, y_train, 1000, 0.1d, false); + network.fit(X_train, y_train, 1000, 0.1d); SimpleMatrix[] output = network.predict(X_train); for (SimpleMatrix entry : output) { diff --git a/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java b/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java index af8c31a..20d1da2 100644 --- a/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java +++ b/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java @@ -1,6 +1,5 @@ package de.lluni.javann.examples; -import de.lluni.javann.Initializer; import de.lluni.javann.Network; import de.lluni.javann.functions.ActivationFunctions; import de.lluni.javann.functions.LossFunctions; @@ -20,14 +19,14 @@ public class ExampleXORAddedNeurons { new SimpleMatrix(new double[][]{{0}})}; Network network = new Network(); - network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM)); + network.addLayer(new FCLayer(1)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); - network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM)); + network.addLayer(new FCLayer(1)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addNeuron(0, 2); network.use(LossFunctions::MSE, LossFunctions::MSEPrime); - network.fit(X_train, y_train, 1000, 0.1d, false); + network.fit(X_train, y_train, 1000, 0.1d); SimpleMatrix[] output = network.predict(X_train); for (SimpleMatrix entry : output) { diff --git a/src/main/java/de/lluni/javann/layers/FCLayer.java b/src/main/java/de/lluni/javann/layers/FCLayer.java index 8d2617f..2c78c5c 100644 --- a/src/main/java/de/lluni/javann/layers/FCLayer.java +++ b/src/main/java/de/lluni/javann/layers/FCLayer.java @@ -1,6 +1,5 @@ package de.lluni.javann.layers; -import de.lluni.javann.Initializer; import de.lluni.javann.util.Utilities; import org.ejml.simple.SimpleMatrix; @@ -10,8 +9,6 @@ import java.util.Random; * Fully connected layer with n Neurons */ public class FCLayer extends Layer { - private final Initializer weightInit; - private final Initializer biasInit; private SimpleMatrix weights; private SimpleMatrix biases; private int numNeurons; @@ -21,26 +18,14 @@ public class FCLayer extends Layer { * Creates a fully connected layer with numNeurons neurons * @param numNeurons amount of neurons in this layer */ - public FCLayer(int numNeurons, Initializer weightInit, Initializer biasInit) { + public FCLayer(int numNeurons) { this.numNeurons = numNeurons; - this.weightInit = weightInit; - this.biasInit = biasInit; isInitialized = false; } private void initialize(int inputSize) { - switch (weightInit) { - case ZEROS -> this.weights = new SimpleMatrix(inputSize, numNeurons); - case ONES -> this.weights = Utilities.ones(inputSize, numNeurons); - case GAUSSIAN -> this.weights = Utilities.gaussianMatrix(inputSize, numNeurons, 0, 1, 0.1d); - case RANDOM -> this.weights = Utilities.randomMatrix(inputSize, numNeurons); - } - switch (biasInit) { - case ZEROS -> this.biases = new SimpleMatrix(1, numNeurons); - case ONES -> this.biases = Utilities.ones(1, numNeurons); - case GAUSSIAN -> this.biases = Utilities.gaussianMatrix(1, numNeurons, 0, 1, 0.1d); - case RANDOM -> this.biases = Utilities.randomMatrix(1, numNeurons); - } + this.weights = Utilities.gaussianMatrix(inputSize, numNeurons, 0, 1, 0.1d); + this.biases = Utilities.ones(1, numNeurons); this.isInitialized = true; } diff --git a/src/main/java/de/lluni/javann/util/Utilities.java b/src/main/java/de/lluni/javann/util/Utilities.java index bd98b08..cadd566 100644 --- a/src/main/java/de/lluni/javann/util/Utilities.java +++ b/src/main/java/de/lluni/javann/util/Utilities.java @@ -14,9 +14,6 @@ import java.util.Random; public class Utilities { private static final double STANDARD_GAUSSIAN_FACTOR = 1.0d; - private static final double STANDARD_RANDOM_ORIGIN = -1.0d; - private static final double STANDARD_RANDOM_BOUND = 1.0d; - /** * Creates a matrix filled with ones * @param rows amount of rows @@ -53,24 +50,6 @@ public class Utilities { return gaussianMatrix(rows, columns, mean, stddev, STANDARD_GAUSSIAN_FACTOR); } - /** - * Creates a matrix with random values - * @param rows amount of rows - * @param columns amount of columns - * @param origin minimum random value - * @param bound maximum random value - * @return matrix with random values - */ - public static SimpleMatrix randomMatrix(int rows, int columns, double origin, double bound) { - Random random = new Random(); - return new SimpleMatrix(rows, columns, true, - random.doubles((long) rows * columns, origin, bound).toArray()); - } - - public static SimpleMatrix randomMatrix(int rows, int columns) { - return randomMatrix(rows, columns, STANDARD_RANDOM_ORIGIN, STANDARD_RANDOM_BOUND); - } - /** * Creates an array of evenly spaced values from the interval [start, end) * @param start start value