Compare commits
2 commits
4766ea0ad9
...
1c66f1b72f
Author | SHA1 | Date | |
---|---|---|---|
1c66f1b72f | |||
95501bf4b1 |
5 changed files with 86 additions and 66 deletions
|
@ -3,6 +3,7 @@ import org.ejml.simple.SimpleMatrix;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* Goal: initialize layer without any neurons. Not yet implemented.
|
||||||
* Layer initialized with 1 neuron.
|
* Layer initialized with 1 neuron.
|
||||||
* Assumes that each new neuron is fully connected to every previous neuron (this will be changed in the future).
|
* Assumes that each new neuron is fully connected to every previous neuron (this will be changed in the future).
|
||||||
*/
|
*/
|
||||||
|
@ -18,57 +19,6 @@ public class BlankLayer extends Layer {
|
||||||
random.doubles(1, -1, 1).toArray());
|
random.doubles(1, -1, 1).toArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Updates input size when previous layer has newly added neurons.
|
|
||||||
* @param n amount of new neurons in previous layer
|
|
||||||
*/
|
|
||||||
public void updateInputSize(int n) {
|
|
||||||
Random random = new Random();
|
|
||||||
|
|
||||||
// add new weights
|
|
||||||
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows() + n, this.weights.numCols());
|
|
||||||
for (int i = 0; i < this.weights.numRows(); i++) {
|
|
||||||
for (int j = 0; j < this.weights.numCols(); j++) {
|
|
||||||
newWeights.set(i, j, this.weights.get(i, j));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < newWeights.getNumElements(); i++) {
|
|
||||||
if (newWeights.get(i) == 0) {
|
|
||||||
newWeights.set(i, random.nextDouble(-1, 1));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.weights = newWeights;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Adds new neurons at the end of the layer
|
|
||||||
* @param n amount how many new neurons should be added
|
|
||||||
*/
|
|
||||||
public void addNeuron(int n) {
|
|
||||||
Random random = new Random();
|
|
||||||
|
|
||||||
// add new weights
|
|
||||||
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n);
|
|
||||||
for (int i = 0; i < this.weights.numRows(); i++) {
|
|
||||||
for (int j = 0; j < this.weights.numCols(); j++) {
|
|
||||||
newWeights.set(i, j, this.weights.get(i, j));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < newWeights.getNumElements(); i++) {
|
|
||||||
if (newWeights.get(i) == 0) {
|
|
||||||
newWeights.set(i, random.nextDouble(-1, 1));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.weights = newWeights;
|
|
||||||
|
|
||||||
// add new biases
|
|
||||||
SimpleMatrix newBiases = new SimpleMatrix(1, this.biases.numCols() + n);
|
|
||||||
double[] newBiasValues = random.doubles(n, -1, 1).toArray();
|
|
||||||
System.arraycopy(this.biases.getDDRM().data, 0, newBiases.getDDRM().data, 0, this.biases.numCols());
|
|
||||||
System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n);
|
|
||||||
this.biases = newBiases;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SimpleMatrix forwardPropagation(SimpleMatrix inputs) {
|
public SimpleMatrix forwardPropagation(SimpleMatrix inputs) {
|
||||||
this.input = inputs;
|
this.input = inputs;
|
||||||
|
|
|
@ -12,9 +12,9 @@ public class ExampleXOR {
|
||||||
new SimpleMatrix(new double[][]{{0}})};
|
new SimpleMatrix(new double[][]{{0}})};
|
||||||
|
|
||||||
Network network = new Network();
|
Network network = new Network();
|
||||||
network.addLayer(new FCLayer(2, 3));
|
network.addLayer(new FCLayer(3));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
network.addLayer(new FCLayer(3, 1));
|
network.addLayer(new FCLayer(1));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
|
|
||||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||||
|
|
|
@ -12,9 +12,9 @@ public class ExampleXORBlankLayers {
|
||||||
new SimpleMatrix(new double[][]{{0}})};
|
new SimpleMatrix(new double[][]{{0}})};
|
||||||
|
|
||||||
Network network = new Network();
|
Network network = new Network();
|
||||||
network.addLayer(new BlankLayer(2));
|
network.addLayer(new FCLayer(1));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
network.addLayer(new BlankLayer(1));
|
network.addLayer(new FCLayer(1));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
network.addNeuron(0, 2);
|
network.addNeuron(0, 2);
|
||||||
|
|
||||||
|
|
|
@ -3,19 +3,89 @@ import org.ejml.simple.SimpleMatrix;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
public class FCLayer extends Layer {
|
public class FCLayer extends Layer {
|
||||||
SimpleMatrix weights;
|
private SimpleMatrix weights;
|
||||||
SimpleMatrix biases;
|
private SimpleMatrix biases;
|
||||||
|
private int numNeurons;
|
||||||
|
private boolean isInitialized;
|
||||||
|
|
||||||
public FCLayer(int inputSize, int outputSize) {
|
public FCLayer(int numNeurons) {
|
||||||
|
this.numNeurons = numNeurons;
|
||||||
|
isInitialized = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initialize(int inputSize) {
|
||||||
Random random = new Random();
|
Random random = new Random();
|
||||||
weights = new SimpleMatrix(inputSize, outputSize, true,
|
this.weights = new SimpleMatrix(inputSize, numNeurons, true,
|
||||||
random.doubles((long) inputSize*outputSize, -1, 1).toArray());
|
random.doubles((long) inputSize*numNeurons, -1, 1).toArray());
|
||||||
biases = new SimpleMatrix(1, outputSize, true,
|
this.biases = new SimpleMatrix(1, numNeurons, true,
|
||||||
random.doubles(outputSize, -1, 1).toArray());
|
random.doubles(numNeurons, -1, 1).toArray());
|
||||||
|
this.isInitialized = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Updates input size when previous layer has newly added neurons.
|
||||||
|
* @param n amount of new neurons in previous layer
|
||||||
|
*/
|
||||||
|
public void updateInputSize(int n) {
|
||||||
|
if (isInitialized) {
|
||||||
|
Random random = new Random();
|
||||||
|
|
||||||
|
// add new weights
|
||||||
|
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows() + n, this.weights.numCols());
|
||||||
|
for (int i = 0; i < this.weights.numRows(); i++) {
|
||||||
|
for (int j = 0; j < this.weights.numCols(); j++) {
|
||||||
|
newWeights.set(i, j, this.weights.get(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < newWeights.getNumElements(); i++) {
|
||||||
|
if (newWeights.get(i) == 0) {
|
||||||
|
newWeights.set(i, random.nextDouble(-1, 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.weights = newWeights;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds new neurons at the end of the layer
|
||||||
|
* @param n amount how many new neurons should be added
|
||||||
|
*/
|
||||||
|
public void addNeuron(int n) {
|
||||||
|
Random random = new Random();
|
||||||
|
|
||||||
|
// update neuron count
|
||||||
|
this.numNeurons += n;
|
||||||
|
|
||||||
|
if (isInitialized) {
|
||||||
|
// add new weights
|
||||||
|
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n);
|
||||||
|
for (int i = 0; i < this.weights.numRows(); i++) {
|
||||||
|
for (int j = 0; j < this.weights.numCols(); j++) {
|
||||||
|
newWeights.set(i, j, this.weights.get(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < newWeights.getNumElements(); i++) {
|
||||||
|
if (newWeights.get(i) == 0) {
|
||||||
|
newWeights.set(i, random.nextDouble(-1, 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.weights = newWeights;
|
||||||
|
|
||||||
|
// add new biases
|
||||||
|
SimpleMatrix newBiases = new SimpleMatrix(1, this.biases.numCols() + n);
|
||||||
|
double[] newBiasValues = random.doubles(n, -1, 1).toArray();
|
||||||
|
System.arraycopy(this.biases.getDDRM().data, 0, newBiases.getDDRM().data, 0, this.biases.numCols());
|
||||||
|
System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n);
|
||||||
|
this.biases = newBiases;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SimpleMatrix forwardPropagation(SimpleMatrix inputs) {
|
public SimpleMatrix forwardPropagation(SimpleMatrix inputs) {
|
||||||
|
if (!isInitialized) {
|
||||||
|
initialize(inputs.numCols());
|
||||||
|
}
|
||||||
|
|
||||||
this.input = inputs;
|
this.input = inputs;
|
||||||
this.output = this.input.mult(this.weights).plus(this.biases);
|
this.output = this.input.mult(this.weights).plus(this.biases);
|
||||||
return this.output;
|
return this.output;
|
||||||
|
|
|
@ -23,13 +23,13 @@ public class Network {
|
||||||
* @param n amount how many new neurons should be added
|
* @param n amount how many new neurons should be added
|
||||||
*/
|
*/
|
||||||
public void addNeuron(int layer, int n) {
|
public void addNeuron(int layer, int n) {
|
||||||
if (!(this.layers.get(layer) instanceof BlankLayer)) {
|
if (!(this.layers.get(layer) instanceof FCLayer)) {
|
||||||
System.out.println("This layer is not a BlankLayer");
|
System.out.println("This layer is not a BlankLayer");
|
||||||
} else if (!(this.layers.get(layer + 2) instanceof BlankLayer)) {
|
} else if (!(this.layers.get(layer + 2) instanceof FCLayer)) {
|
||||||
System.out.println("The next layer is not a BlankLayer");
|
System.out.println("The next layer is not a BlankLayer");
|
||||||
}
|
}
|
||||||
((BlankLayer) this.layers.get(layer)).addNeuron(n);
|
((FCLayer) this.layers.get(layer)).addNeuron(n);
|
||||||
((BlankLayer) this.layers.get(layer + 2)).updateInputSize(n);
|
((FCLayer) this.layers.get(layer + 2)).updateInputSize(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
|
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
|
||||||
|
|
Loading…
Reference in a new issue