diff --git a/src/main/java/ExampleGradientDescent.java b/src/main/java/ExampleGradientDescent.java new file mode 100644 index 0000000..efd98c7 --- /dev/null +++ b/src/main/java/ExampleGradientDescent.java @@ -0,0 +1,16 @@ +import org.ejml.simple.SimpleMatrix; + +import java.util.function.Function; + +public class ExampleGradientDescent { + public static void main(String[] args) { + GradientDescent gd = new GradientDescent(); + + Function f = x -> x*x; + System.out.println(gd.findLocalMinimum(f, 1)); + + Function g = x -> x.elementMult(x); + SimpleMatrix initialX = new SimpleMatrix(2, 1, true, new double[]{1, 0.5}); + System.out.println(gd.findLocalMinimum(g, initialX)); + } +} diff --git a/src/main/java/GradientDescent.java b/src/main/java/GradientDescent.java index d018323..f4e0e68 100644 --- a/src/main/java/GradientDescent.java +++ b/src/main/java/GradientDescent.java @@ -86,17 +86,6 @@ public class GradientDescent { return findLocalMinimum(f, initialX, STANDARD_MAX_ITERATIONS); } - public static void main(String[] args) { - GradientDescent gd = new GradientDescent(); - - Function f = x -> x*x; - System.out.println(gd.findLocalMinimum(f, 1)); - - Function g = x -> x.elementMult(x); - SimpleMatrix initialX = new SimpleMatrix(2, 1, true, new double[]{1, 0.5}); - System.out.println(gd.findLocalMinimum(g, initialX)); - } - public double getPrecision() { return precision; }