Compare commits

..

No commits in common. "1c66f1b72feba721e1f34bf9fcfaaa6d9f6db038" and "4766ea0ad96f806b66ca0aae113fc205d44678eb" have entirely different histories.

5 changed files with 66 additions and 86 deletions

View file

@ -3,7 +3,6 @@ import org.ejml.simple.SimpleMatrix;
import java.util.Random; import java.util.Random;
/** /**
* Goal: initialize layer without any neurons. Not yet implemented.
* Layer initialized with 1 neuron. * Layer initialized with 1 neuron.
* Assumes that each new neuron is fully connected to every previous neuron (this will be changed in the future). * Assumes that each new neuron is fully connected to every previous neuron (this will be changed in the future).
*/ */
@ -19,6 +18,57 @@ public class BlankLayer extends Layer {
random.doubles(1, -1, 1).toArray()); random.doubles(1, -1, 1).toArray());
} }
/**
* Updates input size when previous layer has newly added neurons.
* @param n amount of new neurons in previous layer
*/
public void updateInputSize(int n) {
Random random = new Random();
// add new weights
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows() + n, this.weights.numCols());
for (int i = 0; i < this.weights.numRows(); i++) {
for (int j = 0; j < this.weights.numCols(); j++) {
newWeights.set(i, j, this.weights.get(i, j));
}
}
for (int i = 0; i < newWeights.getNumElements(); i++) {
if (newWeights.get(i) == 0) {
newWeights.set(i, random.nextDouble(-1, 1));
}
}
this.weights = newWeights;
}
/**
* Adds new neurons at the end of the layer
* @param n amount how many new neurons should be added
*/
public void addNeuron(int n) {
Random random = new Random();
// add new weights
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n);
for (int i = 0; i < this.weights.numRows(); i++) {
for (int j = 0; j < this.weights.numCols(); j++) {
newWeights.set(i, j, this.weights.get(i, j));
}
}
for (int i = 0; i < newWeights.getNumElements(); i++) {
if (newWeights.get(i) == 0) {
newWeights.set(i, random.nextDouble(-1, 1));
}
}
this.weights = newWeights;
// add new biases
SimpleMatrix newBiases = new SimpleMatrix(1, this.biases.numCols() + n);
double[] newBiasValues = random.doubles(n, -1, 1).toArray();
System.arraycopy(this.biases.getDDRM().data, 0, newBiases.getDDRM().data, 0, this.biases.numCols());
System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n);
this.biases = newBiases;
}
@Override @Override
public SimpleMatrix forwardPropagation(SimpleMatrix inputs) { public SimpleMatrix forwardPropagation(SimpleMatrix inputs) {
this.input = inputs; this.input = inputs;

View file

@ -12,9 +12,9 @@ public class ExampleXOR {
new SimpleMatrix(new double[][]{{0}})}; new SimpleMatrix(new double[][]{{0}})};
Network network = new Network(); Network network = new Network();
network.addLayer(new FCLayer(3)); network.addLayer(new FCLayer(2, 3));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.addLayer(new FCLayer(1)); network.addLayer(new FCLayer(3, 1));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.use(LossFunctions::MSE, LossFunctions::MSEPrime); network.use(LossFunctions::MSE, LossFunctions::MSEPrime);

View file

@ -12,9 +12,9 @@ public class ExampleXORBlankLayers {
new SimpleMatrix(new double[][]{{0}})}; new SimpleMatrix(new double[][]{{0}})};
Network network = new Network(); Network network = new Network();
network.addLayer(new FCLayer(1)); network.addLayer(new BlankLayer(2));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.addLayer(new FCLayer(1)); network.addLayer(new BlankLayer(1));
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.addNeuron(0, 2); network.addNeuron(0, 2);

View file

@ -3,89 +3,19 @@ import org.ejml.simple.SimpleMatrix;
import java.util.Random; import java.util.Random;
public class FCLayer extends Layer { public class FCLayer extends Layer {
private SimpleMatrix weights; SimpleMatrix weights;
private SimpleMatrix biases; SimpleMatrix biases;
private int numNeurons;
private boolean isInitialized;
public FCLayer(int numNeurons) { public FCLayer(int inputSize, int outputSize) {
this.numNeurons = numNeurons;
isInitialized = false;
}
private void initialize(int inputSize) {
Random random = new Random(); Random random = new Random();
this.weights = new SimpleMatrix(inputSize, numNeurons, true, weights = new SimpleMatrix(inputSize, outputSize, true,
random.doubles((long) inputSize*numNeurons, -1, 1).toArray()); random.doubles((long) inputSize*outputSize, -1, 1).toArray());
this.biases = new SimpleMatrix(1, numNeurons, true, biases = new SimpleMatrix(1, outputSize, true,
random.doubles(numNeurons, -1, 1).toArray()); random.doubles(outputSize, -1, 1).toArray());
this.isInitialized = true;
}
/**
* Updates input size when previous layer has newly added neurons.
* @param n amount of new neurons in previous layer
*/
public void updateInputSize(int n) {
if (isInitialized) {
Random random = new Random();
// add new weights
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows() + n, this.weights.numCols());
for (int i = 0; i < this.weights.numRows(); i++) {
for (int j = 0; j < this.weights.numCols(); j++) {
newWeights.set(i, j, this.weights.get(i, j));
}
}
for (int i = 0; i < newWeights.getNumElements(); i++) {
if (newWeights.get(i) == 0) {
newWeights.set(i, random.nextDouble(-1, 1));
}
}
this.weights = newWeights;
}
}
/**
* Adds new neurons at the end of the layer
* @param n amount how many new neurons should be added
*/
public void addNeuron(int n) {
Random random = new Random();
// update neuron count
this.numNeurons += n;
if (isInitialized) {
// add new weights
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n);
for (int i = 0; i < this.weights.numRows(); i++) {
for (int j = 0; j < this.weights.numCols(); j++) {
newWeights.set(i, j, this.weights.get(i, j));
}
}
for (int i = 0; i < newWeights.getNumElements(); i++) {
if (newWeights.get(i) == 0) {
newWeights.set(i, random.nextDouble(-1, 1));
}
}
this.weights = newWeights;
// add new biases
SimpleMatrix newBiases = new SimpleMatrix(1, this.biases.numCols() + n);
double[] newBiasValues = random.doubles(n, -1, 1).toArray();
System.arraycopy(this.biases.getDDRM().data, 0, newBiases.getDDRM().data, 0, this.biases.numCols());
System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n);
this.biases = newBiases;
}
} }
@Override @Override
public SimpleMatrix forwardPropagation(SimpleMatrix inputs) { public SimpleMatrix forwardPropagation(SimpleMatrix inputs) {
if (!isInitialized) {
initialize(inputs.numCols());
}
this.input = inputs; this.input = inputs;
this.output = this.input.mult(this.weights).plus(this.biases); this.output = this.input.mult(this.weights).plus(this.biases);
return this.output; return this.output;

View file

@ -23,13 +23,13 @@ public class Network {
* @param n amount how many new neurons should be added * @param n amount how many new neurons should be added
*/ */
public void addNeuron(int layer, int n) { public void addNeuron(int layer, int n) {
if (!(this.layers.get(layer) instanceof FCLayer)) { if (!(this.layers.get(layer) instanceof BlankLayer)) {
System.out.println("This layer is not a BlankLayer"); System.out.println("This layer is not a BlankLayer");
} else if (!(this.layers.get(layer + 2) instanceof FCLayer)) { } else if (!(this.layers.get(layer + 2) instanceof BlankLayer)) {
System.out.println("The next layer is not a BlankLayer"); System.out.println("The next layer is not a BlankLayer");
} }
((FCLayer) this.layers.get(layer)).addNeuron(n); ((BlankLayer) this.layers.get(layer)).addNeuron(n);
((FCLayer) this.layers.get(layer + 2)).updateInputSize(n); ((BlankLayer) this.layers.get(layer + 2)).updateInputSize(n);
} }
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) { public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {