Added support for different activation and loss functions
This commit is contained in:
parent
645a8baf2c
commit
5aa9313776
7 changed files with 119 additions and 61 deletions
|
@ -1,5 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="GradleMigrationSettings" migrationVersion="1" />
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
|
|
|
@ -4,11 +4,13 @@
|
|||
<option name="autoReloadType" value="SELECTIVE" />
|
||||
</component>
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="75a8c215-b746-4a4d-aa16-f3223c12b1ed" name="Changes" comment="Initial commit">
|
||||
<list default="true" id="75a8c215-b746-4a4d-aa16-f3223c12b1ed" name="Changes" comment="Commented out not yet working code">
|
||||
<change afterPath="$PROJECT_DIR$/src/main/java/ActivationFunctions.java" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/src/main/java/LossFunctions.java" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/gradle.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/gradle.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/src/main/java/ActivationLayer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/ActivationLayer.java" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/src/main/java/FCLayer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/FCLayer.java" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/src/main/java/Layer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/Layer.java" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/src/main/java/ExampleXOR.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/ExampleXOR.java" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/src/main/java/Network.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/Network.java" afterDir="false" />
|
||||
</list>
|
||||
<option name="SHOW_DIALOG" value="false" />
|
||||
|
@ -21,6 +23,23 @@
|
|||
<ProjectState />
|
||||
</projectState>
|
||||
</component>
|
||||
<component name="ExternalProjectsManager">
|
||||
<system id="GRADLE">
|
||||
<state>
|
||||
<projects_view>
|
||||
<tree_state>
|
||||
<expand>
|
||||
<path>
|
||||
<item name="" type="6a2764b6:ExternalProjectsStructure$RootNode" />
|
||||
<item name="JavaNN" type="f1a62948:ProjectNode" />
|
||||
</path>
|
||||
</expand>
|
||||
<select />
|
||||
</tree_state>
|
||||
</projects_view>
|
||||
</state>
|
||||
</system>
|
||||
</component>
|
||||
<component name="FileTemplateManagerImpl">
|
||||
<option name="RECENT_TEMPLATES">
|
||||
<list>
|
||||
|
@ -29,6 +48,11 @@
|
|||
</option>
|
||||
</component>
|
||||
<component name="Git.Settings">
|
||||
<option name="RECENT_BRANCH_BY_REPOSITORY">
|
||||
<map>
|
||||
<entry key="$PROJECT_DIR$" value="a431e08264a1ba49995b5b45ab2e2233f40a9a1d" />
|
||||
</map>
|
||||
</option>
|
||||
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||
</component>
|
||||
<component name="MarkdownSettingsMigration">
|
||||
|
@ -59,6 +83,7 @@
|
|||
<recent_temporary>
|
||||
<list>
|
||||
<item itemvalue="Application.ExampleXOR" />
|
||||
<item itemvalue="Application.ExampleXOR" />
|
||||
</list>
|
||||
</recent_temporary>
|
||||
</component>
|
||||
|
@ -71,6 +96,14 @@
|
|||
<option name="presentableId" value="Default" />
|
||||
<updated>1653154357924</updated>
|
||||
</task>
|
||||
<task id="LOCAL-00001" summary="Commented out not yet working code">
|
||||
<created>1653168727185</created>
|
||||
<option name="number" value="00001" />
|
||||
<option name="presentableId" value="LOCAL-00001" />
|
||||
<option name="project" value="LOCAL" />
|
||||
<updated>1653168727186</updated>
|
||||
</task>
|
||||
<option name="localTasksCounter" value="2" />
|
||||
<servers />
|
||||
</component>
|
||||
<component name="Vcs.Log.Tabs.Properties">
|
||||
|
@ -86,9 +119,10 @@
|
|||
</component>
|
||||
<component name="VcsManagerConfiguration">
|
||||
<MESSAGE value="Initial commit" />
|
||||
<option name="LAST_COMMIT_MESSAGE" value="Initial commit" />
|
||||
<MESSAGE value="Commented out not yet working code" />
|
||||
<option name="LAST_COMMIT_MESSAGE" value="Commented out not yet working code" />
|
||||
</component>
|
||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||
<SUITE FILE_PATH="coverage/JavaNN$ExampleXOR.ic" NAME="ExampleXOR Coverage Results" MODIFIED="1653164239216" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="idea" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" />
|
||||
<SUITE FILE_PATH="coverage/JavaNN$ExampleXOR.ic" NAME="ExampleXOR Coverage Results" MODIFIED="1653174521293" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="idea" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" />
|
||||
</component>
|
||||
</project>
|
39
src/main/java/ActivationFunctions.java
Normal file
39
src/main/java/ActivationFunctions.java
Normal file
|
@ -0,0 +1,39 @@
|
|||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
public class ActivationFunctions {
|
||||
public static SimpleMatrix tanh(SimpleMatrix A) {
|
||||
SimpleMatrix B = new SimpleMatrix(A);
|
||||
for (int i = 0; i < A.getNumElements(); i++) {
|
||||
B.set(i, Math.tanh(A.get(i)));
|
||||
}
|
||||
return B;
|
||||
}
|
||||
|
||||
public static SimpleMatrix tanhPrime(SimpleMatrix A) {
|
||||
SimpleMatrix B = new SimpleMatrix(A);
|
||||
for (int i = 0; i < A.getNumElements(); i++) {
|
||||
B.set(i, 1 - Math.pow(Math.tanh(A.get(i)), 2));
|
||||
}
|
||||
return B;
|
||||
}
|
||||
|
||||
public static SimpleMatrix logistic(SimpleMatrix A) {
|
||||
SimpleMatrix B = new SimpleMatrix(A);
|
||||
for (int i = 0; i < A.getNumElements(); i++) {
|
||||
B.set(i, sigma(-A.get(i)));
|
||||
}
|
||||
return B;
|
||||
}
|
||||
|
||||
public static SimpleMatrix logisticPrime(SimpleMatrix A) {
|
||||
SimpleMatrix B = new SimpleMatrix(A);
|
||||
for (int i = 0; i < A.getNumElements(); i++) {
|
||||
B.set(i, sigma(A.get(i) * (1 - sigma(A.get(i)))));
|
||||
}
|
||||
return B;
|
||||
}
|
||||
|
||||
private static double sigma(double value) {
|
||||
return 1 / (1 + Math.exp(-value));
|
||||
}
|
||||
}
|
|
@ -1,7 +1,8 @@
|
|||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import java.util.function.Function;
|
||||
|
||||
public class ActivationLayer extends Layer {
|
||||
/* custom activation functions not yet supported
|
||||
Function<SimpleMatrix, SimpleMatrix> activation;
|
||||
Function<SimpleMatrix, SimpleMatrix> activationPrime;
|
||||
|
||||
|
@ -9,35 +10,16 @@ public class ActivationLayer extends Layer {
|
|||
this.activation = activation;
|
||||
this.activationPrime = activationPrime;
|
||||
}
|
||||
*/
|
||||
|
||||
public ActivationLayer() {}
|
||||
|
||||
@Override
|
||||
public SimpleMatrix forwardPropagation(SimpleMatrix input) {
|
||||
this.input = input;
|
||||
this.output = tanh(input);
|
||||
this.output = activation.apply(input);
|
||||
return this.output;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
|
||||
return tanhPrime(this.input).elementMult(outputError);
|
||||
}
|
||||
|
||||
private SimpleMatrix tanh(SimpleMatrix A) {
|
||||
SimpleMatrix B = new SimpleMatrix(A);
|
||||
for (int i = 0; i < A.getNumElements(); i++) {
|
||||
B.set(i, Math.tanh(A.get(i)));
|
||||
}
|
||||
return B;
|
||||
}
|
||||
|
||||
private SimpleMatrix tanhPrime(SimpleMatrix A) {
|
||||
SimpleMatrix B = new SimpleMatrix(A);
|
||||
for (int i = 0; i < A.getNumElements(); i++) {
|
||||
B.set(i, 1 - Math.pow(Math.tanh(A.get(i)), 2));
|
||||
}
|
||||
return B;
|
||||
return activationPrime.apply(this.input).elementMult(outputError);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
public class ExampleXOR {
|
||||
public static void main(String[] args) {
|
||||
SimpleMatrix[] X_train = {new SimpleMatrix(new double[][]{{0, 0}}),
|
||||
|
@ -15,13 +13,14 @@ public class ExampleXOR {
|
|||
|
||||
Network network = new Network();
|
||||
network.add(new FCLayer(2, 3));
|
||||
network.add(new ActivationLayer());
|
||||
network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||
network.add(new FCLayer(3, 1));
|
||||
network.add(new ActivationLayer());
|
||||
network.add(new ActivationLayer(ActivationFunctions::tanh, ActivationFunctions::tanhPrime));
|
||||
|
||||
network.use(LossFunctions::MSE, LossFunctions::MSEPrime);
|
||||
network.fit(X_train, y_train, 1000, 0.1d);
|
||||
|
||||
ArrayList<SimpleMatrix> output = network.predict(X_train);
|
||||
SimpleMatrix[] output = network.predict(X_train);
|
||||
for (SimpleMatrix entry : output) {
|
||||
System.out.println("Prediction:");
|
||||
for (int i = 0; i < entry.getNumElements(); i++) {
|
||||
|
|
17
src/main/java/LossFunctions.java
Normal file
17
src/main/java/LossFunctions.java
Normal file
|
@ -0,0 +1,17 @@
|
|||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
public class LossFunctions {
|
||||
public static double MSE(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
||||
SimpleMatrix temp = y_true.minus(y_pred);
|
||||
temp = temp.elementMult(temp);
|
||||
double sum = 0;
|
||||
for (int i = 0; i < temp.getNumElements(); i++) {
|
||||
sum += temp.get(i);
|
||||
}
|
||||
return sum / temp.getNumElements();
|
||||
}
|
||||
|
||||
public static SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
||||
return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2);
|
||||
}
|
||||
}
|
|
@ -1,20 +1,13 @@
|
|||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.function.BiFunction;
|
||||
|
||||
public class Network {
|
||||
|
||||
ArrayList<Layer> layers;
|
||||
|
||||
/* custom loss functions not yet supported
|
||||
Function<SimpleMatrix, Double> loss;
|
||||
Function<SimpleMatrix, Double> lossPrime;
|
||||
|
||||
public void use(Function<SimpleMatrix, Double> loss, Function<SimpleMatrix, Double> lossPrime) {
|
||||
this.loss = loss;
|
||||
this.lossPrime = lossPrime;
|
||||
}
|
||||
*/
|
||||
BiFunction<SimpleMatrix, SimpleMatrix, Double> loss;
|
||||
BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime;
|
||||
|
||||
public Network() {
|
||||
layers = new ArrayList<>();
|
||||
|
@ -24,16 +17,23 @@ public class Network {
|
|||
layers.add(layer);
|
||||
}
|
||||
|
||||
public ArrayList<SimpleMatrix> predict(SimpleMatrix[] inputs) {
|
||||
ArrayList<SimpleMatrix> result = new ArrayList<>();
|
||||
public void use(BiFunction<SimpleMatrix, SimpleMatrix, Double> loss, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> lossPrime) {
|
||||
this.loss = loss;
|
||||
this.lossPrime = lossPrime;
|
||||
}
|
||||
|
||||
public SimpleMatrix[] predict(SimpleMatrix[] inputs) {
|
||||
SimpleMatrix[] result = new SimpleMatrix[inputs.length];
|
||||
SimpleMatrix output;
|
||||
int i = 0;
|
||||
|
||||
for (SimpleMatrix input : inputs) {
|
||||
output = input;
|
||||
for (Layer l : layers) {
|
||||
output = l.forwardPropagation(output);
|
||||
}
|
||||
result.add(output);
|
||||
result[i] = output;
|
||||
i++;
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -52,10 +52,10 @@ public class Network {
|
|||
}
|
||||
|
||||
// compute loss (for display purpose only)
|
||||
err = MSE(y_train[j], output);
|
||||
err = loss.apply(y_train[j], output);
|
||||
|
||||
// backward propagation
|
||||
SimpleMatrix error = MSEPrime(y_train[j], output);
|
||||
SimpleMatrix error = lossPrime.apply(y_train[j], output);
|
||||
for (int k = layers.size() - 1; k >= 0; k--) {
|
||||
error = layers.get(k).backwardPropagation(error, learningRate);
|
||||
}
|
||||
|
@ -65,18 +65,4 @@ public class Network {
|
|||
System.out.println("epoch " + (i+1) + "/" + epochs + " error=" + err);
|
||||
}
|
||||
}
|
||||
|
||||
private double MSE(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
||||
SimpleMatrix temp = y_true.minus(y_pred);
|
||||
temp = temp.elementMult(temp);
|
||||
double sum = 0;
|
||||
for (int i = 0; i < temp.getNumElements(); i++) {
|
||||
sum += temp.get(i);
|
||||
}
|
||||
return sum / temp.getNumElements();
|
||||
}
|
||||
|
||||
private SimpleMatrix MSEPrime(SimpleMatrix y_true, SimpleMatrix y_pred) {
|
||||
return y_true.minus(y_pred).divide((double) y_true.getNumElements()/2);
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue