Compare commits
No commits in common. "74e4d05fa118bb248be9dfa7b1f9c1ba96165357" and "281b42b0fb288643b1f646313659b3b98b0250b0" have entirely different histories.
74e4d05fa1
...
281b42b0fb
14 changed files with 9 additions and 126 deletions
2
gradlew
vendored
2
gradlew
vendored
|
@ -31,7 +31,7 @@
|
|||
#
|
||||
# Busybox and similar reduced shells will NOT work, because this script
|
||||
# requires all of these POSIX shell features:
|
||||
# * de.lluni.javann.functions;
|
||||
# * functions;
|
||||
# * expansions «$var», «${var}», «${var:-default}», «${var+SET}»,
|
||||
# «${var#prefix}», «${var%suffix}», and «$( cmd )»;
|
||||
# * compound commands having a testable exit status, especially «case»;
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
package de.lluni.javann.functions;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
public class ActivationFunctions {
|
||||
|
@ -54,20 +52,4 @@ public class ActivationFunctions {
|
|||
}
|
||||
return B;
|
||||
}
|
||||
|
||||
public static SimpleMatrix LeakyReLu(SimpleMatrix A) {
|
||||
SimpleMatrix B = new SimpleMatrix(A);
|
||||
for (int i = 0; i < A.getNumElements(); i++) {
|
||||
B.set(i, Math.max(0.001 * A.get(i), A.get(i)));
|
||||
}
|
||||
return B;
|
||||
}
|
||||
|
||||
public static SimpleMatrix LeakyReLuPrime(SimpleMatrix A) {
|
||||
SimpleMatrix B = new SimpleMatrix(A);
|
||||
for (int i = 0; i < A.getNumElements(); i++) {
|
||||
B.set(i, A.get(i) < 0 ? 0.001 : 1);
|
||||
}
|
||||
return B;
|
||||
}
|
||||
}
|
|
@ -1,5 +1,3 @@
|
|||
package de.lluni.javann.layers;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import java.util.function.Function;
|
|
@ -1,12 +1,10 @@
|
|||
package de.lluni.javann.layers;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* Goal: initialize layer without any neurons. Not yet implemented.
|
||||
* de.lluni.javann.layers.Layer initialized with 1 neuron.
|
||||
* Layer initialized with 1 neuron.
|
||||
* Assumes that each new neuron is fully connected to every previous neuron (this will be changed in the future).
|
||||
*/
|
||||
public class BlankLayer extends Layer {
|
|
@ -1,6 +1,3 @@
|
|||
package de.lluni.javann.examples;
|
||||
|
||||
import de.lluni.javann.util.GradientDescent;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import java.util.function.Function;
|
|
@ -1,10 +1,3 @@
|
|||
package de.lluni.javann.examples;
|
||||
|
||||
import de.lluni.javann.Network;
|
||||
import de.lluni.javann.functions.ActivationFunctions;
|
||||
import de.lluni.javann.functions.LossFunctions;
|
||||
import de.lluni.javann.layers.ActivationLayer;
|
||||
import de.lluni.javann.layers.FCLayer;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
public class ExampleXOR {
|
|
@ -1,10 +1,3 @@
|
|||
package de.lluni.javann.examples;
|
||||
|
||||
import de.lluni.javann.Network;
|
||||
import de.lluni.javann.functions.ActivationFunctions;
|
||||
import de.lluni.javann.functions.LossFunctions;
|
||||
import de.lluni.javann.layers.ActivationLayer;
|
||||
import de.lluni.javann.layers.FCLayer;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
public class ExampleXORBlankLayers {
|
|
@ -1,6 +1,3 @@
|
|||
package de.lluni.javann.layers;
|
||||
|
||||
import de.lluni.javann.util.Utilities;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import java.util.Random;
|
||||
|
@ -17,8 +14,11 @@ public class FCLayer extends Layer {
|
|||
}
|
||||
|
||||
private void initialize(int inputSize) {
|
||||
this.weights = Utilities.gaussianMatrix(inputSize, numNeurons, 0, 1, 0.1d);
|
||||
this.biases = Utilities.ones(1, numNeurons);
|
||||
Random random = new Random();
|
||||
this.weights = new SimpleMatrix(inputSize, numNeurons, true,
|
||||
random.doubles((long) inputSize*numNeurons, -1, 1).toArray());
|
||||
this.biases = new SimpleMatrix(1, numNeurons, true,
|
||||
random.doubles(numNeurons, -1, 1).toArray());
|
||||
this.isInitialized = true;
|
||||
}
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
package de.lluni.javann.util;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import java.util.function.Function;
|
|
@ -1,5 +1,3 @@
|
|||
package de.lluni.javann.layers;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
public abstract class Layer {
|
|
@ -1,5 +1,3 @@
|
|||
package de.lluni.javann.functions;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
public class LossFunctions {
|
|
@ -1,7 +1,3 @@
|
|||
package de.lluni.javann;
|
||||
|
||||
import de.lluni.javann.layers.FCLayer;
|
||||
import de.lluni.javann.layers.Layer;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -28,9 +24,9 @@ public class Network {
|
|||
*/
|
||||
public void addNeuron(int layer, int n) {
|
||||
if (!(this.layers.get(layer) instanceof FCLayer)) {
|
||||
System.out.println("This layer is not a de.lluni.javann.layers.BlankLayer");
|
||||
System.out.println("This layer is not a BlankLayer");
|
||||
} else if (!(this.layers.get(layer + 2) instanceof FCLayer)) {
|
||||
System.out.println("The next layer is not a de.lluni.javann.layers.BlankLayer");
|
||||
System.out.println("The next layer is not a BlankLayer");
|
||||
}
|
||||
((FCLayer) this.layers.get(layer)).addNeuron(n);
|
||||
((FCLayer) this.layers.get(layer + 2)).updateInputSize(n);
|
|
@ -1,5 +1,3 @@
|
|||
package de.lluni.javann.util;
|
||||
|
||||
import com.opencsv.CSVReader;
|
||||
import com.opencsv.exceptions.CsvValidationException;
|
||||
import org.ejml.simple.SimpleMatrix;
|
|
@ -1,66 +0,0 @@
|
|||
package de.lluni.javann.examples;
|
||||
|
||||
import de.lluni.javann.Network;
|
||||
import de.lluni.javann.functions.ActivationFunctions;
|
||||
import de.lluni.javann.functions.LossFunctions;
|
||||
import de.lluni.javann.layers.ActivationLayer;
|
||||
import de.lluni.javann.layers.FCLayer;
|
||||
import de.lluni.javann.util.Utilities;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.knowm.xchart.SwingWrapper;
|
||||
import org.knowm.xchart.XYChart;
|
||||
import org.knowm.xchart.XYChartBuilder;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
public class ExampleSine {
|
||||
private static final int TRAINING_SIZE = 100000;
|
||||
private static final int TEST_SIZE = 1000;
|
||||
|
||||
public static void main(String[] args) {
|
||||
SimpleMatrix[] X_train = new SimpleMatrix[TRAINING_SIZE];
|
||||
SimpleMatrix[] y_train = new SimpleMatrix[TRAINING_SIZE];
|
||||
double[] X_test_linspace = Utilities.linspace(0, 2 * Math.PI, TEST_SIZE);
|
||||
double[] y_test_true = new double[TEST_SIZE];
|
||||
double[] y_test_pred = new double[TEST_SIZE];
|
||||
SimpleMatrix[] X_test = new SimpleMatrix[TEST_SIZE];
|
||||
SimpleMatrix[] y_test = new SimpleMatrix[TEST_SIZE];
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < TRAINING_SIZE; i++) {
|
||||
double temp = random.nextDouble(0, 2 * Math.PI);
|
||||
X_train[i] = new SimpleMatrix(new double[][]{{temp}});
|
||||
y_train[i] = new SimpleMatrix(new double[][]{{Math.sin(temp)}});
|
||||
}
|
||||
for (int i = 0; i < TEST_SIZE; i++) {
|
||||
X_test[i] = new SimpleMatrix(new double[][]{{X_test_linspace[i]}});
|
||||
y_test[i] = new SimpleMatrix(new double[][]{{Math.sin(X_test_linspace[i])}});
|
||||
y_test_true[i] = Math.sin(X_test_linspace[i]);
|
||||
}
|
||||
|
||||
// create network and add layers
|
||||
Network network = new Network();
|
||||
network.addLayer(new FCLayer(8));
|
||||
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
||||
network.addLayer(new FCLayer(8));
|
||||
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
||||
network.addLayer(new FCLayer(1));
|
||||
|
||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||
network.fit(X_train, y_train, 100, 0.05d);
|
||||
|
||||
SimpleMatrix[] output = network.predict(X_test);
|
||||
for (int i = 0; i < output.length; i++) {
|
||||
y_test_pred[i] = output[i].get(0);
|
||||
System.out.println("Prediction for x=" + X_test[i].get(0) + " (correct value: " + y_test[i].get(0) + "):");
|
||||
for (int j = 0; j < output[i].getNumElements(); j++) {
|
||||
System.out.println(output[i].get(j));
|
||||
}
|
||||
System.out.println();
|
||||
}
|
||||
|
||||
XYChart chart = new XYChartBuilder().title("sin(x) predictions").xAxisTitle("x").yAxisTitle("y").build();
|
||||
chart.addSeries("sin(x) true", X_test_linspace, y_test_true);
|
||||
chart.addSeries("sin(x) predictions", X_test_linspace, y_test_pred);
|
||||
new SwingWrapper<>(chart).displayChart();
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue