JavaNN/src/main/java/ExampleXOR.java

31 lines
1.3 KiB
Java

import org.ejml.simple.SimpleMatrix;
public class ExampleXOR {
public static void main(String[] args) {
SimpleMatrix[] X_train = {new SimpleMatrix(new double[][]{{0, 0}}),
new SimpleMatrix(new double[][]{{0, 1}}),
new SimpleMatrix(new double[][]{{1, 0}}),
new SimpleMatrix(new double[][]{{1, 1}})};
SimpleMatrix[] y_train = {new SimpleMatrix(new double[][]{{0}}),
new SimpleMatrix(new double[][]{{1}}),
new SimpleMatrix(new double[][]{{1}}),
new SimpleMatrix(new double[][]{{0}})};
Network network = new Network();
network.addLayer(new FCLayer(3));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.addLayer(new FCLayer(1));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
network.fit(X_train, y_train, 1000, 0.1d);
SimpleMatrix[] output = network.predict(X_train);
for (SimpleMatrix entry : output) {
System.out.println("Prediction:");
for (int i = 0; i < entry.getNumElements(); i++) {
System.out.println(Math.round(entry.get(i)));
}
}
}
}