Improved forward_pass
This commit is contained in:
parent
d130c7cce1
commit
f2e54cfac1
4 changed files with 9 additions and 11 deletions
|
@ -23,8 +23,8 @@ impl ActivationLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Layer for ActivationLayer {
|
impl Layer for ActivationLayer {
|
||||||
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
|
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64> {
|
||||||
self.input = input.to_owned();
|
self.input = input;
|
||||||
// output isn't needed elsewhere
|
// output isn't needed elsewhere
|
||||||
// self.output = (self.activation)(&self.input);
|
// self.output = (self.activation)(&self.input);
|
||||||
// self.output.clone()
|
// self.output.clone()
|
||||||
|
|
|
@ -71,12 +71,12 @@ impl FCLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Layer for FCLayer {
|
impl Layer for FCLayer {
|
||||||
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
|
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64> {
|
||||||
if !self.is_initialized {
|
if !self.is_initialized {
|
||||||
self.initialize(input.len());
|
self.initialize(input.len());
|
||||||
}
|
}
|
||||||
|
|
||||||
self.input = input.to_owned();
|
self.input = input;
|
||||||
// output isn't needed elsewhere
|
// output isn't needed elsewhere
|
||||||
// self.output = self.input.dot(&self.weights) + &self.biases;
|
// self.output = self.input.dot(&self.weights) + &self.biases;
|
||||||
// self.output.clone()
|
// self.output.clone()
|
||||||
|
|
|
@ -4,6 +4,6 @@ pub mod activation_layer;
|
||||||
pub mod fc_layer;
|
pub mod fc_layer;
|
||||||
|
|
||||||
pub trait Layer {
|
pub trait Layer {
|
||||||
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64>;
|
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64>;
|
||||||
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64>;
|
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64>;
|
||||||
}
|
}
|
||||||
|
|
10
src/lib.rs
10
src/lib.rs
|
@ -30,10 +30,9 @@ impl Network {
|
||||||
let mut result = vec![];
|
let mut result = vec![];
|
||||||
|
|
||||||
for input in inputs.iter() {
|
for input in inputs.iter() {
|
||||||
let mut output = Array1::default(inputs[0].raw_dim());
|
let mut output = input.to_owned();
|
||||||
output.assign(input);
|
|
||||||
for layer in &mut self.layers {
|
for layer in &mut self.layers {
|
||||||
output = layer.forward_pass(output.view());
|
output = layer.forward_pass(output);
|
||||||
}
|
}
|
||||||
result.push(output);
|
result.push(output);
|
||||||
}
|
}
|
||||||
|
@ -57,10 +56,9 @@ impl Network {
|
||||||
let mut err = 0.0;
|
let mut err = 0.0;
|
||||||
for j in 0..num_samples {
|
for j in 0..num_samples {
|
||||||
// forward propagation
|
// forward propagation
|
||||||
let mut output = Array1::default(x_train[0].raw_dim());
|
let mut output = x_train[j].to_owned();
|
||||||
output.assign(&x_train[j]);
|
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
output = layer.forward_pass(output.view());
|
output = layer.forward_pass(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute loss
|
// compute loss
|
||||||
|
|
Loading…
Reference in a new issue