Added JavaDoc comments

This commit is contained in:
lluni 2022-05-28 02:10:02 +02:00
parent e19dec4af9
commit 67d94efc39
4 changed files with 56 additions and 2 deletions

View file

@ -21,8 +21,7 @@ public class Network {
} }
/** /**
* Adds n neurons to a specific layer and also updates this and the next layer's weights and biases. * Adds n neurons to a specific layer and also updates this and the next layer's weights and biases
* Only works if there are two successive BlankLayers.
* @param layer index of layer in the ArrayList layers * @param layer index of layer in the ArrayList layers
* @param n amount how many new neurons should be added * @param n amount how many new neurons should be added
*/ */
@ -36,11 +35,22 @@ public class Network {
((FCLayer) this.layers.get(layer + 2)).updateInputSize(n); ((FCLayer) this.layers.get(layer + 2)).updateInputSize(n);
} }
/**
* Sets loss function used to evaluate the network
* Chosen from {@link de.lluni.javann.functions.LossFunctions}
* @param loss loss function
* @param lossPrime derivative of the loss function
*/
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) { public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
this.loss = loss; this.loss = loss;
this.lossPrime = lossPrime; this.lossPrime = lossPrime;
} }
/**
* Predicts labels corresponding to the given inputs
* @param inputs inputs
* @return array of predictions
*/
public SimpleMatrix[] predict(SimpleMatrix[] inputs) { public SimpleMatrix[] predict(SimpleMatrix[] inputs) {
SimpleMatrix[] result = new SimpleMatrix[inputs.length]; SimpleMatrix[] result = new SimpleMatrix[inputs.length];
SimpleMatrix output; SimpleMatrix output;
@ -58,6 +68,13 @@ public class Network {
return result; return result;
} }
/**
* Trains the network on a set of training data
* @param X_train inputs
* @param y_train labels
* @param epochs amount of training iterations
* @param learningRate step size of gradient descent
*/
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) { public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) {
int samples = X_train.length; int samples = X_train.length;

View file

@ -4,10 +4,18 @@ import org.ejml.simple.SimpleMatrix;
import java.util.function.Function; import java.util.function.Function;
/**
* Layer to pass each output of every neuron from the previous layer through an activation function
*/
public class ActivationLayer extends Layer { public class ActivationLayer extends Layer {
Function<SimpleMatrix, SimpleMatrix> activation; Function<SimpleMatrix, SimpleMatrix> activation;
Function<SimpleMatrix, SimpleMatrix> activationPrime; Function<SimpleMatrix, SimpleMatrix> activationPrime;
/**
* Creates new activation layer with an activation function and its derivative from {@link de.lluni.javann.functions.ActivationFunctions}
* @param activation activation function
* @param activationPrime derivative of the activation function to compute gradient for backpropagation
*/
public ActivationLayer(Function<SimpleMatrix, SimpleMatrix> activation, Function<SimpleMatrix, SimpleMatrix> activationPrime) { public ActivationLayer(Function<SimpleMatrix, SimpleMatrix> activation, Function<SimpleMatrix, SimpleMatrix> activationPrime) {
this.activation = activation; this.activation = activation;
this.activationPrime = activationPrime; this.activationPrime = activationPrime;

View file

@ -5,12 +5,19 @@ import org.ejml.simple.SimpleMatrix;
import java.util.Random; import java.util.Random;
/**
* Fully connected layer with n Neurons
*/
public class FCLayer extends Layer { public class FCLayer extends Layer {
private SimpleMatrix weights; private SimpleMatrix weights;
private SimpleMatrix biases; private SimpleMatrix biases;
private int numNeurons; private int numNeurons;
private boolean isInitialized; private boolean isInitialized;
/**
* Creates a fully connected layer with numNeurons neurons
* @param numNeurons amount of neurons in this layer
*/
public FCLayer(int numNeurons) { public FCLayer(int numNeurons) {
this.numNeurons = numNeurons; this.numNeurons = numNeurons;
isInitialized = false; isInitialized = false;

View file

@ -14,12 +14,27 @@ import java.util.Random;
public class Utilities { public class Utilities {
private static final double STANDARD_GAUSSIAN_FACTOR = 1.0d; private static final double STANDARD_GAUSSIAN_FACTOR = 1.0d;
/**
* Creates a matrix filled with ones
* @param rows amount of rows
* @param columns amount of columns
* @return matrix filled with ones
*/
public static SimpleMatrix ones(int rows, int columns) { public static SimpleMatrix ones(int rows, int columns) {
SimpleMatrix mat = new SimpleMatrix(rows, columns); SimpleMatrix mat = new SimpleMatrix(rows, columns);
Arrays.fill(mat.getDDRM().data, 1); Arrays.fill(mat.getDDRM().data, 1);
return mat; return mat;
} }
/**
* Creates a matrix with Gaussian distributed values
* @param rows amount of rows
* @param columns amount of columns
* @param mean mean of Gaussian distribution
* @param stddev standard deviation of Gaussian distribution
* @param factor factor to multiply each Gaussian distributed value with
* @return matrix of Gaussian distributed values
*/
public static SimpleMatrix gaussianMatrix(int rows, int columns, double mean, double stddev, double factor) { public static SimpleMatrix gaussianMatrix(int rows, int columns, double mean, double stddev, double factor) {
SimpleMatrix mat = new SimpleMatrix(rows, columns); SimpleMatrix mat = new SimpleMatrix(rows, columns);
Random random = new Random(); Random random = new Random();
@ -35,6 +50,13 @@ public class Utilities {
return gaussianMatrix(rows, columns, mean, stddev, STANDARD_GAUSSIAN_FACTOR); return gaussianMatrix(rows, columns, mean, stddev, STANDARD_GAUSSIAN_FACTOR);
} }
/**
* Creates an array of evenly spaced values from the interval [start, end)
* @param start start value
* @param end end value
* @param num amount of values
* @return array of values
*/
public static double[] linspace(double start, double end, int num) { public static double[] linspace(double start, double end, int num) {
double[] result = new double[num]; double[] result = new double[num];
double stepSize = Math.abs(end - start) / num; double stepSize = Math.abs(end - start) / num;