Compare commits
No commits in common. "ffcf9fa975a02e53b19e838d00566750e2026134" and "1c66f1b72feba721e1f34bf9fcfaaa6d9f6db038" have entirely different histories.
ffcf9fa975
...
1c66f1b72f
2 changed files with 13 additions and 96 deletions
|
@ -1,16 +0,0 @@
|
||||||
import org.ejml.simple.SimpleMatrix;
|
|
||||||
|
|
||||||
import java.util.function.Function;
|
|
||||||
|
|
||||||
public class ExampleGradientDescent {
|
|
||||||
public static void main(String[] args) {
|
|
||||||
GradientDescent gd = new GradientDescent();
|
|
||||||
|
|
||||||
Function<Double, Double> f = x -> x*x;
|
|
||||||
System.out.println(gd.findLocalMinimum(f, 1));
|
|
||||||
|
|
||||||
Function<SimpleMatrix, SimpleMatrix> g = x -> x.elementMult(x);
|
|
||||||
SimpleMatrix initialX = new SimpleMatrix(2, 1, true, new double[]{1, 0.5});
|
|
||||||
System.out.println(gd.findLocalMinimum(g, initialX));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,104 +1,37 @@
|
||||||
import org.ejml.simple.SimpleMatrix;
|
|
||||||
|
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
|
|
||||||
public class GradientDescent {
|
public class GradientDescent {
|
||||||
private static final double STANDARD_PRECISION = 0.000001;
|
|
||||||
private static final double STANDARD_STEP_COEFFICIENT = 0.5;
|
|
||||||
private static final int STANDARD_MAX_ITERATIONS = 1000;
|
|
||||||
|
|
||||||
private double precision;
|
private final double precision = 0.000001;
|
||||||
private double stepCoefficient;
|
|
||||||
|
|
||||||
public GradientDescent(double precision, double stepCoefficient) {
|
public double findLocalMinimum(Function<Double, Double> f, double initialX) {
|
||||||
this.precision = precision;
|
double stepCoefficient = 0.5;
|
||||||
this.stepCoefficient = stepCoefficient;
|
|
||||||
}
|
|
||||||
|
|
||||||
public GradientDescent() {
|
|
||||||
this(STANDARD_PRECISION, STANDARD_STEP_COEFFICIENT);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Performs gradient descent on a function f: ℝ -> ℝ
|
|
||||||
* @param f real-valued function
|
|
||||||
* @param initialX initial X vector
|
|
||||||
* @param maxIterations maximum number of iterations
|
|
||||||
* @return approximation of the nearest local minimum
|
|
||||||
*/
|
|
||||||
public double findLocalMinimum(Function<Double, Double> f, double initialX, int maxIterations) {
|
|
||||||
double previousStep = 1.0;
|
double previousStep = 1.0;
|
||||||
double currentX = initialX;
|
double currentX = initialX;
|
||||||
double previousX = initialX;
|
double previousX = initialX;
|
||||||
double previousY = f.apply(previousX);
|
double previousY = f.apply(previousX);
|
||||||
|
int iter = 1000;
|
||||||
|
|
||||||
currentX += this.stepCoefficient * previousY;
|
currentX += stepCoefficient * previousY;
|
||||||
|
|
||||||
while (previousStep > this.precision && maxIterations > 0) {
|
while (previousStep > precision && iter > 0) {
|
||||||
maxIterations--;
|
iter--;
|
||||||
double currentY = f.apply(currentX);
|
double currentY = f.apply(currentX);
|
||||||
if (currentY > previousY) {
|
if (currentY > previousY) {
|
||||||
this.stepCoefficient = -this.stepCoefficient / 2;
|
stepCoefficient = -stepCoefficient / 2;
|
||||||
}
|
}
|
||||||
previousX = currentX;
|
previousX = currentX;
|
||||||
currentX += this.stepCoefficient * previousY;
|
currentX += stepCoefficient * previousY;
|
||||||
previousY = currentY;
|
previousY = currentY;
|
||||||
previousStep = StrictMath.abs(currentX - previousX);
|
previousStep = StrictMath.abs(currentX - previousX);
|
||||||
}
|
}
|
||||||
return currentX;
|
return currentX;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double findLocalMinimum(Function<Double, Double> f, double initialX) {
|
public static void main(String[] args) {
|
||||||
return findLocalMinimum(f, initialX, STANDARD_MAX_ITERATIONS);
|
GradientDescent gd = new GradientDescent();
|
||||||
}
|
Function<Double, Double> f = x -> x*x;
|
||||||
|
|
||||||
/**
|
System.out.println(gd.findLocalMinimum(f, 1));
|
||||||
* Performs gradient descent on a function f: ℝⁿ -> ℝⁿ.
|
|
||||||
* @param f vector-valued function
|
|
||||||
* @param initialX initial X vector
|
|
||||||
* @param maxIterations maximum number of iterations
|
|
||||||
* @return approximation of the nearest local minimum
|
|
||||||
*/
|
|
||||||
public SimpleMatrix findLocalMinimum(Function<SimpleMatrix, SimpleMatrix> f,
|
|
||||||
SimpleMatrix initialX, int maxIterations) {
|
|
||||||
double previousStep = 1.0;
|
|
||||||
SimpleMatrix currentX = initialX;
|
|
||||||
SimpleMatrix previousX = initialX;
|
|
||||||
SimpleMatrix previousY = f.apply(previousX);
|
|
||||||
|
|
||||||
currentX = currentX.plus(this.stepCoefficient, previousY);
|
|
||||||
|
|
||||||
while (previousStep > this.precision && maxIterations > 0) {
|
|
||||||
maxIterations--;
|
|
||||||
SimpleMatrix currentY = f.apply(currentX);
|
|
||||||
if (currentY.normF() > previousY.normF()) {
|
|
||||||
this.stepCoefficient = -this.stepCoefficient / 2;
|
|
||||||
}
|
|
||||||
previousX = currentX;
|
|
||||||
currentX = currentX.plus(this.stepCoefficient, previousY);
|
|
||||||
previousY = currentY;
|
|
||||||
previousStep = currentX.minus(previousX).normF();
|
|
||||||
}
|
|
||||||
return currentX;
|
|
||||||
}
|
|
||||||
|
|
||||||
public SimpleMatrix findLocalMinimum(Function<SimpleMatrix, SimpleMatrix> f, SimpleMatrix initialX) {
|
|
||||||
return findLocalMinimum(f, initialX, STANDARD_MAX_ITERATIONS);
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getPrecision() {
|
|
||||||
return precision;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setPrecision(double precision) {
|
|
||||||
this.precision = precision <= 0 ? STANDARD_PRECISION : precision;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getStepCoefficient() {
|
|
||||||
return stepCoefficient;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setStepCoefficient(double stepCoefficient) {
|
|
||||||
this.stepCoefficient = stepCoefficient;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue