38 lines
1.3 KiB
Java
38 lines
1.3 KiB
Java
import org.ejml.simple.SimpleMatrix;
|
|
|
|
import java.util.Random;
|
|
|
|
/**
|
|
* Goal: initialize layer without any neurons. Not yet implemented.
|
|
* Layer initialized with 1 neuron.
|
|
* Assumes that each new neuron is fully connected to every previous neuron (this will be changed in the future).
|
|
*/
|
|
public class BlankLayer extends Layer {
|
|
SimpleMatrix weights;
|
|
SimpleMatrix biases;
|
|
|
|
public BlankLayer(int inputSize) {
|
|
Random random = new Random();
|
|
this.weights = new SimpleMatrix(inputSize, 1, true,
|
|
random.doubles(inputSize, -1, 1).toArray());
|
|
this.biases = new SimpleMatrix(1, 1, true,
|
|
random.doubles(1, -1, 1).toArray());
|
|
}
|
|
|
|
@Override
|
|
public SimpleMatrix forwardPropagation(SimpleMatrix inputs) {
|
|
this.input = inputs;
|
|
this.output = this.input.mult(this.weights).plus(this.biases);
|
|
return this.output;
|
|
}
|
|
|
|
@Override
|
|
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
|
|
SimpleMatrix inputError = outputError.mult(this.weights.transpose());
|
|
SimpleMatrix weightsError = this.input.transpose().mult(outputError);
|
|
|
|
this.weights = this.weights.plus(learningRate, weightsError);
|
|
this.biases = this.biases.plus(learningRate, outputError);
|
|
return inputError;
|
|
}
|
|
}
|