25 lines
836 B
Java
25 lines
836 B
Java
import org.ejml.simple.SimpleMatrix;
|
|
|
|
import java.util.function.Function;
|
|
|
|
public class ActivationLayer extends Layer {
|
|
Function<SimpleMatrix, SimpleMatrix> activation;
|
|
Function<SimpleMatrix, SimpleMatrix> activationPrime;
|
|
|
|
public ActivationLayer(Function<SimpleMatrix, SimpleMatrix> activation, Function<SimpleMatrix, SimpleMatrix> activationPrime) {
|
|
this.activation = activation;
|
|
this.activationPrime = activationPrime;
|
|
}
|
|
|
|
@Override
|
|
public SimpleMatrix forwardPropagation(SimpleMatrix input) {
|
|
this.input = input;
|
|
this.output = activation.apply(input);
|
|
return this.output;
|
|
}
|
|
|
|
@Override
|
|
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
|
|
return activationPrime.apply(this.input).elementMult(outputError);
|
|
}
|
|
}
|