From 67d94efc39bc23a6968499da2a9a92ce1d8b1b73 Mon Sep 17 00:00:00 2001 From: lluni Date: Sat, 28 May 2022 02:10:02 +0200 Subject: [PATCH] Added JavaDoc comments --- src/main/java/de/lluni/javann/Network.java | 21 ++++++++++++++++-- .../lluni/javann/layers/ActivationLayer.java | 8 +++++++ .../java/de/lluni/javann/layers/FCLayer.java | 7 ++++++ .../java/de/lluni/javann/util/Utilities.java | 22 +++++++++++++++++++ 4 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/main/java/de/lluni/javann/Network.java b/src/main/java/de/lluni/javann/Network.java index a7c5f40..d7e2f6d 100644 --- a/src/main/java/de/lluni/javann/Network.java +++ b/src/main/java/de/lluni/javann/Network.java @@ -21,8 +21,7 @@ public class Network { } /** - * Adds n neurons to a specific layer and also updates this and the next layer's weights and biases. - * Only works if there are two successive BlankLayers. + * Adds n neurons to a specific layer and also updates this and the next layer's weights and biases * @param layer index of layer in the ArrayList layers * @param n amount how many new neurons should be added */ @@ -36,11 +35,22 @@ public class Network { ((FCLayer) this.layers.get(layer + 2)).updateInputSize(n); } + /** + * Sets loss function used to evaluate the network + * Chosen from {@link de.lluni.javann.functions.LossFunctions} + * @param loss loss function + * @param lossPrime derivative of the loss function + */ public void use(BiFunction loss, BiFunction lossPrime) { this.loss = loss; this.lossPrime = lossPrime; } + /** + * Predicts labels corresponding to the given inputs + * @param inputs inputs + * @return array of predictions + */ public SimpleMatrix[] predict(SimpleMatrix[] inputs) { SimpleMatrix[] result = new SimpleMatrix[inputs.length]; SimpleMatrix output; @@ -58,6 +68,13 @@ public class Network { return result; } + /** + * Trains the network on a set of training data + * @param X_train inputs + * @param y_train labels + * @param epochs amount of training iterations + * @param learningRate step size of gradient descent + */ public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) { int samples = X_train.length; diff --git a/src/main/java/de/lluni/javann/layers/ActivationLayer.java b/src/main/java/de/lluni/javann/layers/ActivationLayer.java index 49c193c..b274b0c 100644 --- a/src/main/java/de/lluni/javann/layers/ActivationLayer.java +++ b/src/main/java/de/lluni/javann/layers/ActivationLayer.java @@ -4,10 +4,18 @@ import org.ejml.simple.SimpleMatrix; import java.util.function.Function; +/** + * Layer to pass each output of every neuron from the previous layer through an activation function + */ public class ActivationLayer extends Layer { Function activation; Function activationPrime; + /** + * Creates new activation layer with an activation function and its derivative from {@link de.lluni.javann.functions.ActivationFunctions} + * @param activation activation function + * @param activationPrime derivative of the activation function to compute gradient for backpropagation + */ public ActivationLayer(Function activation, Function activationPrime) { this.activation = activation; this.activationPrime = activationPrime; diff --git a/src/main/java/de/lluni/javann/layers/FCLayer.java b/src/main/java/de/lluni/javann/layers/FCLayer.java index 0f33a9a..21bdc7b 100644 --- a/src/main/java/de/lluni/javann/layers/FCLayer.java +++ b/src/main/java/de/lluni/javann/layers/FCLayer.java @@ -5,12 +5,19 @@ import org.ejml.simple.SimpleMatrix; import java.util.Random; +/** + * Fully connected layer with n Neurons + */ public class FCLayer extends Layer { private SimpleMatrix weights; private SimpleMatrix biases; private int numNeurons; private boolean isInitialized; + /** + * Creates a fully connected layer with numNeurons neurons + * @param numNeurons amount of neurons in this layer + */ public FCLayer(int numNeurons) { this.numNeurons = numNeurons; isInitialized = false; diff --git a/src/main/java/de/lluni/javann/util/Utilities.java b/src/main/java/de/lluni/javann/util/Utilities.java index 2daf693..cadd566 100644 --- a/src/main/java/de/lluni/javann/util/Utilities.java +++ b/src/main/java/de/lluni/javann/util/Utilities.java @@ -14,12 +14,27 @@ import java.util.Random; public class Utilities { private static final double STANDARD_GAUSSIAN_FACTOR = 1.0d; + /** + * Creates a matrix filled with ones + * @param rows amount of rows + * @param columns amount of columns + * @return matrix filled with ones + */ public static SimpleMatrix ones(int rows, int columns) { SimpleMatrix mat = new SimpleMatrix(rows, columns); Arrays.fill(mat.getDDRM().data, 1); return mat; } + /** + * Creates a matrix with Gaussian distributed values + * @param rows amount of rows + * @param columns amount of columns + * @param mean mean of Gaussian distribution + * @param stddev standard deviation of Gaussian distribution + * @param factor factor to multiply each Gaussian distributed value with + * @return matrix of Gaussian distributed values + */ public static SimpleMatrix gaussianMatrix(int rows, int columns, double mean, double stddev, double factor) { SimpleMatrix mat = new SimpleMatrix(rows, columns); Random random = new Random(); @@ -35,6 +50,13 @@ public class Utilities { return gaussianMatrix(rows, columns, mean, stddev, STANDARD_GAUSSIAN_FACTOR); } + /** + * Creates an array of evenly spaced values from the interval [start, end) + * @param start start value + * @param end end value + * @param num amount of values + * @return array of values + */ public static double[] linspace(double start, double end, int num) { double[] result = new double[num]; double stepSize = Math.abs(end - start) / num;