Added support for choosing if the step size should decrease for each subsequent epoch
This commit is contained in:
parent
faa547564c
commit
c7154817ee
4 changed files with 10 additions and 5 deletions
|
@ -79,8 +79,9 @@ public class Network {
|
|||
* @param y_train labels
|
||||
* @param epochs amount of training iterations
|
||||
* @param learningRate step size of gradient descent
|
||||
* @param optimize if step size should decrease for each subsequent epoch
|
||||
*/
|
||||
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate) {
|
||||
public void fit(SimpleMatrix[] X_train, SimpleMatrix[] y_train, int epochs, double learningRate, boolean optimize) {
|
||||
int samples = X_train.length;
|
||||
|
||||
for (int i = 0; i < epochs; i++) {
|
||||
|
@ -98,7 +99,11 @@ public class Network {
|
|||
// backward propagation
|
||||
SimpleMatrix error = lossPrime.apply(y_train[j], output);
|
||||
for (int k = layers.size() - 1; k >= 0; k--) {
|
||||
if (optimize) {
|
||||
error = layers.get(k).backwardPropagation(error, learningRate / (i+1));
|
||||
} else {
|
||||
error = layers.get(k).backwardPropagation(error, learningRate);
|
||||
}
|
||||
}
|
||||
}
|
||||
// calculate average error on all samples
|
||||
|
|
|
@ -57,7 +57,7 @@ public class ExampleSine {
|
|||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||
|
||||
// train network on X_train and y_train
|
||||
network.fit(X_train, y_train, 100, 0.05d);
|
||||
network.fit(X_train, y_train, 100, 0.05d, true);
|
||||
|
||||
// predict X_test and output results to console
|
||||
SimpleMatrix[] output = network.predict(X_test);
|
||||
|
|
|
@ -26,7 +26,7 @@ public class ExampleXOR {
|
|||
network.addLayer(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||
|
||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||
network.fit(X_train, y_train, 1000, 0.1d);
|
||||
network.fit(X_train, y_train, 1000, 0.1d, false);
|
||||
|
||||
SimpleMatrix[] output = network.predict(X_train);
|
||||
for (SimpleMatrix entry : output) {
|
||||
|
|
|
@ -27,7 +27,7 @@ public class ExampleXORAddedNeurons {
|
|||
network.addNeuron(0, 2);
|
||||
|
||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||
network.fit(X_train, y_train, 1000, 0.1d);
|
||||
network.fit(X_train, y_train, 1000, 0.1d, false);
|
||||
|
||||
SimpleMatrix[] output = network.predict(X_train);
|
||||
for (SimpleMatrix entry : output) {
|
||||
|
|
Loading…
Reference in a new issue