107 lines
No EOL
3.5 KiB
Java
107 lines
No EOL
3.5 KiB
Java
import org.ejml.simple.SimpleMatrix;
|
|
|
|
import java.util.ArrayList;
|
|
import java.util.function.BiFunction;
|
|
|
|
public class Network {
|
|
private ArrayList<Layer> layers;
|
|
private BiFunction<SimpleMatrix, SimpleMatrix, Double> loss;
|
|
private BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime;
|
|
|
|
public Network() {
|
|
layers = new ArrayList<>();
|
|
}
|
|
|
|
public void addLayer(Layer layer) {
|
|
layers.add(layer);
|
|
}
|
|
|
|
/**
|
|
* Adds n neurons to a specific layer and also updates this and the next layer's weights and biases.
|
|
* Only works if there are two successive BlankLayers.
|
|
* @param layer index of layer in the ArrayList layers
|
|
* @param n amount how many new neurons should be added
|
|
*/
|
|
public void addNeuron(int layer, int n) {
|
|
if (!(this.layers.get(layer) instanceof BlankLayer)) {
|
|
System.out.println("This layer is not a BlankLayer");
|
|
} else if (!(this.layers.get(layer + 2) instanceof BlankLayer)) {
|
|
System.out.println("The next layer is not a BlankLayer");
|
|
}
|
|
((BlankLayer) this.layers.get(layer)).addNeuron(n);
|
|
((BlankLayer) this.layers.get(layer + 2)).updateInputSize(n);
|
|
}
|
|
|
|
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
|
|
this.loss = loss;
|
|
this.lossPrime = lossPrime;
|
|
}
|
|
|
|
public SimpleMatrix[] predict(SimpleMatrix[] inputs) {
|
|
SimpleMatrix[] result = new SimpleMatrix[inputs.length];
|
|
SimpleMatrix output;
|
|
int i = 0;
|
|
|
|
for (SimpleMatrix input : inputs) {
|
|
output = input;
|
|
for (Layer l : layers) {
|
|
output = l.forwardPropagation(output);
|
|
}
|
|
result[i] = output;
|
|
i++;
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) {
|
|
int samples = X_train.length;
|
|
|
|
for (int i = 0; i < epochs; i++) {
|
|
double err = 0;
|
|
for (int j = 0; j < samples; j++) {
|
|
// forward propagation
|
|
SimpleMatrix output = X_train[j];
|
|
for (Layer l : layers) {
|
|
output = l.forwardPropagation(output);
|
|
}
|
|
|
|
// compute loss (for display purpose only)
|
|
err = loss.apply(y_train[j], output);
|
|
|
|
// backward propagation
|
|
SimpleMatrix error = lossPrime.apply(y_train[j], output);
|
|
for (int k = layers.size() - 1; k >= 0; k--) {
|
|
error = layers.get(k).backwardPropagation(error, learningRate);
|
|
}
|
|
}
|
|
// calculate average error on all samples
|
|
err /= samples;
|
|
System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err);
|
|
}
|
|
}
|
|
|
|
public ArrayList<Layer> getLayers() {
|
|
return layers;
|
|
}
|
|
|
|
public void setLayers(ArrayList<Layer> layers) {
|
|
this.layers = layers;
|
|
}
|
|
|
|
public BiFunction<SimpleMatrix, SimpleMatrix, Double> getLoss() {
|
|
return loss;
|
|
}
|
|
|
|
public void setLoss(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss) {
|
|
this.loss = loss;
|
|
}
|
|
|
|
public BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> getLossPrime() {
|
|
return lossPrime;
|
|
}
|
|
|
|
public void setLossPrime(BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
|
|
this.lossPrime = lossPrime;
|
|
}
|
|
} |