diff --git a/src/main/java/de/lluni/javann/examples/ExampleSine.java b/src/main/java/de/lluni/javann/examples/ExampleSine.java index 52edae9..3dbddfb 100644 --- a/src/main/java/de/lluni/javann/examples/ExampleSine.java +++ b/src/main/java/de/lluni/javann/examples/ExampleSine.java @@ -18,19 +18,26 @@ public class ExampleSine { private static final int TEST_SIZE = 1000; public static void main(String[] args) { + // training set SimpleMatrix[] X_train = new SimpleMatrix[TRAINING_SIZE]; SimpleMatrix[] y_train = new SimpleMatrix[TRAINING_SIZE]; + // test set + SimpleMatrix[] X_test = new SimpleMatrix[TEST_SIZE]; + SimpleMatrix[] y_test = new SimpleMatrix[TEST_SIZE]; + // identical test set for plotting double[] X_test_linspace = Utilities.linspace(0, 2 * Math.PI, TEST_SIZE); double[] y_test_true = new double[TEST_SIZE]; double[] y_test_pred = new double[TEST_SIZE]; - SimpleMatrix[] X_test = new SimpleMatrix[TEST_SIZE]; - SimpleMatrix[] y_test = new SimpleMatrix[TEST_SIZE]; + Random random = new Random(); + + // generate training set from random data for (int i = 0; i < TRAINING_SIZE; i++) { double temp = random.nextDouble(0, 2 * Math.PI); X_train[i] = new SimpleMatrix(new double[][]{{temp}}); y_train[i] = new SimpleMatrix(new double[][]{{Math.sin(temp)}}); } + // generate test set for (int i = 0; i < TEST_SIZE; i++) { X_test[i] = new SimpleMatrix(new double[][]{{X_test_linspace[i]}}); y_test[i] = new SimpleMatrix(new double[][]{{Math.sin(X_test_linspace[i])}}); @@ -45,9 +52,13 @@ public class ExampleSine { network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime)); network.addLayer(new FCLayer(1)); + // configure loss function for the network network.use(LossFunctions::MSE, LossFunctions::MSEPrime); + + // train network on X_train and y_train network.fit(X_train, y_train, 100, 0.05d); + // predict X_test and output results to console SimpleMatrix[] output = network.predict(X_test); for (int i = 0; i < output.length; i++) { y_test_pred[i] = output[i].get(0); @@ -58,6 +69,7 @@ public class ExampleSine { System.out.println(); } + // create and display chart XYChart chart = new XYChartBuilder().title("sin(x) predictions").xAxisTitle("x").yAxisTitle("y").build(); chart.addSeries("sin(x) true", X_test_linspace, y_test_true); chart.addSeries("sin(x) predictions", X_test_linspace, y_test_pred);