import org.ejml.simple.SimpleMatrix; import java.util.function.Function; public class ActivationLayer extends Layer { Function activation; Function activationPrime; public ActivationLayer(Function activation, Function activationPrime) { this.activation = activation; this.activationPrime = activationPrime; } @Override public SimpleMatrix forwardPropagation(SimpleMatrix input) { this.input = input; this.output = activation.apply(input); return this.output; } @Override public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) { return activationPrime.apply(this.input).elementMult(outputError); } }