JavaNN/src/main/java/Network.java

107 lines
No EOL
3.5 KiB
Java

import org.ejml.simple.SimpleMatrix;
import java.util.ArrayList;
import java.util.function.BiFunction;
public class Network {
private ArrayList<Layer> layers;
private BiFunction<SimpleMatrix, SimpleMatrix, Double> loss;
private BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime;
public Network() {
layers = new ArrayList<>();
}
public void addLayer(Layer layer) {
layers.add(layer);
}
/**
* Adds n neurons to a specific layer and also updates this and the next layer's weights and biases.
* Only works if there are two successive BlankLayers.
* @param layer index of layer in the ArrayList layers
* @param n amount how many new neurons should be added
*/
public void addNeuron(int layer, int n) {
if (!(this.layers.get(layer) instanceof BlankLayer)) {
System.out.println("This layer is not a BlankLayer");
} else if (!(this.layers.get(layer + 2) instanceof BlankLayer)) {
System.out.println("The next layer is not a BlankLayer");
}
((BlankLayer) this.layers.get(layer)).addNeuron(n);
((BlankLayer) this.layers.get(layer + 2)).updateInputSize(n);
}
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
this.loss = loss;
this.lossPrime = lossPrime;
}
public SimpleMatrix[] predict(SimpleMatrix[] inputs) {
SimpleMatrix[] result = new SimpleMatrix[inputs.length];
SimpleMatrix output;
int i = 0;
for (SimpleMatrix input : inputs) {
output = input;
for (Layer l : layers) {
output = l.forwardPropagation(output);
}
result[i] = output;
i++;
}
return result;
}
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) {
int samples = X_train.length;
for (int i = 0; i < epochs; i++) {
double err = 0;
for (int j = 0; j < samples; j++) {
// forward propagation
SimpleMatrix output = X_train[j];
for (Layer l : layers) {
output = l.forwardPropagation(output);
}
// compute loss (for display purpose only)
err = loss.apply(y_train[j], output);
// backward propagation
SimpleMatrix error = lossPrime.apply(y_train[j], output);
for (int k = layers.size() - 1; k >= 0; k--) {
error = layers.get(k).backwardPropagation(error, learningRate);
}
}
// calculate average error on all samples
err /= samples;
System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err);
}
}
public ArrayList<Layer> getLayers() {
return layers;
}
public void setLayers(ArrayList<Layer> layers) {
this.layers = layers;
}
public BiFunction<SimpleMatrix, SimpleMatrix, Double> getLoss() {
return loss;
}
public void setLoss(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss) {
this.loss = loss;
}
public BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> getLossPrime() {
return lossPrime;
}
public void setLossPrime(BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
this.lossPrime = lossPrime;
}
}