Compare commits

...

3 commits

2 changed files with 96 additions and 13 deletions

View file

@ -0,0 +1,16 @@
import org.ejml.simple.SimpleMatrix;
import java.util.function.Function;
public class ExampleGradientDescent {
public static void main(String[] args) {
GradientDescent gd = new GradientDescent();
Function<Double, Double> f = x -> x*x;
System.out.println(gd.findLocalMinimum(f, 1));
Function<SimpleMatrix, SimpleMatrix> g = x -> x.elementMult(x);
SimpleMatrix initialX = new SimpleMatrix(2, 1, true, new double[]{1, 0.5});
System.out.println(gd.findLocalMinimum(g, initialX));
}
}

View file

@ -1,37 +1,104 @@
import org.ejml.simple.SimpleMatrix;
import java.util.function.Function; import java.util.function.Function;
public class GradientDescent { public class GradientDescent {
private static final double STANDARD_PRECISION = 0.000001;
private static final double STANDARD_STEP_COEFFICIENT = 0.5;
private static final int STANDARD_MAX_ITERATIONS = 1000;
private final double precision = 0.000001; private double precision;
private double stepCoefficient;
public double findLocalMinimum(Function<Double, Double> f, double initialX) { public GradientDescent(double precision, double stepCoefficient) {
double stepCoefficient = 0.5; this.precision = precision;
this.stepCoefficient = stepCoefficient;
}
public GradientDescent() {
this(STANDARD_PRECISION, STANDARD_STEP_COEFFICIENT);
}
/**
* Performs gradient descent on a function f: ->
* @param f real-valued function
* @param initialX initial X vector
* @param maxIterations maximum number of iterations
* @return approximation of the nearest local minimum
*/
public double findLocalMinimum(Function<Double, Double> f, double initialX, int maxIterations) {
double previousStep = 1.0; double previousStep = 1.0;
double currentX = initialX; double currentX = initialX;
double previousX = initialX; double previousX = initialX;
double previousY = f.apply(previousX); double previousY = f.apply(previousX);
int iter = 1000;
currentX += stepCoefficient * previousY; currentX += this.stepCoefficient * previousY;
while (previousStep > precision && iter > 0) { while (previousStep > this.precision && maxIterations > 0) {
iter--; maxIterations--;
double currentY = f.apply(currentX); double currentY = f.apply(currentX);
if (currentY > previousY) { if (currentY > previousY) {
stepCoefficient = -stepCoefficient / 2; this.stepCoefficient = -this.stepCoefficient / 2;
} }
previousX = currentX; previousX = currentX;
currentX += stepCoefficient * previousY; currentX += this.stepCoefficient * previousY;
previousY = currentY; previousY = currentY;
previousStep = StrictMath.abs(currentX - previousX); previousStep = StrictMath.abs(currentX - previousX);
} }
return currentX; return currentX;
} }
public static void main(String[] args) { public double findLocalMinimum(Function<Double, Double> f, double initialX) {
GradientDescent gd = new GradientDescent(); return findLocalMinimum(f, initialX, STANDARD_MAX_ITERATIONS);
Function<Double, Double> f = x -> x*x; }
System.out.println(gd.findLocalMinimum(f, 1)); /**
* Performs gradient descent on a function f: ℝⁿ -> ℝⁿ.
* @param f vector-valued function
* @param initialX initial X vector
* @param maxIterations maximum number of iterations
* @return approximation of the nearest local minimum
*/
public SimpleMatrix findLocalMinimum(Function<SimpleMatrix, SimpleMatrix> f,
SimpleMatrix initialX, int maxIterations) {
double previousStep = 1.0;
SimpleMatrix currentX = initialX;
SimpleMatrix previousX = initialX;
SimpleMatrix previousY = f.apply(previousX);
currentX = currentX.plus(this.stepCoefficient, previousY);
while (previousStep > this.precision && maxIterations > 0) {
maxIterations--;
SimpleMatrix currentY = f.apply(currentX);
if (currentY.normF() > previousY.normF()) {
this.stepCoefficient = -this.stepCoefficient / 2;
}
previousX = currentX;
currentX = currentX.plus(this.stepCoefficient, previousY);
previousY = currentY;
previousStep = currentX.minus(previousX).normF();
}
return currentX;
}
public SimpleMatrix findLocalMinimum(Function<SimpleMatrix, SimpleMatrix> f, SimpleMatrix initialX) {
return findLocalMinimum(f, initialX, STANDARD_MAX_ITERATIONS);
}
public double getPrecision() {
return precision;
}
public void setPrecision(double precision) {
this.precision = precision <= 0 ? STANDARD_PRECISION : precision;
}
public double getStepCoefficient() {
return stepCoefficient;
}
public void setStepCoefficient(double stepCoefficient) {
this.stepCoefficient = stepCoefficient;
} }
} }