diff --git a/src/main/java/de/lluni/javann/Network.java b/src/main/java/de/lluni/javann/Network.java index f03555b..d62648d 100644 --- a/src/main/java/de/lluni/javann/Network.java +++ b/src/main/java/de/lluni/javann/Network.java @@ -79,8 +79,9 @@ public class Network { * @param y_train labels * @param epochs amount of training iterations * @param learningRate step size of gradient descent + * @param optimize if step size should decrease for each subsequent epoch */ - public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) { + public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate, boolean optimize) { int samples = X_train.length; for (int i = 0; i < epochs; i++) { @@ -98,7 +99,11 @@ public class Network { // backward propagation SimpleMatrix error = lossPrime.apply(y_train[j], output); for (int k = layers.size() - 1; k >= 0; k--) { - error = layers.get(k).backwardPropagation(error, learningRate / (i+1)); + if (optimize) { + error = layers.get(k).backwardPropagation(error, learningRate / (i+1)); + } else { + error = layers.get(k).backwardPropagation(error, learningRate); + } } } // calculate average error on all samples diff --git a/src/main/java/de/lluni/javann/examples/ExampleSine.java b/src/main/java/de/lluni/javann/examples/ExampleSine.java index c0701d2..1572b0a 100644 --- a/src/main/java/de/lluni/javann/examples/ExampleSine.java +++ b/src/main/java/de/lluni/javann/examples/ExampleSine.java @@ -57,7 +57,7 @@ public class ExampleSine { network.use(LossFunctions::MSE, LossFunctions::MSEPrime); // train network on X_train and y_train - network.fit(X_train, y_train, 100, 0.05d); + network.fit(X_train, y_train, 100, 0.05d, true); // predict X_test and output results to console SimpleMatrix[] output = network.predict(X_test); diff --git a/src/main/java/de/lluni/javann/examples/ExampleXOR.java b/src/main/java/de/lluni/javann/examples/ExampleXOR.java index f647673..7f4d56f 100644 --- a/src/main/java/de/lluni/javann/examples/ExampleXOR.java +++ b/src/main/java/de/lluni/javann/examples/ExampleXOR.java @@ -26,7 +26,7 @@ public class ExampleXOR { network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.use(LossFunctions::MSE, LossFunctions::MSEPrime); - network.fit(X_train, y_train, 1000, 0.1d); + network.fit(X_train, y_train, 1000, 0.1d, false); SimpleMatrix[] output = network.predict(X_train); for (SimpleMatrix entry : output) { diff --git a/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java b/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java index b19ab58..af8c31a 100644 --- a/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java +++ b/src/main/java/de/lluni/javann/examples/ExampleXORAddedNeurons.java @@ -27,7 +27,7 @@ public class ExampleXORAddedNeurons { network.addNeuron(0, 2); network.use(LossFunctions::MSE, LossFunctions::MSEPrime); - network.fit(X_train, y_train, 1000, 0.1d); + network.fit(X_train, y_train, 1000, 0.1d, false); SimpleMatrix[] output = network.predict(X_train); for (SimpleMatrix entry : output) {