From b53328b41cead5b10f51794df1c31314812b1657 Mon Sep 17 00:00:00 2001 From: lluni Date: Wed, 25 May 2022 17:28:29 +0200 Subject: [PATCH] Added sine example --- src/main/java/ExampleSine.java | 58 ++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 src/main/java/ExampleSine.java diff --git a/src/main/java/ExampleSine.java b/src/main/java/ExampleSine.java new file mode 100644 index 0000000..653c084 --- /dev/null +++ b/src/main/java/ExampleSine.java @@ -0,0 +1,58 @@ +import org.ejml.simple.SimpleMatrix; +import org.knowm.xchart.SwingWrapper; +import org.knowm.xchart.XYChart; +import org.knowm.xchart.XYChartBuilder; + +import java.util.Random; + +public class ExampleSine { + private static final int TRAINING_SIZE = 100000; + private static final int TEST_SIZE = 1000; + + public static void main(String[] args) { + SimpleMatrix[] X_train = new SimpleMatrix[TRAINING_SIZE]; + SimpleMatrix[] y_train = new SimpleMatrix[TRAINING_SIZE]; + double[] X_test_linspace = Utilities.linspace(0, 2 * Math.PI, TEST_SIZE); + double[] y_test_true = new double[TEST_SIZE]; + double[] y_test_pred = new double[TEST_SIZE]; + SimpleMatrix[] X_test = new SimpleMatrix[TEST_SIZE]; + SimpleMatrix[] y_test = new SimpleMatrix[TEST_SIZE]; + Random random = new Random(); + for (int i = 0; i < TRAINING_SIZE; i++) { + double temp = random.nextDouble(0, 2 * Math.PI); + X_train[i] = new SimpleMatrix(new double[][]{{temp}}); + y_train[i] = new SimpleMatrix(new double[][]{{Math.sin(temp)}}); + } + for (int i = 0; i < TEST_SIZE; i++) { + X_test[i] = new SimpleMatrix(new double[][]{{X_test_linspace[i]}}); + y_test[i] = new SimpleMatrix(new double[][]{{Math.sin(X_test_linspace[i])}}); + y_test_true[i] = Math.sin(X_test_linspace[i]); + } + + // create network and add layers + Network network = new Network(); + network.addLayer(new FCLayer(8)); + network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime)); + network.addLayer(new FCLayer(8)); + network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime)); + network.addLayer(new FCLayer(1)); + + network.use(LossFunctions::MSE, LossFunctions::MSEPrime); + network.fit(X_train, y_train, 100, 0.05d); + + SimpleMatrix[] output = network.predict(X_test); + for (int i = 0; i < output.length; i++) { + y_test_pred[i] = output[i].get(0); + System.out.println("Prediction for x=" + X_test[i].get(0) + " (correct value: " + y_test[i].get(0) + "):"); + for (int j = 0; j < output[i].getNumElements(); j++) { + System.out.println(output[i].get(j)); + } + System.out.println(); + } + + XYChart chart = new XYChartBuilder().title("sin(x) predictions").xAxisTitle("x").yAxisTitle("y").build(); + chart.addSeries("sin(x) true", X_test_linspace, y_test_true); + chart.addSeries("sin(x) predictions", X_test_linspace, y_test_pred); + new SwingWrapper<>(chart).displayChart(); + } +}