Added support for different activation and loss functions

This commit is contained in:
lluni 2022-05-22 01:11:38 +02:00
parent 645a8baf2c
commit 5aa9313776
7 changed files with 119 additions and 61 deletions

View file

@ -1,5 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="GradleMigrationSettings" migrationVersion="1" />
<component name="GradleSettings">
<option name="linkedExternalProjectsSettings">
<GradleProjectSettings>

View file

@ -4,11 +4,13 @@
<option name="autoReloadType" value="SELECTIVE" />
</component>
<component name="ChangeListManager">
<list default="true" id="75a8c215-b746-4a4d-aa16-f3223c12b1ed" name="Changes" comment="Initial commit">
<list default="true" id="75a8c215-b746-4a4d-aa16-f3223c12b1ed" name="Changes" comment="Commented out not yet working code">
<change afterPath="$PROJECT_DIR$/src/main/java/ActivationFunctions.java" afterDir="false" />
<change afterPath="$PROJECT_DIR$/src/main/java/LossFunctions.java" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/gradle.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/gradle.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/src/main/java/ActivationLayer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/ActivationLayer.java" afterDir="false" />
<change beforePath="$PROJECT_DIR$/src/main/java/FCLayer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/FCLayer.java" afterDir="false" />
<change beforePath="$PROJECT_DIR$/src/main/java/Layer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/Layer.java" afterDir="false" />
<change beforePath="$PROJECT_DIR$/src/main/java/ExampleXOR.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/ExampleXOR.java" afterDir="false" />
<change beforePath="$PROJECT_DIR$/src/main/java/Network.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/Network.java" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
@ -21,6 +23,23 @@
<ProjectState />
</projectState>
</component>
<component name="ExternalProjectsManager">
<system id="GRADLE">
<state>
<projects_view>
<tree_state>
<expand>
<path>
<item name="" type="6a2764b6:ExternalProjectsStructure$RootNode" />
<item name="JavaNN" type="f1a62948:ProjectNode" />
</path>
</expand>
<select />
</tree_state>
</projects_view>
</state>
</system>
</component>
<component name="FileTemplateManagerImpl">
<option name="RECENT_TEMPLATES">
<list>
@ -29,6 +48,11 @@
</option>
</component>
<component name="Git.Settings">
<option name="RECENT_BRANCH_BY_REPOSITORY">
<map>
<entry key="$PROJECT_DIR$" value="a431e08264a1ba49995b5b45ab2e2233f40a9a1d" />
</map>
</option>
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="MarkdownSettingsMigration">
@ -59,6 +83,7 @@
<recent_temporary>
<list>
<item itemvalue="Application.ExampleXOR" />
<item itemvalue="Application.ExampleXOR" />
</list>
</recent_temporary>
</component>
@ -71,6 +96,14 @@
<option name="presentableId" value="Default" />
<updated>1653154357924</updated>
</task>
<task id="LOCAL-00001" summary="Commented out not yet working code">
<created>1653168727185</created>
<option name="number" value="00001" />
<option name="presentableId" value="LOCAL-00001" />
<option name="project" value="LOCAL" />
<updated>1653168727186</updated>
</task>
<option name="localTasksCounter" value="2" />
<servers />
</component>
<component name="Vcs.Log.Tabs.Properties">
@ -86,9 +119,10 @@
</component>
<component name="VcsManagerConfiguration">
<MESSAGE value="Initial commit" />
<option name="LAST_COMMIT_MESSAGE" value="Initial commit" />
<MESSAGE value="Commented out not yet working code" />
<option name="LAST_COMMIT_MESSAGE" value="Commented out not yet working code" />
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/JavaNN$ExampleXOR.ic" NAME="ExampleXOR Coverage Results" MODIFIED="1653164239216" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="idea" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" />
<SUITE FILE_PATH="coverage/JavaNN$ExampleXOR.ic" NAME="ExampleXOR Coverage Results" MODIFIED="1653174521293" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="idea" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" />
</component>
</project>

View file

@ -0,0 +1,39 @@
import org.ejml.simple.SimpleMatrix;
public class ActivationFunctions {
public static SimpleMatrix tanh(SimpleMatrix A) {
SimpleMatrix B = new SimpleMatrix(A);
for (int i = 0; i < A.getNumElements(); i++) {
B.set(i, Math.tanh(A.get(i)));
}
return B;
}
public static SimpleMatrix tanhPrime(SimpleMatrix A) {
SimpleMatrix B = new SimpleMatrix(A);
for (int i = 0; i < A.getNumElements(); i++) {
B.set(i, 1 - Math.pow(Math.tanh(A.get(i)), 2));
}
return B;
}
public static SimpleMatrix logistic(SimpleMatrix A) {
SimpleMatrix B = new SimpleMatrix(A);
for (int i = 0; i < A.getNumElements(); i++) {
B.set(i, sigma(-A.get(i)));
}
return B;
}
public static SimpleMatrix logisticPrime(SimpleMatrix A) {
SimpleMatrix B = new SimpleMatrix(A);
for (int i = 0; i < A.getNumElements(); i++) {
B.set(i, sigma(A.get(i) * (1 - sigma(A.get(i)))));
}
return B;
}
private static double sigma(double value) {
return 1 / (1 + Math.exp(-value));
}
}

View file

@ -1,7 +1,8 @@
import org.ejml.simple.SimpleMatrix;
import java.util.function.Function;
public class ActivationLayer extends Layer {
/* custom activation functions not yet supported
Function<SimpleMatrix, SimpleMatrix> activation;
Function<SimpleMatrix, SimpleMatrix> activationPrime;
@ -9,35 +10,16 @@ public class ActivationLayer extends Layer {
this.activation = activation;
this.activationPrime = activationPrime;
}
*/
public ActivationLayer() {}
@Override
public SimpleMatrix forwardPropagation(SimpleMatrix input) {
this.input = input;
this.output = tanh(input);
this.output = activation.apply(input);
return this.output;
}
@Override
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
return tanhPrime(this.input).elementMult(outputError);
}
private SimpleMatrix tanh(SimpleMatrix A) {
SimpleMatrix B = new SimpleMatrix(A);
for (int i = 0; i < A.getNumElements(); i++) {
B.set(i, Math.tanh(A.get(i)));
}
return B;
}
private SimpleMatrix tanhPrime(SimpleMatrix A) {
SimpleMatrix B = new SimpleMatrix(A);
for (int i = 0; i < A.getNumElements(); i++) {
B.set(i, 1 - Math.pow(Math.tanh(A.get(i)), 2));
}
return B;
return activationPrime.apply(this.input).elementMult(outputError);
}
}

View file

@ -1,7 +1,5 @@
import org.ejml.simple.SimpleMatrix;
import java.util.ArrayList;
public class ExampleXOR {
public static void main(String[] args) {
SimpleMatrix[] X_train = {new SimpleMatrix(new double[][]{{0, 0}}),
@ -15,13 +13,14 @@ public class ExampleXOR {
Network network = new Network();
network.add(new FCLayer(2, 3));
network.add(new ActivationLayer());
network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.add(new FCLayer(3, 1));
network.add(new ActivationLayer());
network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
network.fit(X_train, y_train, 1000, 0.1d);
ArrayList<SimpleMatrix> output = network.predict(X_train);
SimpleMatrix[] output = network.predict(X_train);
for (SimpleMatrix entry : output) {
System.out.println("Prediction:");
for (int i = 0; i < entry.getNumElements(); i++) {

View file

@ -0,0 +1,17 @@
import org.ejml.simple.SimpleMatrix;
public class LossFunctions {
public static double MSE(SimpleMatrix y_true, SimpleMatrix y_pred) {
SimpleMatrix temp = y_true.minus(y_pred);
temp = temp.elementMult(temp);
double sum = 0;
for (int i = 0; i < temp.getNumElements(); i++) {
sum += temp.get(i);
}
return sum / temp.getNumElements();
}
public static SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2);
}
}

View file

@ -1,20 +1,13 @@
import org.ejml.simple.SimpleMatrix;
import java.util.ArrayList;
import java.util.function.BiFunction;
public class Network {
ArrayList<Layer> layers;
/* custom loss functions not yet supported
Function<SimpleMatrix, Double> loss;
Function<SimpleMatrix, Double> lossPrime;
public void use(Function<SimpleMatrix, Double> loss, Function<SimpleMatrix, Double> lossPrime) {
this.loss = loss;
this.lossPrime = lossPrime;
}
*/
BiFunction<SimpleMatrix, SimpleMatrix, Double> loss;
BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime;
public Network() {
layers = new ArrayList<>();
@ -24,16 +17,23 @@ public class Network {
layers.add(layer);
}
public ArrayList<SimpleMatrix> predict(SimpleMatrix[] inputs) {
ArrayList<SimpleMatrix> result = new ArrayList<>();
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
this.loss = loss;
this.lossPrime = lossPrime;
}
public SimpleMatrix[] predict(SimpleMatrix[] inputs) {
SimpleMatrix[] result = new SimpleMatrix[inputs.length];
SimpleMatrix output;
int i = 0;
for (SimpleMatrix input : inputs) {
output = input;
for (Layer l : layers) {
output = l.forwardPropagation(output);
}
result.add(output);
result[i] = output;
i++;
}
return result;
@ -52,10 +52,10 @@ public class Network {
}
// compute loss (for display purpose only)
err = MSE(y_train[j], output);
err = loss.apply(y_train[j], output);
// backward propagation
SimpleMatrix error = MSEPrime(y_train[j], output);
SimpleMatrix error = lossPrime.apply(y_train[j], output);
for (int k = layers.size() - 1; k >= 0; k--) {
error = layers.get(k).backwardPropagation(error, learningRate);
}
@ -65,18 +65,4 @@ public class Network {
System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err);
}
}
private double MSE(SimpleMatrix y_true, SimpleMatrix y_pred) {
SimpleMatrix temp = y_true.minus(y_pred);
temp = temp.elementMult(temp);
double sum = 0;
for (int i = 0; i < temp.getNumElements(); i++) {
sum += temp.get(i);
}
return sum / temp.getNumElements();
}
private SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2);
}
}