diff --git a/src/main/java/de/lluni/javann/Network.java b/src/main/java/de/lluni/javann/Network.java index 680cd36..a7c5f40 100644 --- a/src/main/java/de/lluni/javann/Network.java +++ b/src/main/java/de/lluni/javann/Network.java @@ -76,7 +76,7 @@ public class Network { // backward propagation SimpleMatrix error = lossPrime.apply(y_train[j], output); for (int k = layers.size() - 1; k >= 0; k--) { - error = layers.get(k).backwardPropagation(error, learningRate); + error = layers.get(k).backwardPropagation(error, learningRate / (i+1)); } } // calculate average error on all samples