Commented out not yet working code
This commit is contained in:
parent
a431e08264
commit
645a8baf2c
5 changed files with 27 additions and 39 deletions
|
@ -5,24 +5,11 @@
|
||||||
</component>
|
</component>
|
||||||
<component name="ChangeListManager">
|
<component name="ChangeListManager">
|
||||||
<list default="true" id="75a8c215-b746-4a4d-aa16-f3223c12b1ed" name="Changes" comment="Initial commit">
|
<list default="true" id="75a8c215-b746-4a4d-aa16-f3223c12b1ed" name="Changes" comment="Initial commit">
|
||||||
<change afterPath="$PROJECT_DIR$/.gitignore" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||||
<change afterPath="$PROJECT_DIR$/.idea/gradle.xml" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/src/main/java/ActivationLayer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/ActivationLayer.java" afterDir="false" />
|
||||||
<change afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/src/main/java/FCLayer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/FCLayer.java" afterDir="false" />
|
||||||
<change afterPath="$PROJECT_DIR$/.idea/uiDesigner.xml" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/src/main/java/Layer.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/Layer.java" afterDir="false" />
|
||||||
<change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/src/main/java/Network.java" beforeDir="false" afterPath="$PROJECT_DIR$/src/main/java/Network.java" afterDir="false" />
|
||||||
<change afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/build.gradle" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/gradle/wrapper/gradle-wrapper.jar" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/gradle/wrapper/gradle-wrapper.properties" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/gradlew" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/gradlew.bat" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/settings.gradle" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/src/main/java/ActivationLayer.java" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/src/main/java/ExampleXOR.java" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/src/main/java/FCLayer.java" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/src/main/java/GradientDescent.java" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/src/main/java/Layer.java" afterDir="false" />
|
|
||||||
<change afterPath="$PROJECT_DIR$/src/main/java/Network.java" afterDir="false" />
|
|
||||||
</list>
|
</list>
|
||||||
<option name="SHOW_DIALOG" value="false" />
|
<option name="SHOW_DIALOG" value="false" />
|
||||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||||
|
@ -53,14 +40,14 @@
|
||||||
<option name="hideEmptyMiddlePackages" value="true" />
|
<option name="hideEmptyMiddlePackages" value="true" />
|
||||||
<option name="showLibraryContents" value="true" />
|
<option name="showLibraryContents" value="true" />
|
||||||
</component>
|
</component>
|
||||||
<component name="PropertiesComponent"><![CDATA[{
|
<component name="PropertiesComponent">{
|
||||||
"keyToString": {
|
"keyToString": {
|
||||||
"SHARE_PROJECT_CONFIGURATION_FILES": "true",
|
"SHARE_PROJECT_CONFIGURATION_FILES": "true",
|
||||||
"project.structure.last.edited": "Project",
|
"project.structure.last.edited": "Project",
|
||||||
"project.structure.proportion": "0.15",
|
"project.structure.proportion": "0.15",
|
||||||
"project.structure.side.proportion": "0.2"
|
"project.structure.side.proportion": "0.2"
|
||||||
}
|
}
|
||||||
}]]></component>
|
}</component>
|
||||||
<component name="RunManager">
|
<component name="RunManager">
|
||||||
<configuration name="ExampleXOR" type="Application" factoryName="Application" temporary="true" nameIsGenerated="true">
|
<configuration name="ExampleXOR" type="Application" factoryName="Application" temporary="true" nameIsGenerated="true">
|
||||||
<option name="MAIN_CLASS_NAME" value="ExampleXOR" />
|
<option name="MAIN_CLASS_NAME" value="ExampleXOR" />
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
import org.ejml.simple.SimpleMatrix;
|
import org.ejml.simple.SimpleMatrix;
|
||||||
|
|
||||||
import java.util.function.Function;
|
|
||||||
|
|
||||||
public class ActivationLayer extends Layer {
|
public class ActivationLayer extends Layer {
|
||||||
|
/* custom activation functions not yet supported
|
||||||
Function<SimpleMatrix, SimpleMatrix> activation;
|
Function<SimpleMatrix, SimpleMatrix> activation;
|
||||||
Function<SimpleMatrix, SimpleMatrix> activationPrime;
|
Function<SimpleMatrix, SimpleMatrix> activationPrime;
|
||||||
|
|
||||||
public ActivationLayer() {}
|
|
||||||
|
|
||||||
public ActivationLayer(Function<SimpleMatrix, SimpleMatrix> activation, Function<SimpleMatrix, SimpleMatrix> activationPrime) {
|
public ActivationLayer(Function<SimpleMatrix, SimpleMatrix> activation, Function<SimpleMatrix, SimpleMatrix> activationPrime) {
|
||||||
this.activation = activation;
|
this.activation = activation;
|
||||||
this.activationPrime = activationPrime;
|
this.activationPrime = activationPrime;
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
public ActivationLayer() {}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SimpleMatrix forwardPropagation(SimpleMatrix input) {
|
public SimpleMatrix forwardPropagation(SimpleMatrix input) {
|
||||||
|
@ -21,7 +21,7 @@ public class ActivationLayer extends Layer {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, Double learningRate) {
|
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
|
||||||
return tanhPrime(this.input).elementMult(outputError);
|
return tanhPrime(this.input).elementMult(outputError);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ public class FCLayer extends Layer {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, Double learningRate) {
|
public SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate) {
|
||||||
SimpleMatrix inputError = outputError.mult(this.weights.transpose());
|
SimpleMatrix inputError = outputError.mult(this.weights.transpose());
|
||||||
SimpleMatrix weightsError = this.input.transpose().mult(outputError);
|
SimpleMatrix weightsError = this.input.transpose().mult(outputError);
|
||||||
|
|
||||||
|
|
|
@ -6,5 +6,5 @@ public abstract class Layer {
|
||||||
|
|
||||||
public abstract SimpleMatrix forwardPropagation(SimpleMatrix inputs);
|
public abstract SimpleMatrix forwardPropagation(SimpleMatrix inputs);
|
||||||
|
|
||||||
public abstract SimpleMatrix backwardPropagation(SimpleMatrix outputError, Double learningRate);
|
public abstract SimpleMatrix backwardPropagation(SimpleMatrix outputError, double learningRate);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,14 +1,21 @@
|
||||||
import org.ejml.simple.SimpleMatrix;
|
import org.ejml.simple.SimpleMatrix;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.function.Function;
|
|
||||||
|
|
||||||
public class Network {
|
public class Network {
|
||||||
|
|
||||||
ArrayList<Layer> layers;
|
ArrayList<Layer> layers;
|
||||||
|
|
||||||
|
/* custom loss functions not yet supported
|
||||||
Function<SimpleMatrix, Double> loss;
|
Function<SimpleMatrix, Double> loss;
|
||||||
Function<SimpleMatrix, Double> lossPrime;
|
Function<SimpleMatrix, Double> lossPrime;
|
||||||
|
|
||||||
|
public void use(Function<SimpleMatrix, Double> loss, Function<SimpleMatrix, Double> lossPrime) {
|
||||||
|
this.loss = loss;
|
||||||
|
this.lossPrime = lossPrime;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
public Network() {
|
public Network() {
|
||||||
layers = new ArrayList<>();
|
layers = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
@ -17,13 +24,7 @@ public class Network {
|
||||||
layers.add(layer);
|
layers.add(layer);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void use(Function<SimpleMatrix, Double> loss, Function<SimpleMatrix, Double> lossPrime) {
|
|
||||||
this.loss = loss;
|
|
||||||
this.lossPrime = lossPrime;
|
|
||||||
}
|
|
||||||
|
|
||||||
public ArrayList<SimpleMatrix> predict(SimpleMatrix[] inputs) {
|
public ArrayList<SimpleMatrix> predict(SimpleMatrix[] inputs) {
|
||||||
int samples = inputs.length;
|
|
||||||
ArrayList<SimpleMatrix> result = new ArrayList<>();
|
ArrayList<SimpleMatrix> result = new ArrayList<>();
|
||||||
SimpleMatrix output;
|
SimpleMatrix output;
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue