Added comments to ExampleSine
This commit is contained in:
parent
74e4d05fa1
commit
1b296bb5d0
1 changed files with 14 additions and 2 deletions
|
@ -18,19 +18,26 @@ public class ExampleSine {
|
||||||
private static final int TEST_SIZE = 1000;
|
private static final int TEST_SIZE = 1000;
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
|
// training set
|
||||||
SimpleMatrix[] X_train = new SimpleMatrix[TRAINING_SIZE];
|
SimpleMatrix[] X_train = new SimpleMatrix[TRAINING_SIZE];
|
||||||
SimpleMatrix[] y_train = new SimpleMatrix[TRAINING_SIZE];
|
SimpleMatrix[] y_train = new SimpleMatrix[TRAINING_SIZE];
|
||||||
|
// test set
|
||||||
|
SimpleMatrix[] X_test = new SimpleMatrix[TEST_SIZE];
|
||||||
|
SimpleMatrix[] y_test = new SimpleMatrix[TEST_SIZE];
|
||||||
|
// identical test set for plotting
|
||||||
double[] X_test_linspace = Utilities.linspace(0, 2 * Math.PI, TEST_SIZE);
|
double[] X_test_linspace = Utilities.linspace(0, 2 * Math.PI, TEST_SIZE);
|
||||||
double[] y_test_true = new double[TEST_SIZE];
|
double[] y_test_true = new double[TEST_SIZE];
|
||||||
double[] y_test_pred = new double[TEST_SIZE];
|
double[] y_test_pred = new double[TEST_SIZE];
|
||||||
SimpleMatrix[] X_test = new SimpleMatrix[TEST_SIZE];
|
|
||||||
SimpleMatrix[] y_test = new SimpleMatrix[TEST_SIZE];
|
|
||||||
Random random = new Random();
|
Random random = new Random();
|
||||||
|
|
||||||
|
// generate training set from random data
|
||||||
for (int i = 0; i < TRAINING_SIZE; i++) {
|
for (int i = 0; i < TRAINING_SIZE; i++) {
|
||||||
double temp = random.nextDouble(0, 2 * Math.PI);
|
double temp = random.nextDouble(0, 2 * Math.PI);
|
||||||
X_train[i] = new SimpleMatrix(new double[][]{{temp}});
|
X_train[i] = new SimpleMatrix(new double[][]{{temp}});
|
||||||
y_train[i] = new SimpleMatrix(new double[][]{{Math.sin(temp)}});
|
y_train[i] = new SimpleMatrix(new double[][]{{Math.sin(temp)}});
|
||||||
}
|
}
|
||||||
|
// generate test set
|
||||||
for (int i = 0; i < TEST_SIZE; i++) {
|
for (int i = 0; i < TEST_SIZE; i++) {
|
||||||
X_test[i] = new SimpleMatrix(new double[][]{{X_test_linspace[i]}});
|
X_test[i] = new SimpleMatrix(new double[][]{{X_test_linspace[i]}});
|
||||||
y_test[i] = new SimpleMatrix(new double[][]{{Math.sin(X_test_linspace[i])}});
|
y_test[i] = new SimpleMatrix(new double[][]{{Math.sin(X_test_linspace[i])}});
|
||||||
|
@ -45,9 +52,13 @@ public class ExampleSine {
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::LeakyReLu, ActivationFunctions::LeakyReLuPrime));
|
||||||
network.addLayer(new FCLayer(1));
|
network.addLayer(new FCLayer(1));
|
||||||
|
|
||||||
|
// configure loss function for the network
|
||||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||||
|
|
||||||
|
// train network on X_train and y_train
|
||||||
network.fit(X_train, y_train, 100, 0.05d);
|
network.fit(X_train, y_train, 100, 0.05d);
|
||||||
|
|
||||||
|
// predict X_test and output results to console
|
||||||
SimpleMatrix[] output = network.predict(X_test);
|
SimpleMatrix[] output = network.predict(X_test);
|
||||||
for (int i = 0; i < output.length; i++) {
|
for (int i = 0; i < output.length; i++) {
|
||||||
y_test_pred[i] = output[i].get(0);
|
y_test_pred[i] = output[i].get(0);
|
||||||
|
@ -58,6 +69,7 @@ public class ExampleSine {
|
||||||
System.out.println();
|
System.out.println();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create and display chart
|
||||||
XYChart chart = new XYChartBuilder().title("sin(x) predictions").xAxisTitle("x").yAxisTitle("y").build();
|
XYChart chart = new XYChartBuilder().title("sin(x) predictions").xAxisTitle("x").yAxisTitle("y").build();
|
||||||
chart.addSeries("sin(x) true", X_test_linspace, y_test_true);
|
chart.addSeries("sin(x) true", X_test_linspace, y_test_true);
|
||||||
chart.addSeries("sin(x) predictions", X_test_linspace, y_test_pred);
|
chart.addSeries("sin(x) predictions", X_test_linspace, y_test_pred);
|
||||||
|
|
Loading…
Reference in a new issue