Added support for different activation and loss functions
This commit is contained in:
parent
645a8baf2c
commit
5aa9313776
7 changed files with 119 additions and 61 deletions
|
@ -1,5 +1,6 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
|
<component name="GradleMigrationSettings" migrationVersion="1" />
|
||||||
<component name="GradleSettings">
|
<component name="GradleSettings">
|
||||||
<option name="linkedExternalProjectsSettings">
|
<option name="linkedExternalProjectsSettings">
|
||||||
<GradleProjectSettings>
|
<GradleProjectSettings>
|
||||||
|
|
|
@ -4,11 +4,13 @@
|
||||||
<option name="autoReloadType" value="SELECTIVE" />
|
<option name="autoReloadType" value="SELECTIVE" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ChangeListManager">
|
<component name="ChangeListManager">
|
||||||
<list default="true" id="75a8c215-b746-4a4d-aa16-f3223c12b1ed" name="Changes" comment="Initial commit">
|
<list default="true" id="75a8c215-b746-4a4d-aa16-f3223c12b1ed" name="Changes" comment="Commented out not yet working code">
|
||||||
|
<change afterPath="$PROJECT_DIR$/src/main/java/ActivationFunctions.java" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/src/main/java/LossFunctions.java" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/.idea/gradle.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/gradle.xml" afterDir="false" />
|
||||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||||
<change beforePath="$PROJECT_DIR$/src/main/java/ActivationLayer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/ActivationLayer.java" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/src/main/java/ActivationLayer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/ActivationLayer.java" afterDir="false" />
|
||||||
<change beforePath="$PROJECT_DIR$/src/main/java/FCLayer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/FCLayer.java" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/src/main/java/ExampleXOR.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/ExampleXOR.java" afterDir="false" />
|
||||||
<change beforePath="$PROJECT_DIR$/src/main/java/Layer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/Layer.java" afterDir="false" />
|
|
||||||
<change beforePath="$PROJECT_DIR$/src/main/java/Network.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/Network.java" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/src/main/java/Network.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/Network.java" afterDir="false" />
|
||||||
</list>
|
</list>
|
||||||
<option name="SHOW_DIALOG" value="false" />
|
<option name="SHOW_DIALOG" value="false" />
|
||||||
|
@ -21,6 +23,23 @@
|
||||||
<ProjectState />
|
<ProjectState />
|
||||||
</projectState>
|
</projectState>
|
||||||
</component>
|
</component>
|
||||||
|
<component name="ExternalProjectsManager">
|
||||||
|
<system id="GRADLE">
|
||||||
|
<state>
|
||||||
|
<projects_view>
|
||||||
|
<tree_state>
|
||||||
|
<expand>
|
||||||
|
<path>
|
||||||
|
<item name="" type="6a2764b6:ExternalProjectsStructure$RootNode" />
|
||||||
|
<item name="JavaNN" type="f1a62948:ProjectNode" />
|
||||||
|
</path>
|
||||||
|
</expand>
|
||||||
|
<select />
|
||||||
|
</tree_state>
|
||||||
|
</projects_view>
|
||||||
|
</state>
|
||||||
|
</system>
|
||||||
|
</component>
|
||||||
<component name="FileTemplateManagerImpl">
|
<component name="FileTemplateManagerImpl">
|
||||||
<option name="RECENT_TEMPLATES">
|
<option name="RECENT_TEMPLATES">
|
||||||
<list>
|
<list>
|
||||||
|
@ -29,6 +48,11 @@
|
||||||
</option>
|
</option>
|
||||||
</component>
|
</component>
|
||||||
<component name="Git.Settings">
|
<component name="Git.Settings">
|
||||||
|
<option name="RECENT_BRANCH_BY_REPOSITORY">
|
||||||
|
<map>
|
||||||
|
<entry key="$PROJECT_DIR$" value="a431e08264a1ba49995b5b45ab2e2233f40a9a1d" />
|
||||||
|
</map>
|
||||||
|
</option>
|
||||||
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||||
</component>
|
</component>
|
||||||
<component name="MarkdownSettingsMigration">
|
<component name="MarkdownSettingsMigration">
|
||||||
|
@ -59,6 +83,7 @@
|
||||||
<recent_temporary>
|
<recent_temporary>
|
||||||
<list>
|
<list>
|
||||||
<item itemvalue="Application.ExampleXOR" />
|
<item itemvalue="Application.ExampleXOR" />
|
||||||
|
<item itemvalue="Application.ExampleXOR" />
|
||||||
</list>
|
</list>
|
||||||
</recent_temporary>
|
</recent_temporary>
|
||||||
</component>
|
</component>
|
||||||
|
@ -71,6 +96,14 @@
|
||||||
<option name="presentableId" value="Default" />
|
<option name="presentableId" value="Default" />
|
||||||
<updated>1653154357924</updated>
|
<updated>1653154357924</updated>
|
||||||
</task>
|
</task>
|
||||||
|
<task id="LOCAL-00001" summary="Commented out not yet working code">
|
||||||
|
<created>1653168727185</created>
|
||||||
|
<option name="number" value="00001" />
|
||||||
|
<option name="presentableId" value="LOCAL-00001" />
|
||||||
|
<option name="project" value="LOCAL" />
|
||||||
|
<updated>1653168727186</updated>
|
||||||
|
</task>
|
||||||
|
<option name="localTasksCounter" value="2" />
|
||||||
<servers />
|
<servers />
|
||||||
</component>
|
</component>
|
||||||
<component name="Vcs.Log.Tabs.Properties">
|
<component name="Vcs.Log.Tabs.Properties">
|
||||||
|
@ -86,9 +119,10 @@
|
||||||
</component>
|
</component>
|
||||||
<component name="VcsManagerConfiguration">
|
<component name="VcsManagerConfiguration">
|
||||||
<MESSAGE value="Initial commit" />
|
<MESSAGE value="Initial commit" />
|
||||||
<option name="LAST_COMMIT_MESSAGE" value="Initial commit" />
|
<MESSAGE value="Commented out not yet working code" />
|
||||||
|
<option name="LAST_COMMIT_MESSAGE" value="Commented out not yet working code" />
|
||||||
</component>
|
</component>
|
||||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||||
<SUITE FILE_PATH="coverage/JavaNN$ExampleXOR.ic" NAME="ExampleXOR Coverage Results" MODIFIED="1653164239216" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="idea" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" />
|
<SUITE FILE_PATH="coverage/JavaNN$ExampleXOR.ic" NAME="ExampleXOR Coverage Results" MODIFIED="1653174521293" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="idea" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" />
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
39
src/main/java/ActivationFunctions.java
Normal file
39
src/main/java/ActivationFunctions.java
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
import org.ejml.simple.SimpleMatrix;
|
||||||
|
|
||||||
|
public class ActivationFunctions {
|
||||||
|
public static SimpleMatrix tanh(SimpleMatrix A) {
|
||||||
|
SimpleMatrix B = new SimpleMatrix(A);
|
||||||
|
for (int i = 0; i < A.getNumElements(); i++) {
|
||||||
|
B.set(i, Math.tanh(A.get(i)));
|
||||||
|
}
|
||||||
|
return B;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static SimpleMatrix tanhPrime(SimpleMatrix A) {
|
||||||
|
SimpleMatrix B = new SimpleMatrix(A);
|
||||||
|
for (int i = 0; i < A.getNumElements(); i++) {
|
||||||
|
B.set(i, 1 - Math.pow(Math.tanh(A.get(i)), 2));
|
||||||
|
}
|
||||||
|
return B;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static SimpleMatrix logistic(SimpleMatrix A) {
|
||||||
|
SimpleMatrix B = new SimpleMatrix(A);
|
||||||
|
for (int i = 0; i < A.getNumElements(); i++) {
|
||||||
|
B.set(i, sigma(-A.get(i)));
|
||||||
|
}
|
||||||
|
return B;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static SimpleMatrix logisticPrime(SimpleMatrix A) {
|
||||||
|
SimpleMatrix B = new SimpleMatrix(A);
|
||||||
|
for (int i = 0; i < A.getNumElements(); i++) {
|
||||||
|
B.set(i, sigma(A.get(i) * (1 - sigma(A.get(i)))));
|
||||||
|
}
|
||||||
|
return B;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static double sigma(double value) {
|
||||||
|
return 1 / (1 + Math.exp(-value));
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,7 +1,8 @@
|
||||||
import org.ejml.simple.SimpleMatrix;
|
import org.ejml.simple.SimpleMatrix;
|
||||||
|
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
public class ActivationLayer extends Layer {
|
public class ActivationLayer extends Layer {
|
||||||
/* custom activation functions not yet supported
|
|
||||||
Function<SimpleMatrix, SimpleMatrix> activation;
|
Function<SimpleMatrix, SimpleMatrix> activation;
|
||||||
Function<SimpleMatrix, SimpleMatrix> activationPrime;
|
Function<SimpleMatrix, SimpleMatrix> activationPrime;
|
||||||
|
|
||||||
|
@ -9,35 +10,16 @@ public class ActivationLayer extends Layer {
|
||||||
this.activation = activation;
|
this.activation = activation;
|
||||||
this.activationPrime = activationPrime;
|
this.activationPrime = activationPrime;
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
public ActivationLayer() {}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SimpleMatrix forwardPropagation(SimpleMatrix input) {
|
public SimpleMatrix forwardPropagation(SimpleMatrix input) {
|
||||||
this.input = input;
|
this.input = input;
|
||||||
this.output = tanh(input);
|
this.output = activation.apply(input);
|
||||||
return this.output;
|
return this.output;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
|
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
|
||||||
return tanhPrime(this.input).elementMult(outputError);
|
return activationPrime.apply(this.input).elementMult(outputError);
|
||||||
}
|
|
||||||
|
|
||||||
private SimpleMatrix tanh(SimpleMatrix A) {
|
|
||||||
SimpleMatrix B = new SimpleMatrix(A);
|
|
||||||
for (int i = 0; i < A.getNumElements(); i++) {
|
|
||||||
B.set(i, Math.tanh(A.get(i)));
|
|
||||||
}
|
|
||||||
return B;
|
|
||||||
}
|
|
||||||
|
|
||||||
private SimpleMatrix tanhPrime(SimpleMatrix A) {
|
|
||||||
SimpleMatrix B = new SimpleMatrix(A);
|
|
||||||
for (int i = 0; i < A.getNumElements(); i++) {
|
|
||||||
B.set(i, 1 - Math.pow(Math.tanh(A.get(i)), 2));
|
|
||||||
}
|
|
||||||
return B;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import org.ejml.simple.SimpleMatrix;
|
import org.ejml.simple.SimpleMatrix;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
|
|
||||||
public class ExampleXOR {
|
public class ExampleXOR {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
SimpleMatrix[] X_train = {new SimpleMatrix(new double[][]{{0, 0}}),
|
SimpleMatrix[] X_train = {new SimpleMatrix(new double[][]{{0, 0}}),
|
||||||
|
@ -15,13 +13,14 @@ public class ExampleXOR {
|
||||||
|
|
||||||
Network network = new Network();
|
Network network = new Network();
|
||||||
network.add(new FCLayer(2, 3));
|
network.add(new FCLayer(2, 3));
|
||||||
network.add(new ActivationLayer());
|
network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
network.add(new FCLayer(3, 1));
|
network.add(new FCLayer(3, 1));
|
||||||
network.add(new ActivationLayer());
|
network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||||
|
|
||||||
|
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||||
network.fit(X_train, y_train, 1000, 0.1d);
|
network.fit(X_train, y_train, 1000, 0.1d);
|
||||||
|
|
||||||
ArrayList<SimpleMatrix> output = network.predict(X_train);
|
SimpleMatrix[] output = network.predict(X_train);
|
||||||
for (SimpleMatrix entry : output) {
|
for (SimpleMatrix entry : output) {
|
||||||
System.out.println("Prediction:");
|
System.out.println("Prediction:");
|
||||||
for (int i = 0; i < entry.getNumElements(); i++) {
|
for (int i = 0; i < entry.getNumElements(); i++) {
|
||||||
|
|
17
src/main/java/LossFunctions.java
Normal file
17
src/main/java/LossFunctions.java
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
import org.ejml.simple.SimpleMatrix;
|
||||||
|
|
||||||
|
public class LossFunctions {
|
||||||
|
public static double MSE(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
||||||
|
SimpleMatrix temp = y_true.minus(y_pred);
|
||||||
|
temp = temp.elementMult(temp);
|
||||||
|
double sum = 0;
|
||||||
|
for (int i = 0; i < temp.getNumElements(); i++) {
|
||||||
|
sum += temp.get(i);
|
||||||
|
}
|
||||||
|
return sum / temp.getNumElements();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
||||||
|
return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,20 +1,13 @@
|
||||||
import org.ejml.simple.SimpleMatrix;
|
import org.ejml.simple.SimpleMatrix;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.function.BiFunction;
|
||||||
|
|
||||||
public class Network {
|
public class Network {
|
||||||
|
|
||||||
ArrayList<Layer> layers;
|
ArrayList<Layer> layers;
|
||||||
|
BiFunction<SimpleMatrix, SimpleMatrix, Double> loss;
|
||||||
/* custom loss functions not yet supported
|
BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime;
|
||||||
Function<SimpleMatrix, Double> loss;
|
|
||||||
Function<SimpleMatrix, Double> lossPrime;
|
|
||||||
|
|
||||||
public void use(Function<SimpleMatrix, Double> loss, Function<SimpleMatrix, Double> lossPrime) {
|
|
||||||
this.loss = loss;
|
|
||||||
this.lossPrime = lossPrime;
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
public Network() {
|
public Network() {
|
||||||
layers = new ArrayList<>();
|
layers = new ArrayList<>();
|
||||||
|
@ -24,16 +17,23 @@ public class Network {
|
||||||
layers.add(layer);
|
layers.add(layer);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ArrayList<SimpleMatrix> predict(SimpleMatrix[] inputs) {
|
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
|
||||||
ArrayList<SimpleMatrix> result = new ArrayList<>();
|
this.loss = loss;
|
||||||
|
this.lossPrime = lossPrime;
|
||||||
|
}
|
||||||
|
|
||||||
|
public SimpleMatrix[] predict(SimpleMatrix[] inputs) {
|
||||||
|
SimpleMatrix[] result = new SimpleMatrix[inputs.length];
|
||||||
SimpleMatrix output;
|
SimpleMatrix output;
|
||||||
|
int i = 0;
|
||||||
|
|
||||||
for (SimpleMatrix input : inputs) {
|
for (SimpleMatrix input : inputs) {
|
||||||
output = input;
|
output = input;
|
||||||
for (Layer l : layers) {
|
for (Layer l : layers) {
|
||||||
output = l.forwardPropagation(output);
|
output = l.forwardPropagation(output);
|
||||||
}
|
}
|
||||||
result.add(output);
|
result[i] = output;
|
||||||
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -52,10 +52,10 @@ public class Network {
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute loss (for display purpose only)
|
// compute loss (for display purpose only)
|
||||||
err = MSE(y_train[j], output);
|
err = loss.apply(y_train[j], output);
|
||||||
|
|
||||||
// backward propagation
|
// backward propagation
|
||||||
SimpleMatrix error = MSEPrime(y_train[j], output);
|
SimpleMatrix error = lossPrime.apply(y_train[j], output);
|
||||||
for (int k = layers.size() - 1; k >= 0; k--) {
|
for (int k = layers.size() - 1; k >= 0; k--) {
|
||||||
error = layers.get(k).backwardPropagation(error, learningRate);
|
error = layers.get(k).backwardPropagation(error, learningRate);
|
||||||
}
|
}
|
||||||
|
@ -65,18 +65,4 @@ public class Network {
|
||||||
System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err);
|
System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private double MSE(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
|
||||||
SimpleMatrix temp = y_true.minus(y_pred);
|
|
||||||
temp = temp.elementMult(temp);
|
|
||||||
double sum = 0;
|
|
||||||
for (int i = 0; i < temp.getNumElements(); i++) {
|
|
||||||
sum += temp.get(i);
|
|
||||||
}
|
|
||||||
return sum / temp.getNumElements();
|
|
||||||
}
|
|
||||||
|
|
||||||
private SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
|
||||||
return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2);
|
|
||||||
}
|
|
||||||
}
|
}
|
Loading…
Reference in a new issue