Added support for choosing if the step size should decrease for each subsequent epoch

This commit is contained in:
lluni 2022-05-28 03:28:00 +02:00
parent faa547564c
commit c7154817ee
4 changed files with 10 additions and 5 deletions

View file

@ -79,8 +79,9 @@ public class Network {
* @param y_train labels * @param y_train labels
* @param epochs amount of training iterations * @param epochs amount of training iterations
* @param learningRate step size of gradient descent * @param learningRate step size of gradient descent
* @param optimize if step size should decrease for each subsequent epoch
*/ */
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) { public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate, boolean optimize) {
int samples = X_train.length; int samples = X_train.length;
for (int i = 0; i < epochs; i++) { for (int i = 0; i < epochs; i++) {
@ -98,7 +99,11 @@ public class Network {
// backward propagation // backward propagation
SimpleMatrix error = lossPrime.apply(y_train[j], output); SimpleMatrix error = lossPrime.apply(y_train[j], output);
for (int k = layers.size() - 1; k >= 0; k--) { for (int k = layers.size() - 1; k >= 0; k--) {
if (optimize) {
error = layers.get(k).backwardPropagation(error, learningRate / (i+1)); error = layers.get(k).backwardPropagation(error, learningRate / (i+1));
} else {
error = layers.get(k).backwardPropagation(error, learningRate);
}
} }
} }
// calculate average error on all samples // calculate average error on all samples

View file

@ -57,7 +57,7 @@ public class ExampleSine {
network.use(LossFunctions::MSE, LossFunctions::MSEPrime); network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
// train network on X_train and y_train // train network on X_train and y_train
network.fit(X_train, y_train, 100, 0.05d); network.fit(X_train, y_train, 100, 0.05d, true);
// predict X_test and output results to console // predict X_test and output results to console
SimpleMatrix[] output = network.predict(X_test); SimpleMatrix[] output = network.predict(X_test);

View file

@ -26,7 +26,7 @@ public class ExampleXOR {
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime)); network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.use(LossFunctions::MSE, LossFunctions::MSEPrime); network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
network.fit(X_train, y_train, 1000, 0.1d); network.fit(X_train, y_train, 1000, 0.1d, false);
SimpleMatrix[] output = network.predict(X_train); SimpleMatrix[] output = network.predict(X_train);
for (SimpleMatrix entry : output) { for (SimpleMatrix entry : output) {

View file

@ -27,7 +27,7 @@ public class ExampleXORAddedNeurons {
network.addNeuron(0, 2); network.addNeuron(0, 2);
network.use(LossFunctions::MSE, LossFunctions::MSEPrime); network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
network.fit(X_train, y_train, 1000, 0.1d); network.fit(X_train, y_train, 1000, 0.1d, false);
SimpleMatrix[] output = network.predict(X_train); SimpleMatrix[] output = network.predict(X_train);
for (SimpleMatrix entry : output) { for (SimpleMatrix entry : output) {