Compare commits
No commits in common. "c7154817eefae4cb4ed94066280b6c9871aadb13" and "b20f62309b8ab693799816bcbcbf242ce3240ffa" have entirely different histories.
c7154817ee
...
b20f62309b
7 changed files with 15 additions and 67 deletions
|
@ -1,8 +0,0 @@
|
||||||
package de.lluni.javann;
|
|
||||||
|
|
||||||
public enum Initializer {
|
|
||||||
ZEROS,
|
|
||||||
ONES,
|
|
||||||
GAUSSIAN,
|
|
||||||
RANDOM
|
|
||||||
}
|
|
|
@ -79,9 +79,8 @@ public class Network {
|
||||||
* @param y_train labels
|
* @param y_train labels
|
||||||
* @param epochs amount of training iterations
|
* @param epochs amount of training iterations
|
||||||
* @param learningRate step size of gradient descent
|
* @param learningRate step size of gradient descent
|
||||||
* @param optimize if step size should decrease for each subsequent epoch
|
|
||||||
*/
|
*/
|
||||||
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate, boolean optimize) {
|
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) {
|
||||||
int samples = X_train.length;
|
int samples = X_train.length;
|
||||||
|
|
||||||
for (int i = 0; i < epochs; i++) {
|
for (int i = 0; i < epochs; i++) {
|
||||||
|
@ -99,11 +98,7 @@ public class Network {
|
||||||
// backward propagation
|
// backward propagation
|
||||||
SimpleMatrix error = lossPrime.apply(y_train[j], output);
|
SimpleMatrix error = lossPrime.apply(y_train[j], output);
|
||||||
for (int k = layers.size() - 1; k >= 0; k--) {
|
for (int k = layers.size() - 1; k >= 0; k--) {
|
||||||
if (optimize) {
|
|
||||||
error = layers.get(k).backwardPropagation(error, learningRate / (i+1));
|
error = layers.get(k).backwardPropagation(error, learningRate / (i+1));
|
||||||
} else {
|
|
||||||
error = layers.get(k).backwardPropagation(error, learningRate);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// calculate average error on all samples
|
// calculate average error on all samples
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
package de.lluni.javann.examples;
|
package de.lluni.javann.examples;
|
||||||
|
|
||||||
import de.lluni.javann.Initializer;
|
|
||||||
import de.lluni.javann.Network;
|
import de.lluni.javann.Network;
|
||||||
import de.lluni.javann.functions.ActivationFunctions;
|
import de.lluni.javann.functions.ActivationFunctions;
|
||||||
import de.lluni.javann.functions.LossFunctions;
|
import de.lluni.javann.functions.LossFunctions;
|
||||||
|
@ -47,17 +46,17 @@ public class ExampleSine {
|
||||||
|
|
||||||
// create network and add layers
|
// create network and add layers
|
||||||
Network network = new Network();
|
Network network = new Network();
|
||||||
network.addLayer(new FCLayer(8, Initializer.GAUSSIAN, Initializer.ONES));
|
network.addLayer(new FCLayer(8));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
||||||
network.addLayer(new FCLayer(8, Initializer.GAUSSIAN, Initializer.ONES));
|
network.addLayer(new FCLayer(8));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
||||||
network.addLayer(new FCLayer(1, Initializer.GAUSSIAN, Initializer.ONES));
|
network.addLayer(new FCLayer(1));
|
||||||
|
|
||||||
// configure loss function for the network
|
// configure loss function for the network
|
||||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||||
|
|
||||||
// train network on X_train and y_train
|
// train network on X_train and y_train
|
||||||
network.fit(X_train, y_train, 100, 0.05d, true);
|
network.fit(X_train, y_train, 100, 0.05d);
|
||||||
|
|
||||||
// predict X_test and output results to console
|
// predict X_test and output results to console
|
||||||
SimpleMatrix[] output = network.predict(X_test);
|
SimpleMatrix[] output = network.predict(X_test);
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
package de.lluni.javann.examples;
|
package de.lluni.javann.examples;
|
||||||
|
|
||||||
import de.lluni.javann.Initializer;
|
|
||||||
import de.lluni.javann.Network;
|
import de.lluni.javann.Network;
|
||||||
import de.lluni.javann.functions.ActivationFunctions;
|
import de.lluni.javann.functions.ActivationFunctions;
|
||||||
import de.lluni.javann.functions.LossFunctions;
|
import de.lluni.javann.functions.LossFunctions;
|
||||||
|
@ -20,13 +19,13 @@ public class ExampleXOR {
|
||||||
new SimpleMatrix(new double[][]{{0}})};
|
new SimpleMatrix(new double[][]{{0}})};
|
||||||
|
|
||||||
Network network = new Network();
|
Network network = new Network();
|
||||||
network.addLayer(new FCLayer(3, Initializer.RANDOM, Initializer.RANDOM));
|
network.addLayer(new FCLayer(3));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM));
|
network.addLayer(new FCLayer(1));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
|
|
||||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||||
network.fit(X_train, y_train, 1000, 0.1d, false);
|
network.fit(X_train, y_train, 1000, 0.1d);
|
||||||
|
|
||||||
SimpleMatrix[] output = network.predict(X_train);
|
SimpleMatrix[] output = network.predict(X_train);
|
||||||
for (SimpleMatrix entry : output) {
|
for (SimpleMatrix entry : output) {
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
package de.lluni.javann.examples;
|
package de.lluni.javann.examples;
|
||||||
|
|
||||||
import de.lluni.javann.Initializer;
|
|
||||||
import de.lluni.javann.Network;
|
import de.lluni.javann.Network;
|
||||||
import de.lluni.javann.functions.ActivationFunctions;
|
import de.lluni.javann.functions.ActivationFunctions;
|
||||||
import de.lluni.javann.functions.LossFunctions;
|
import de.lluni.javann.functions.LossFunctions;
|
||||||
|
@ -20,14 +19,14 @@ public class ExampleXORAddedNeurons {
|
||||||
new SimpleMatrix(new double[][]{{0}})};
|
new SimpleMatrix(new double[][]{{0}})};
|
||||||
|
|
||||||
Network network = new Network();
|
Network network = new Network();
|
||||||
network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM));
|
network.addLayer(new FCLayer(1));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
network.addLayer(new FCLayer(1, Initializer.RANDOM, Initializer.RANDOM));
|
network.addLayer(new FCLayer(1));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
network.addNeuron(0, 2);
|
network.addNeuron(0, 2);
|
||||||
|
|
||||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||||
network.fit(X_train, y_train, 1000, 0.1d, false);
|
network.fit(X_train, y_train, 1000, 0.1d);
|
||||||
|
|
||||||
SimpleMatrix[] output = network.predict(X_train);
|
SimpleMatrix[] output = network.predict(X_train);
|
||||||
for (SimpleMatrix entry : output) {
|
for (SimpleMatrix entry : output) {
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
package de.lluni.javann.layers;
|
package de.lluni.javann.layers;
|
||||||
|
|
||||||
import de.lluni.javann.Initializer;
|
|
||||||
import de.lluni.javann.util.Utilities;
|
import de.lluni.javann.util.Utilities;
|
||||||
import org.ejml.simple.SimpleMatrix;
|
import org.ejml.simple.SimpleMatrix;
|
||||||
|
|
||||||
|
@ -10,8 +9,6 @@ import java.util.Random;
|
||||||
* Fully connected layer with n Neurons
|
* Fully connected layer with n Neurons
|
||||||
*/
|
*/
|
||||||
public class FCLayer extends Layer {
|
public class FCLayer extends Layer {
|
||||||
private final Initializer weightInit;
|
|
||||||
private final Initializer biasInit;
|
|
||||||
private SimpleMatrix weights;
|
private SimpleMatrix weights;
|
||||||
private SimpleMatrix biases;
|
private SimpleMatrix biases;
|
||||||
private int numNeurons;
|
private int numNeurons;
|
||||||
|
@ -21,26 +18,14 @@ public class FCLayer extends Layer {
|
||||||
* Creates a fully connected layer with numNeurons neurons
|
* Creates a fully connected layer with numNeurons neurons
|
||||||
* @param numNeurons amount of neurons in this layer
|
* @param numNeurons amount of neurons in this layer
|
||||||
*/
|
*/
|
||||||
public FCLayer(int numNeurons, Initializer weightInit, Initializer biasInit) {
|
public FCLayer(int numNeurons) {
|
||||||
this.numNeurons = numNeurons;
|
this.numNeurons = numNeurons;
|
||||||
this.weightInit = weightInit;
|
|
||||||
this.biasInit = biasInit;
|
|
||||||
isInitialized = false;
|
isInitialized = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initialize(int inputSize) {
|
private void initialize(int inputSize) {
|
||||||
switch (weightInit) {
|
this.weights = Utilities.gaussianMatrix(inputSize, numNeurons, 0, 1, 0.1d);
|
||||||
case ZEROS -> this.weights = new SimpleMatrix(inputSize, numNeurons);
|
this.biases = Utilities.ones(1, numNeurons);
|
||||||
case ONES -> this.weights = Utilities.ones(inputSize, numNeurons);
|
|
||||||
case GAUSSIAN -> this.weights = Utilities.gaussianMatrix(inputSize, numNeurons, 0, 1, 0.1d);
|
|
||||||
case RANDOM -> this.weights = Utilities.randomMatrix(inputSize, numNeurons);
|
|
||||||
}
|
|
||||||
switch (biasInit) {
|
|
||||||
case ZEROS -> this.biases = new SimpleMatrix(1, numNeurons);
|
|
||||||
case ONES -> this.biases = Utilities.ones(1, numNeurons);
|
|
||||||
case GAUSSIAN -> this.biases = Utilities.gaussianMatrix(1, numNeurons, 0, 1, 0.1d);
|
|
||||||
case RANDOM -> this.biases = Utilities.randomMatrix(1, numNeurons);
|
|
||||||
}
|
|
||||||
this.isInitialized = true;
|
this.isInitialized = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,9 +14,6 @@ import java.util.Random;
|
||||||
public class Utilities {
|
public class Utilities {
|
||||||
private static final double STANDARD_GAUSSIAN_FACTOR = 1.0d;
|
private static final double STANDARD_GAUSSIAN_FACTOR = 1.0d;
|
||||||
|
|
||||||
private static final double STANDARD_RANDOM_ORIGIN = -1.0d;
|
|
||||||
private static final double STANDARD_RANDOM_BOUND = 1.0d;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a matrix filled with ones
|
* Creates a matrix filled with ones
|
||||||
* @param rows amount of rows
|
* @param rows amount of rows
|
||||||
|
@ -53,24 +50,6 @@ public class Utilities {
|
||||||
return gaussianMatrix(rows, columns, mean, stddev, STANDARD_GAUSSIAN_FACTOR);
|
return gaussianMatrix(rows, columns, mean, stddev, STANDARD_GAUSSIAN_FACTOR);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a matrix with random values
|
|
||||||
* @param rows amount of rows
|
|
||||||
* @param columns amount of columns
|
|
||||||
* @param origin minimum random value
|
|
||||||
* @param bound maximum random value
|
|
||||||
* @return matrix with random values
|
|
||||||
*/
|
|
||||||
public static SimpleMatrix randomMatrix(int rows, int columns, double origin, double bound) {
|
|
||||||
Random random = new Random();
|
|
||||||
return new SimpleMatrix(rows, columns, true,
|
|
||||||
random.doubles((long) rows * columns, origin, bound).toArray());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SimpleMatrix randomMatrix(int rows, int columns) {
|
|
||||||
return randomMatrix(rows, columns, STANDARD_RANDOM_ORIGIN, STANDARD_RANDOM_BOUND);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an array of evenly spaced values from the interval [start, end)
|
* Creates an array of evenly spaced values from the interval [start, end)
|
||||||
* @param start start value
|
* @param start start value
|
||||||
|
|
Loading…
Reference in a new issue