31 lines
1.3 KiB
Java
31 lines
1.3 KiB
Java
import org.ejml.simple.SimpleMatrix;
|
|
|
|
public class ExampleXOR {
|
|
public static void main(String[] args) {
|
|
SimpleMatrix[] X_train = {new SimpleMatrix(new double[][]{{0, 0}}),
|
|
new SimpleMatrix(new double[][]{{0, 1}}),
|
|
new SimpleMatrix(new double[][]{{1, 0}}),
|
|
new SimpleMatrix(new double[][]{{1, 1}})};
|
|
SimpleMatrix[] y_train = {new SimpleMatrix(new double[][]{{0}}),
|
|
new SimpleMatrix(new double[][]{{1}}),
|
|
new SimpleMatrix(new double[][]{{1}}),
|
|
new SimpleMatrix(new double[][]{{0}})};
|
|
|
|
Network network = new Network();
|
|
network.addLayer(new FCLayer(3));
|
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
|
network.addLayer(new FCLayer(1));
|
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
|
|
|
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
|
network.fit(X_train, y_train, 1000, 0.1d);
|
|
|
|
SimpleMatrix[] output = network.predict(X_train);
|
|
for (SimpleMatrix entry : output) {
|
|
System.out.println("Prediction:");
|
|
for (int i = 0; i < entry.getNumElements(); i++) {
|
|
System.out.println(Math.round(entry.get(i)));
|
|
}
|
|
}
|
|
}
|
|
}
|