diff --git a/src/main/java/ExampleXORBlankLayers.java b/src/main/java/ExampleXORBlankLayers.java new file mode 100644 index 0000000..2dca672 --- /dev/null +++ b/src/main/java/ExampleXORBlankLayers.java @@ -0,0 +1,32 @@ +import org.ejml.simple.SimpleMatrix; + +public class ExampleXORBlankLayers { + public static void main(String[] args) { + SimpleMatrix[] X_train = {new SimpleMatrix(new double[][]{{0, 0}}), + new SimpleMatrix(new double[][]{{0, 1}}), + new SimpleMatrix(new double[][]{{1, 0}}), + new SimpleMatrix(new double[][]{{1, 1}})}; + SimpleMatrix[] y_train = {new SimpleMatrix(new double[][]{{0}}), + new SimpleMatrix(new double[][]{{1}}), + new SimpleMatrix(new double[][]{{1}}), + new SimpleMatrix(new double[][]{{0}})}; + + Network network = new Network(); + network.addLayer(new BlankLayer(2)); + network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); + network.addLayer(new BlankLayer(1)); + network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); + network.addNeuron(0, 2); + + network.use(LossFunctions::MSE, LossFunctions::MSEPrime); + network.fit(X_train, y_train, 1000, 0.1d); + + SimpleMatrix[] output = network.predict(X_train); + for (SimpleMatrix entry : output) { + System.out.println("Prediction:"); + for (int i = 0; i < entry.getNumElements(); i++) { + System.out.println(Math.round(entry.get(i))); + } + } + } +}