Compare commits
No commits in common. "4766ea0ad96f806b66ca0aae113fc205d44678eb" and "7e80e5bc9485862a814b1e174d422e7ff52e6334" have entirely different histories.
4766ea0ad9
...
7e80e5bc94
6 changed files with 9 additions and 202 deletions
|
@ -11,7 +11,6 @@ repositories {
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation 'org.ejml:ejml-all:0.41'
|
implementation 'org.ejml:ejml-all:0.41'
|
||||||
implementation 'com.opencsv:opencsv:5.6'
|
|
||||||
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2'
|
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2'
|
||||||
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
|
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,88 +0,0 @@
|
||||||
import org.ejml.simple.SimpleMatrix;
|
|
||||||
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Layer initialized with 1 neuron.
|
|
||||||
* Assumes that each new neuron is fully connected to every previous neuron (this will be changed in the future).
|
|
||||||
*/
|
|
||||||
public class BlankLayer extends Layer {
|
|
||||||
SimpleMatrix weights;
|
|
||||||
SimpleMatrix biases;
|
|
||||||
|
|
||||||
public BlankLayer(int inputSize) {
|
|
||||||
Random random = new Random();
|
|
||||||
this.weights = new SimpleMatrix(inputSize, 1, true,
|
|
||||||
random.doubles(inputSize, -1, 1).toArray());
|
|
||||||
this.biases = new SimpleMatrix(1, 1, true,
|
|
||||||
random.doubles(1, -1, 1).toArray());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Updates input size when previous layer has newly added neurons.
|
|
||||||
* @param n amount of new neurons in previous layer
|
|
||||||
*/
|
|
||||||
public void updateInputSize(int n) {
|
|
||||||
Random random = new Random();
|
|
||||||
|
|
||||||
// add new weights
|
|
||||||
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows() + n, this.weights.numCols());
|
|
||||||
for (int i = 0; i < this.weights.numRows(); i++) {
|
|
||||||
for (int j = 0; j < this.weights.numCols(); j++) {
|
|
||||||
newWeights.set(i, j, this.weights.get(i, j));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < newWeights.getNumElements(); i++) {
|
|
||||||
if (newWeights.get(i) == 0) {
|
|
||||||
newWeights.set(i, random.nextDouble(-1, 1));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.weights = newWeights;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Adds new neurons at the end of the layer
|
|
||||||
* @param n amount how many new neurons should be added
|
|
||||||
*/
|
|
||||||
public void addNeuron(int n) {
|
|
||||||
Random random = new Random();
|
|
||||||
|
|
||||||
// add new weights
|
|
||||||
SimpleMatrix newWeights = new SimpleMatrix(this.weights.numRows(), this.weights.numCols() + n);
|
|
||||||
for (int i = 0; i < this.weights.numRows(); i++) {
|
|
||||||
for (int j = 0; j < this.weights.numCols(); j++) {
|
|
||||||
newWeights.set(i, j, this.weights.get(i, j));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < newWeights.getNumElements(); i++) {
|
|
||||||
if (newWeights.get(i) == 0) {
|
|
||||||
newWeights.set(i, random.nextDouble(-1, 1));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.weights = newWeights;
|
|
||||||
|
|
||||||
// add new biases
|
|
||||||
SimpleMatrix newBiases = new SimpleMatrix(1, this.biases.numCols() + n);
|
|
||||||
double[] newBiasValues = random.doubles(n, -1, 1).toArray();
|
|
||||||
System.arraycopy(this.biases.getDDRM().data, 0, newBiases.getDDRM().data, 0, this.biases.numCols());
|
|
||||||
System.arraycopy(newBiasValues, 0, newBiases.getDDRM().data, this.biases.numCols(), n);
|
|
||||||
this.biases = newBiases;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SimpleMatrix forwardPropagation(SimpleMatrix inputs) {
|
|
||||||
this.input = inputs;
|
|
||||||
this.output = this.input.mult(this.weights).plus(this.biases);
|
|
||||||
return this.output;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
|
|
||||||
SimpleMatrix inputError = outputError.mult(this.weights.transpose());
|
|
||||||
SimpleMatrix weightsError = this.input.transpose().mult(outputError);
|
|
||||||
|
|
||||||
this.weights = this.weights.plus(learningRate, weightsError);
|
|
||||||
this.biases = this.biases.plus(learningRate, outputError);
|
|
||||||
return inputError;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -12,10 +12,10 @@ public class ExampleXOR {
|
||||||
new SimpleMatrix(new double[][]{{0}})};
|
new SimpleMatrix(new double[][]{{0}})};
|
||||||
|
|
||||||
Network network = new Network();
|
Network network = new Network();
|
||||||
network.addLayer(new FCLayer(2, 3));
|
network.add(new FCLayer(2, 3));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
network.addLayer(new FCLayer(3, 1));
|
network.add(new FCLayer(3, 1));
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
|
|
||||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||||
network.fit(X_train, y_train, 1000, 0.1d);
|
network.fit(X_train, y_train, 1000, 0.1d);
|
||||||
|
|
|
@ -1,32 +0,0 @@
|
||||||
import org.ejml.simple.SimpleMatrix;
|
|
||||||
|
|
||||||
public class ExampleXORBlankLayers {
|
|
||||||
public static void main(String[] args) {
|
|
||||||
SimpleMatrix[] X_train = {new SimpleMatrix(new double[][]{{0, 0}}),
|
|
||||||
new SimpleMatrix(new double[][]{{0, 1}}),
|
|
||||||
new SimpleMatrix(new double[][]{{1, 0}}),
|
|
||||||
new SimpleMatrix(new double[][]{{1, 1}})};
|
|
||||||
SimpleMatrix[] y_train = {new SimpleMatrix(new double[][]{{0}}),
|
|
||||||
new SimpleMatrix(new double[][]{{1}}),
|
|
||||||
new SimpleMatrix(new double[][]{{1}}),
|
|
||||||
new SimpleMatrix(new double[][]{{0}})};
|
|
||||||
|
|
||||||
Network network = new Network();
|
|
||||||
network.addLayer(new BlankLayer(2));
|
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
|
||||||
network.addLayer(new BlankLayer(1));
|
|
||||||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
|
||||||
network.addNeuron(0, 2);
|
|
||||||
|
|
||||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
|
||||||
network.fit(X_train, y_train, 1000, 0.1d);
|
|
||||||
|
|
||||||
SimpleMatrix[] output = network.predict(X_train);
|
|
||||||
for (SimpleMatrix entry : output) {
|
|
||||||
System.out.println("Prediction:");
|
|
||||||
for (int i = 0; i < entry.getNumElements(); i++) {
|
|
||||||
System.out.println(Math.round(entry.get(i)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -4,34 +4,19 @@ import java.util.ArrayList;
|
||||||
import java.util.function.BiFunction;
|
import java.util.function.BiFunction;
|
||||||
|
|
||||||
public class Network {
|
public class Network {
|
||||||
private ArrayList<Layer> layers;
|
|
||||||
private BiFunction<SimpleMatrix, SimpleMatrix, Double> loss;
|
ArrayList<Layer> layers;
|
||||||
private BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime;
|
BiFunction<SimpleMatrix, SimpleMatrix, Double> loss;
|
||||||
|
BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime;
|
||||||
|
|
||||||
public Network() {
|
public Network() {
|
||||||
layers = new ArrayList<>();
|
layers = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addLayer(Layer layer) {
|
public void add(Layer layer) {
|
||||||
layers.add(layer);
|
layers.add(layer);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Adds n neurons to a specific layer and also updates this and the next layer's weights and biases.
|
|
||||||
* Only works if there are two successive BlankLayers.
|
|
||||||
* @param layer index of layer in the ArrayList layers
|
|
||||||
* @param n amount how many new neurons should be added
|
|
||||||
*/
|
|
||||||
public void addNeuron(int layer, int n) {
|
|
||||||
if (!(this.layers.get(layer) instanceof BlankLayer)) {
|
|
||||||
System.out.println("This layer is not a BlankLayer");
|
|
||||||
} else if (!(this.layers.get(layer + 2) instanceof BlankLayer)) {
|
|
||||||
System.out.println("The next layer is not a BlankLayer");
|
|
||||||
}
|
|
||||||
((BlankLayer) this.layers.get(layer)).addNeuron(n);
|
|
||||||
((BlankLayer) this.layers.get(layer + 2)).updateInputSize(n);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
|
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
|
||||||
this.loss = loss;
|
this.loss = loss;
|
||||||
this.lossPrime = lossPrime;
|
this.lossPrime = lossPrime;
|
||||||
|
@ -80,28 +65,4 @@ public class Network {
|
||||||
System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err);
|
System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public ArrayList<Layer> getLayers() {
|
|
||||||
return layers;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setLayers(ArrayList<Layer> layers) {
|
|
||||||
this.layers = layers;
|
|
||||||
}
|
|
||||||
|
|
||||||
public BiFunction<SimpleMatrix, SimpleMatrix, Double> getLoss() {
|
|
||||||
return loss;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setLoss(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss) {
|
|
||||||
this.loss = loss;
|
|
||||||
}
|
|
||||||
|
|
||||||
public BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> getLossPrime() {
|
|
||||||
return lossPrime;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setLossPrime(BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
|
|
||||||
this.lossPrime = lossPrime;
|
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -1,33 +0,0 @@
|
||||||
import com.opencsv.CSVReader;
|
|
||||||
import com.opencsv.exceptions.CsvValidationException;
|
|
||||||
import org.ejml.simple.SimpleMatrix;
|
|
||||||
|
|
||||||
import java.io.FileReader;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class Utilities {
|
|
||||||
public static SimpleMatrix ones(int rows, int columns) {
|
|
||||||
SimpleMatrix mat = new SimpleMatrix(rows, columns);
|
|
||||||
Arrays.fill(mat.getDDRM().data, 1);
|
|
||||||
return mat;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static List<List<String>> readCSV(String filename) {
|
|
||||||
List<List<String>> entries = new ArrayList<>();
|
|
||||||
try (CSVReader csvReader = new CSVReader(new FileReader(filename))) {
|
|
||||||
String[] values;
|
|
||||||
while ((values = csvReader.readNext()) != null) {
|
|
||||||
entries.add(Arrays.asList(values));
|
|
||||||
}
|
|
||||||
return entries;
|
|
||||||
} catch (IOException e) {
|
|
||||||
System.out.println(filename + " does not exist");
|
|
||||||
} catch (CsvValidationException e) {
|
|
||||||
System.out.println("Invalid line in " + filename);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in a new issue