The inputSize of each layer does not need to be specified anymore

This commit is contained in:
lluni 2022-05-25 00:41:38 +02:00
parent 95501bf4b1
commit 1c66f1b72f
3 changed files with 60 additions and 41 deletions

View file

@ -12,9 +12,9 @@ public class ExampleXOR {
new SimpleMatrix(new double[][]{{0}})}; new SimpleMatrix(new double[][]{{0}})};
Network network = new Network(); Network network = new Network();
network.addLayer(new FCLayer(2, 3)); network.addLayer(new FCLayer(3));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.addLayer(new FCLayer(3, 1)); network.addLayer(new FCLayer(1));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.use(LossFunctions::MSE, LossFunctions::MSEPrime); network.use(LossFunctions::MSE, LossFunctions::MSEPrime);

View file

@ -12,9 +12,9 @@ public class ExampleXORBlankLayers {
new SimpleMatrix(new double[][]{{0}})}; new SimpleMatrix(new double[][]{{0}})};
Network network = new Network(); Network network = new Network();
network.addLayer(new FCLayer(2, 1)); network.addLayer(new FCLayer(1));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.addLayer(new FCLayer(1, 1)); network.addLayer(new FCLayer(1));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.addNeuron(0, 2); network.addNeuron(0, 2);

View file

@ -3,15 +3,23 @@ import org.ejml.simple.SimpleMatrix;
import java.util.Random; import java.util.Random;
public class FCLayer extends Layer { public class FCLayer extends Layer {
SimpleMatrix weights; private SimpleMatrix weights;
SimpleMatrix biases; private SimpleMatrix biases;
private int numNeurons;
private boolean isInitialized;
public FCLayer(int inputSize, int outputSize) { public FCLayer(int numNeurons) {
this.numNeurons = numNeurons;
isInitialized = false;
}
private void initialize(int inputSize) {
Random random = new Random(); Random random = new Random();
weights = new SimpleMatrix(inputSize, outputSize, true, this.weights = new SimpleMatrix(inputSize, numNeurons, true,
random.doubles((long) inputSize*outputSize, -1, 1).toArray()); random.doubles((long) inputSize*numNeurons, -1, 1).toArray());
biases = new SimpleMatrix(1, outputSize, true, this.biases = new SimpleMatrix(1, numNeurons, true,
random.doubles(outputSize, -1, 1).toArray()); random.doubles(numNeurons, -1, 1).toArray());
this.isInitialized = true;
} }
/** /**
@ -19,6 +27,7 @@ public class FCLayer extends Layer {
* @param n amount of new neurons in previous layer * @param n amount of new neurons in previous layer
*/ */
public void updateInputSize(int n) { public void updateInputSize(int n) {
if (isInitialized) {
Random random = new Random(); Random random = new Random();
// add new weights // add new weights
@ -35,6 +44,7 @@ public class FCLayer extends Layer {
} }
this.weights = newWeights; this.weights = newWeights;
} }
}
/** /**
* Adds new neurons at the end of the layer * Adds new neurons at the end of the layer
@ -43,6 +53,10 @@ public class FCLayer extends Layer {
public void addNeuron(int n) { public void addNeuron(int n) {
Random random = new Random(); Random random = new Random();
// update neuron count
this.numNeurons += n;
if (isInitialized) {
// add new weights // add new weights
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n); SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n);
for (int i = 0; i < this.weights.numRows(); i++) { for (int i = 0; i < this.weights.numRows(); i++) {
@ -64,9 +78,14 @@ public class FCLayer extends Layer {
System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n); System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n);
this.biases = newBiases; this.biases = newBiases;
} }
}
@Override @Override
public SimpleMatrix forwardPropagation(SimpleMatrix inputs) { public SimpleMatrix forwardPropagation(SimpleMatrix inputs) {
if (!isInitialized) {
initialize(inputs.numCols());
}
this.input = inputs; this.input = inputs;
this.output = this.input.mult(this.weights).plus(this.biases); this.output = this.input.mult(this.weights).plus(this.biases);
return this.output; return this.output;