Added sine example
This commit is contained in:
parent
a3be9daf02
commit
b53328b41c
1 changed files with 58 additions and 0 deletions
58
src/main/java/ExampleSine.java
Normal file
58
src/main/java/ExampleSine.java
Normal file
|
@ -0,0 +1,58 @@
|
|||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.knowm.xchart.SwingWrapper;
|
||||
import org.knowm.xchart.XYChart;
|
||||
import org.knowm.xchart.XYChartBuilder;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
public class ExampleSine {
|
||||
private static final int TRAINING_SIZE = 100000;
|
||||
private static final int TEST_SIZE = 1000;
|
||||
|
||||
public static void main(String[] args) {
|
||||
SimpleMatrix[] X_train = new SimpleMatrix[TRAINING_SIZE];
|
||||
SimpleMatrix[] y_train = new SimpleMatrix[TRAINING_SIZE];
|
||||
double[] X_test_linspace = Utilities.linspace(0, 2 * Math.PI, TEST_SIZE);
|
||||
double[] y_test_true = new double[TEST_SIZE];
|
||||
double[] y_test_pred = new double[TEST_SIZE];
|
||||
SimpleMatrix[] X_test = new SimpleMatrix[TEST_SIZE];
|
||||
SimpleMatrix[] y_test = new SimpleMatrix[TEST_SIZE];
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < TRAINING_SIZE; i++) {
|
||||
double temp = random.nextDouble(0, 2 * Math.PI);
|
||||
X_train[i] = new SimpleMatrix(new double[][]{{temp}});
|
||||
y_train[i] = new SimpleMatrix(new double[][]{{Math.sin(temp)}});
|
||||
}
|
||||
for (int i = 0; i < TEST_SIZE; i++) {
|
||||
X_test[i] = new SimpleMatrix(new double[][]{{X_test_linspace[i]}});
|
||||
y_test[i] = new SimpleMatrix(new double[][]{{Math.sin(X_test_linspace[i])}});
|
||||
y_test_true[i] = Math.sin(X_test_linspace[i]);
|
||||
}
|
||||
|
||||
// create network and add layers
|
||||
Network network = new Network();
|
||||
network.addLayer(new FCLayer(8));
|
||||
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
||||
network.addLayer(new FCLayer(8));
|
||||
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
||||
network.addLayer(new FCLayer(1));
|
||||
|
||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||
network.fit(X_train, y_train, 100, 0.05d);
|
||||
|
||||
SimpleMatrix[] output = network.predict(X_test);
|
||||
for (int i = 0; i < output.length; i++) {
|
||||
y_test_pred[i] = output[i].get(0);
|
||||
System.out.println("Prediction for x=" + X_test[i].get(0) + " (correct value: " + y_test[i].get(0) + "):");
|
||||
for (int j = 0; j < output[i].getNumElements(); j++) {
|
||||
System.out.println(output[i].get(j));
|
||||
}
|
||||
System.out.println();
|
||||
}
|
||||
|
||||
XYChart chart = new XYChartBuilder().title("sin(x) predictions").xAxisTitle("x").yAxisTitle("y").build();
|
||||
chart.addSeries("sin(x) true", X_test_linspace, y_test_true);
|
||||
chart.addSeries("sin(x) predictions", X_test_linspace, y_test_pred);
|
||||
new SwingWrapper<>(chart).displayChart();
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue