From db0481e9cf8b36eb6404ed43a1ac98b5c3f2e6e5 Mon Sep 17 00:00:00 2001 From: lluni Date: Wed, 25 May 2022 02:39:18 +0200 Subject: [PATCH] Added gradient descent for vector-valued functions --- src/main/java/GradientDescent.java | 98 +++++++++++++++++++++++++++--- 1 file changed, 88 insertions(+), 10 deletions(-) diff --git a/src/main/java/GradientDescent.java b/src/main/java/GradientDescent.java index cd0cb88..d018323 100644 --- a/src/main/java/GradientDescent.java +++ b/src/main/java/GradientDescent.java @@ -1,37 +1,115 @@ +import org.ejml.simple.SimpleMatrix; + import java.util.function.Function; public class GradientDescent { + private static final double STANDARD_PRECISION = 0.000001; + private static final double STANDARD_STEP_COEFFICIENT = 0.5; + private static final int STANDARD_MAX_ITERATIONS = 1000; - private final double precision = 0.000001; + private double precision; + private double stepCoefficient; - public double findLocalMinimum(Function f, double initialX) { - double stepCoefficient = 0.5; + public GradientDescent(double precision, double stepCoefficient) { + this.precision = precision; + this.stepCoefficient = stepCoefficient; + } + + public GradientDescent() { + this(STANDARD_PRECISION, STANDARD_STEP_COEFFICIENT); + } + + /** + * Performs gradient descent on a function f: ℝ -> ℝ + * @param f vector-valued function + * @param initialX initial X vector + * @param maxIterations maximum number of iterations + * @return approximation of the nearest local minimum + */ + public double findLocalMinimum(Function f, double initialX, int maxIterations) { double previousStep = 1.0; double currentX = initialX; double previousX = initialX; double previousY = f.apply(previousX); - int iter = 1000; - currentX += stepCoefficient * previousY; + currentX += this.stepCoefficient * previousY; - while (previousStep > precision && iter > 0) { - iter--; + while (previousStep > this.precision && maxIterations > 0) { + maxIterations--; double currentY = f.apply(currentX); if (currentY > previousY) { - stepCoefficient = -stepCoefficient / 2; + this.stepCoefficient = -this.stepCoefficient / 2; } previousX = currentX; - currentX += stepCoefficient * previousY; + currentX += this.stepCoefficient * previousY; previousY = currentY; previousStep = StrictMath.abs(currentX - previousX); } return currentX; } + public double findLocalMinimum(Function f, double initialX) { + return findLocalMinimum(f, initialX, STANDARD_MAX_ITERATIONS); + } + + /** + * Performs gradient descent on a function f: ℝⁿ -> ℝⁿ. + * @param f vector-valued function + * @param initialX initial X vector + * @param maxIterations maximum number of iterations + * @return approximation of the nearest local minimum + */ + public SimpleMatrix findLocalMinimum(Function f, + SimpleMatrix initialX, int maxIterations) { + double previousStep = 1.0; + SimpleMatrix currentX = initialX; + SimpleMatrix previousX = initialX; + SimpleMatrix previousY = f.apply(previousX); + + currentX = currentX.plus(this.stepCoefficient, previousY); + + while (previousStep > this.precision && maxIterations > 0) { + maxIterations--; + SimpleMatrix currentY = f.apply(currentX); + if (currentY.normF() > previousY.normF()) { + this.stepCoefficient = -this.stepCoefficient / 2; + } + previousX = currentX; + currentX = currentX.plus(this.stepCoefficient, previousY); + previousY = currentY; + previousStep = currentX.minus(previousX).normF(); + } + return currentX; + } + + public SimpleMatrix findLocalMinimum(Function f, SimpleMatrix initialX) { + return findLocalMinimum(f, initialX, STANDARD_MAX_ITERATIONS); + } + public static void main(String[] args) { GradientDescent gd = new GradientDescent(); - Function f = x -> x*x; + Function f = x -> x*x; System.out.println(gd.findLocalMinimum(f, 1)); + + Function g = x -> x.elementMult(x); + SimpleMatrix initialX = new SimpleMatrix(2, 1, true, new double[]{1, 0.5}); + System.out.println(gd.findLocalMinimum(g, initialX)); + } + + public double getPrecision() { + return precision; + } + + public void setPrecision(double precision) { + this.precision = precision <= 0 ? STANDARD_PRECISION : precision; + } + + public double getStepCoefficient() { + return stepCoefficient; + } + + public void setStepCoefficient(double stepCoefficient) { + this.stepCoefficient = stepCoefficient; } }