diff --git a/src/main/java/de/lluni/javann/Network.java b/src/main/java/de/lluni/javann/Network.java index d7e2f6d..f03555b 100644 --- a/src/main/java/de/lluni/javann/Network.java +++ b/src/main/java/de/lluni/javann/Network.java @@ -27,12 +27,17 @@ public class Network { */ public void addNeuron(int layer, int n) { if (!(this.layers.get(layer) instanceof FCLayer)) { - System.out.println("This layer is not a de.lluni.javann.layers.BlankLayer"); - } else if (!(this.layers.get(layer + 2) instanceof FCLayer)) { - System.out.println("The next layer is not a de.lluni.javann.layers.BlankLayer"); + System.out.println("This layer is not a FCLayer"); } + ((FCLayer) this.layers.get(layer)).addNeuron(n); - ((FCLayer) this.layers.get(layer + 2)).updateInputSize(n); + + for (int i = layer + 1; i < this.layers.size(); i++) { + if (this.getLayers().get(i) instanceof FCLayer) { + ((FCLayer) this.layers.get(i)).updateInputSize(n); + break; + } + } } /** diff --git a/src/main/java/de/lluni/javann/layers/FCLayer.java b/src/main/java/de/lluni/javann/layers/FCLayer.java index 21bdc7b..2c78c5c 100644 --- a/src/main/java/de/lluni/javann/layers/FCLayer.java +++ b/src/main/java/de/lluni/javann/layers/FCLayer.java @@ -46,7 +46,7 @@ public class FCLayer extends Layer { } for (int i = 0; i < newWeights.getNumElements(); i++) { if (newWeights.get(i) == 0) { - newWeights.set(i, random.nextDouble(-1, 1)); + newWeights.set(i, 0.1d * random.nextGaussian(0, 1)); } } this.weights = newWeights;