diff --git a/src/main/java/Utilities.java b/src/main/java/Utilities.java index f444927..4c734fc 100644 --- a/src/main/java/Utilities.java +++ b/src/main/java/Utilities.java @@ -7,14 +7,45 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Random; public class Utilities { + private static final double STANDARD_GAUSSIAN_FACTOR = 1.0d; + public static SimpleMatrix ones(int rows, int columns) { SimpleMatrix mat = new SimpleMatrix(rows, columns); Arrays.fill(mat.getDDRM().data, 1); return mat; } + public static SimpleMatrix gaussianMatrix(int rows, int columns, double mean, double stddev, double factor) { + SimpleMatrix mat = new SimpleMatrix(rows, columns); + Random random = new Random(); + + for (int i = 0; i < mat.getNumElements(); i++) { + mat.set(i, factor * random.nextGaussian(mean, stddev)); + } + + return mat; + } + + public static SimpleMatrix gaussianMatrix(int rows, int columns, double mean, double stddev) { + return gaussianMatrix(rows, columns, mean, stddev, STANDARD_GAUSSIAN_FACTOR); + } + + public static double[] linspace(double start, double end, int num) { + double[] result = new double[num]; + double stepSize = Math.abs(end - start) / num; + double nextEntry = start; + + for (int i = 0; i < num; i++) { + result[i] = nextEntry; + nextEntry += stepSize; + } + + return result; + } + public static List> readCSV(String filename) { List> entries = new ArrayList<>(); try (CSVReader csvReader = new CSVReader(new FileReader(filename))) {