Added support for choosing if the step size should decrease for each subsequent epoch

This commit is contained in:
lluni 2022-05-28 03:28:00 +02:00
parent faa547564c
commit c7154817ee
4 changed files with 10 additions and 5 deletions

View file

@ -79,8 +79,9 @@ public class Network {
* @param y_train labels
* @param epochs amount of training iterations
* @param learningRate step size of gradient descent
* @param optimize if step size should decrease for each subsequent epoch
*/
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) {
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate, boolean optimize) {
int samples = X_train.length;
for (int i = 0; i < epochs; i++) {
@ -98,7 +99,11 @@ public class Network {
// backward propagation
SimpleMatrix error = lossPrime.apply(y_train[j], output);
for (int k = layers.size() - 1; k >= 0; k--) {
error = layers.get(k).backwardPropagation(error, learningRate / (i+1));
if (optimize) {
error = layers.get(k).backwardPropagation(error, learningRate / (i+1));
} else {
error = layers.get(k).backwardPropagation(error, learningRate);
}
}
}
// calculate average error on all samples

View file

@ -57,7 +57,7 @@ public class ExampleSine {
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
// train network on X_train and y_train
network.fit(X_train, y_train, 100, 0.05d);
network.fit(X_train, y_train, 100, 0.05d, true);
// predict X_test and output results to console
SimpleMatrix[] output = network.predict(X_test);

View file

@ -26,7 +26,7 @@ public class ExampleXOR {
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
network.fit(X_train, y_train, 1000, 0.1d);
network.fit(X_train, y_train, 1000, 0.1d, false);
SimpleMatrix[] output = network.predict(X_train);
for (SimpleMatrix entry : output) {

View file

@ -27,7 +27,7 @@ public class ExampleXORAddedNeurons {
network.addNeuron(0, 2);
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
network.fit(X_train, y_train, 1000, 0.1d);
network.fit(X_train, y_train, 1000, 0.1d, false);
SimpleMatrix[] output = network.predict(X_train);
for (SimpleMatrix entry : output) {