Compare commits

...

3 commits

7 changed files with 67 additions and 15 deletions

View file

@ -0,0 +1,8 @@
package de.lluni.javann;
public enum Initializer {
ZEROS,
ONES,
GAUSSIAN,
RANDOM
}

View file

@ -79,8 +79,9 @@ public class Network {
* @param y_train labels
* @param epochs amount of training iterations
* @param learningRate step size of gradient descent
* @param optimize if step size should decrease for each subsequent epoch
*/
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) {
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate, boolean optimize) {
int samples = X_train.length;
for (int i = 0; i < epochs; i++) {
@ -98,7 +99,11 @@ public class Network {
// backward propagation
SimpleMatrix error = lossPrime.apply(y_train[j], output);
for (int k = layers.size() - 1; k >= 0; k--) {
error = layers.get(k).backwardPropagation(error, learningRate / (i+1));
if (optimize) {
error = layers.get(k).backwardPropagation(error, learningRate / (i+1));
} else {
error = layers.get(k).backwardPropagation(error, learningRate);
}
}
}
// calculate average error on all samples

View file

@ -1,5 +1,6 @@
package de.lluni.javann.examples;
import de.lluni.javann.Initializer;
import de.lluni.javann.Network;
import de.lluni.javann.functions.ActivationFunctions;
import de.lluni.javann.functions.LossFunctions;
@ -46,17 +47,17 @@ public class ExampleSine {
// create network and add layers
Network network = new Network();
network.addLayer(new FCLayer(8));
network.addLayer(new FCLayer(8, Initializer.GAUSSIAN, Initializer.ONES));
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
network.addLayer(new FCLayer(8));
network.addLayer(new FCLayer(8, Initializer.GAUSSIAN, Initializer.ONES));
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
network.addLayer(new FCLayer(1));
network.addLayer(new FCLayer(1, Initializer.GAUSSIAN, Initializer.ONES));
// configure loss function for the network
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
// train network on X_train and y_train
network.fit(X_train, y_train, 100, 0.05d);
network.fit(X_train, y_train, 100, 0.05d, true);
// predict X_test and output results to console
SimpleMatrix[] output = network.predict(X_test);

View file

@ -1,5 +1,6 @@
package de.lluni.javann.examples;
import de.lluni.javann.Initializer;
import de.lluni.javann.Network;
import de.lluni.javann.functions.ActivationFunctions;
import de.lluni.javann.functions.LossFunctions;
@ -19,13 +20,13 @@ public class ExampleXOR {
new SimpleMatrix(new double[][]{{0}})};
Network network = new Network();
network.addLayer(new FCLayer(3));
network.addLayer(new FCLayer(3, Initializer.RANDOM, Initializer.RANDOM));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.addLayer(new FCLayer(1));
network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
network.fit(X_train, y_train, 1000, 0.1d);
network.fit(X_train, y_train, 1000, 0.1d, false);
SimpleMatrix[] output = network.predict(X_train);
for (SimpleMatrix entry : output) {

View file

@ -1,5 +1,6 @@
package de.lluni.javann.examples;
import de.lluni.javann.Initializer;
import de.lluni.javann.Network;
import de.lluni.javann.functions.ActivationFunctions;
import de.lluni.javann.functions.LossFunctions;
@ -19,14 +20,14 @@ public class ExampleXORAddedNeurons {
new SimpleMatrix(new double[][]{{0}})};
Network network = new Network();
network.addLayer(new FCLayer(1));
network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.addLayer(new FCLayer(1));
network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.addNeuron(0, 2);
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
network.fit(X_train, y_train, 1000, 0.1d);
network.fit(X_train, y_train, 1000, 0.1d, false);
SimpleMatrix[] output = network.predict(X_train);
for (SimpleMatrix entry : output) {

View file

@ -1,5 +1,6 @@
package de.lluni.javann.layers;
import de.lluni.javann.Initializer;
import de.lluni.javann.util.Utilities;
import org.ejml.simple.SimpleMatrix;
@ -9,6 +10,8 @@ import java.util.Random;
* Fully connected layer with n Neurons
*/
public class FCLayer extends Layer {
private final Initializer weightInit;
private final Initializer biasInit;
private SimpleMatrix weights;
private SimpleMatrix biases;
private int numNeurons;
@ -18,14 +21,26 @@ public class FCLayer extends Layer {
* Creates a fully connected layer with numNeurons neurons
* @param numNeurons amount of neurons in this layer
*/
public FCLayer(int numNeurons) {
public FCLayer(int numNeurons, Initializer weightInit, Initializer biasInit) {
this.numNeurons = numNeurons;
this.weightInit = weightInit;
this.biasInit = biasInit;
isInitialized = false;
}
private void initialize(int inputSize) {
this.weights = Utilities.gaussianMatrix(inputSize, numNeurons, 0, 1, 0.1d);
this.biases = Utilities.ones(1, numNeurons);
switch (weightInit) {
case ZEROS -> this.weights = new SimpleMatrix(inputSize, numNeurons);
case ONES -> this.weights = Utilities.ones(inputSize, numNeurons);
case GAUSSIAN -> this.weights = Utilities.gaussianMatrix(inputSize, numNeurons, 0, 1, 0.1d);
case RANDOM -> this.weights = Utilities.randomMatrix(inputSize, numNeurons);
}
switch (biasInit) {
case ZEROS -> this.biases = new SimpleMatrix(1, numNeurons);
case ONES -> this.biases = Utilities.ones(1, numNeurons);
case GAUSSIAN -> this.biases = Utilities.gaussianMatrix(1, numNeurons, 0, 1, 0.1d);
case RANDOM -> this.biases = Utilities.randomMatrix(1, numNeurons);
}
this.isInitialized = true;
}

View file

@ -14,6 +14,9 @@ import java.util.Random;
public class Utilities {
private static final double STANDARD_GAUSSIAN_FACTOR = 1.0d;
private static final double STANDARD_RANDOM_ORIGIN = -1.0d;
private static final double STANDARD_RANDOM_BOUND = 1.0d;
/**
* Creates a matrix filled with ones
* @param rows amount of rows
@ -50,6 +53,24 @@ public class Utilities {
return gaussianMatrix(rows, columns, mean, stddev, STANDARD_GAUSSIAN_FACTOR);
}
/**
* Creates a matrix with random values
* @param rows amount of rows
* @param columns amount of columns
* @param origin minimum random value
* @param bound maximum random value
* @return matrix with random values
*/
public static SimpleMatrix randomMatrix(int rows, int columns, double origin, double bound) {
Random random = new Random();
return new SimpleMatrix(rows, columns, true,
random.doubles((long) rows * columns, origin, bound).toArray());
}
public static SimpleMatrix randomMatrix(int rows, int columns) {
return randomMatrix(rows, columns, STANDARD_RANDOM_ORIGIN, STANDARD_RANDOM_BOUND);
}
/**
* Creates an array of evenly spaced values from the interval [start, end)
* @param start start value