Added support for adding new neurons

This commit is contained in:
lluni 2022-05-23 19:09:26 +02:00
parent 7738781bb5
commit 8c82838c54
3 changed files with 136 additions and 9 deletions

View file

@ -0,0 +1,88 @@
import org.ejml.simple.SimpleMatrix;
import java.util.Random;
/**
* Layer initialized with 1 neuron.
* Assumes that each new neuron is fully connected to every previous neuron (this will be changed in the future).
*/
public class BlankLayer extends Layer {
SimpleMatrix weights;
SimpleMatrix biases;
public BlankLayer(int inputSize) {
Random random = new Random();
this.weights = new SimpleMatrix(inputSize, 1, true,
random.doubles(inputSize, -1, 1).toArray());
this.biases = new SimpleMatrix(1, 1, true,
random.doubles(1, -1, 1).toArray());
}
/**
* Updates input size when previous layer has newly added neurons.
* @param n amount of new neurons in previous layer
*/
public void updateInputSize(int n) {
Random random = new Random();
// add new weights
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows() + n, this.weights.numCols());
for (int i = 0; i < this.weights.numRows(); i++) {
for (int j = 0; j < this.weights.numCols(); j++) {
newWeights.set(i, j, this.weights.get(i, j));
}
}
for (int i = 0; i < newWeights.getNumElements(); i++) {
if (newWeights.get(i) == 0) {
newWeights.set(i, random.nextDouble(-1, 1));
}
}
this.weights = newWeights;
}
/**
* Adds new neurons at the end of the layer
* @param n amount how many new neurons should be added
*/
public void addNeuron(int n) {
Random random = new Random();
// add new weights
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n);
for (int i = 0; i < this.weights.numRows(); i++) {
for (int j = 0; j < this.weights.numCols(); j++) {
newWeights.set(i, j, this.weights.get(i, j));
}
}
for (int i = 0; i < newWeights.getNumElements(); i++) {
if (newWeights.get(i) == 0) {
newWeights.set(i, random.nextDouble(-1, 1));
}
}
this.weights = newWeights;
// add new biases
SimpleMatrix newBiases = new SimpleMatrix(1, this.biases.numCols() + n);
double[] newBiasValues = random.doubles(n, -1, 1).toArray();
System.arraycopy(this.biases.getDDRM().data, 0, newBiases.getDDRM().data, 0, this.biases.numCols());
System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n);
this.biases = newBiases;
}
@Override
public SimpleMatrix forwardPropagation(SimpleMatrix inputs) {
this.input = inputs;
this.output = this.input.mult(this.weights).plus(this.biases);
return this.output;
}
@Override
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
SimpleMatrix inputError = outputError.mult(this.weights.transpose());
SimpleMatrix weightsError = this.input.transpose().mult(outputError);
this.weights = this.weights.plus(learningRate, weightsError);
this.biases = this.biases.plus(learningRate, outputError);
return inputError;
}
}

View file

@ -12,10 +12,10 @@ public class ExampleXOR {
new SimpleMatrix(new double[][]{{0}})}; new SimpleMatrix(new double[][]{{0}})};
Network network = new Network(); Network network = new Network();
network.add(new FCLayer(2, 3)); network.addLayer(new FCLayer(2, 3));
network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.add(new FCLayer(3, 1)); network.addLayer(new FCLayer(3, 1));
network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.use(LossFunctions::MSE, LossFunctions::MSEPrime); network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
network.fit(X_train, y_train, 1000, 0.1d); network.fit(X_train, y_train, 1000, 0.1d);

View file

@ -4,19 +4,34 @@ import java.util.ArrayList;
import java.util.function.BiFunction; import java.util.function.BiFunction;
public class Network { public class Network {
private ArrayList<Layer> layers;
ArrayList<Layer> layers; private BiFunction<SimpleMatrix, SimpleMatrix, Double> loss;
BiFunction<SimpleMatrix, SimpleMatrix, Double> loss; private BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime;
BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime;
public Network() { public Network() {
layers = new ArrayList<>(); layers = new ArrayList<>();
} }
public void add(Layer layer) { public void addLayer(Layer layer) {
layers.add(layer); layers.add(layer);
} }
/**
* Adds n neurons to a specific layer and also updates this and the next layer's weights and biases.
* Only works if there are two successive BlankLayers.
* @param layer index of layer in the ArrayList layers
* @param n amount how many new neurons should be added
*/
public void addNeuron(int layer, int n) {
if (!(this.layers.get(layer) instanceof BlankLayer)) {
System.out.println("This layer is not a BlankLayer");
} else if (!(this.layers.get(layer + 2) instanceof BlankLayer)) {
System.out.println("The next layer is not a BlankLayer");
}
((BlankLayer) this.layers.get(layer)).addNeuron(n);
((BlankLayer) this.layers.get(layer + 2)).updateInputSize(n);
}
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) { public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
this.loss = loss; this.loss = loss;
this.lossPrime = lossPrime; this.lossPrime = lossPrime;
@ -65,4 +80,28 @@ public class Network {
System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err); System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err);
} }
} }
public ArrayList<Layer> getLayers() {
return layers;
}
public void setLayers(ArrayList<Layer> layers) {
this.layers = layers;
}
public BiFunction<SimpleMatrix, SimpleMatrix, Double> getLoss() {
return loss;
}
public void setLoss(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss) {
this.loss = loss;
}
public BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> getLossPrime() {
return lossPrime;
}
public void setLossPrime(BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
this.lossPrime = lossPrime;
}
} }