import org.ejml.simple.SimpleMatrix; import java.util.ArrayList; import java.util.function.BiFunction; public class Network { private ArrayList layers; private BiFunction loss; private BiFunction lossPrime; public Network() { layers = new ArrayList<>(); } public void addLayer(Layer layer) { layers.add(layer); } /** * Adds n neurons to a specific layer and also updates this and the next layer's weights and biases. * Only works if there are two successive BlankLayers. * @param layer index of layer in the ArrayList layers * @param n amount how many new neurons should be added */ public void addNeuron(int layer, int n) { if (!(this.layers.get(layer) instanceof BlankLayer)) { System.out.println("This layer is not a BlankLayer"); } else if (!(this.layers.get(layer + 2) instanceof BlankLayer)) { System.out.println("The next layer is not a BlankLayer"); } ((BlankLayer) this.layers.get(layer)).addNeuron(n); ((BlankLayer) this.layers.get(layer + 2)).updateInputSize(n); } public void use(BiFunction loss, BiFunction lossPrime) { this.loss = loss; this.lossPrime = lossPrime; } public SimpleMatrix[] predict(SimpleMatrix[] inputs) { SimpleMatrix[] result = new SimpleMatrix[inputs.length]; SimpleMatrix output; int i = 0; for (SimpleMatrix input : inputs) { output = input; for (Layer l : layers) { output = l.forwardPropagation(output); } result[i] = output; i++; } return result; } public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) { int samples = X_train.length; for (int i = 0; i < epochs; i++) { double err = 0; for (int j = 0; j < samples; j++) { // forward propagation SimpleMatrix output = X_train[j]; for (Layer l : layers) { output = l.forwardPropagation(output); } // compute loss (for display purpose only) err = loss.apply(y_train[j], output); // backward propagation SimpleMatrix error = lossPrime.apply(y_train[j], output); for (int k = layers.size() - 1; k >= 0; k--) { error = layers.get(k).backwardPropagation(error, learningRate); } } // calculate average error on all samples err /= samples; System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err); } } public ArrayList getLayers() { return layers; } public void setLayers(ArrayList layers) { this.layers = layers; } public BiFunction getLoss() { return loss; } public void setLoss(BiFunction loss) { this.loss = loss; } public BiFunction getLossPrime() { return lossPrime; } public void setLossPrime(BiFunction lossPrime) { this.lossPrime = lossPrime; } }