Improved forward_pass
This commit is contained in:
parent
d130c7cce1
commit
f2e54cfac1
4 changed files with 9 additions and 11 deletions
|
@ -23,8 +23,8 @@ impl ActivationLayer {
|
|||
}
|
||||
|
||||
impl Layer for ActivationLayer {
|
||||
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
|
||||
self.input = input.to_owned();
|
||||
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64> {
|
||||
self.input = input;
|
||||
// output isn't needed elsewhere
|
||||
// self.output = (self.activation)(&self.input);
|
||||
// self.output.clone()
|
||||
|
|
|
@ -71,12 +71,12 @@ impl FCLayer {
|
|||
}
|
||||
|
||||
impl Layer for FCLayer {
|
||||
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
|
||||
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64> {
|
||||
if !self.is_initialized {
|
||||
self.initialize(input.len());
|
||||
}
|
||||
|
||||
self.input = input.to_owned();
|
||||
self.input = input;
|
||||
// output isn't needed elsewhere
|
||||
// self.output = self.input.dot(&self.weights) + &self.biases;
|
||||
// self.output.clone()
|
||||
|
|
|
@ -4,6 +4,6 @@ pub mod activation_layer;
|
|||
pub mod fc_layer;
|
||||
|
||||
pub trait Layer {
|
||||
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64>;
|
||||
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64>;
|
||||
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64>;
|
||||
}
|
||||
|
|
10
src/lib.rs
10
src/lib.rs
|
@ -30,10 +30,9 @@ impl Network {
|
|||
let mut result = vec![];
|
||||
|
||||
for input in inputs.iter() {
|
||||
let mut output = Array1::default(inputs[0].raw_dim());
|
||||
output.assign(input);
|
||||
let mut output = input.to_owned();
|
||||
for layer in &mut self.layers {
|
||||
output = layer.forward_pass(output.view());
|
||||
output = layer.forward_pass(output);
|
||||
}
|
||||
result.push(output);
|
||||
}
|
||||
|
@ -57,10 +56,9 @@ impl Network {
|
|||
let mut err = 0.0;
|
||||
for j in 0..num_samples {
|
||||
// forward propagation
|
||||
let mut output = Array1::default(x_train[0].raw_dim());
|
||||
output.assign(&x_train[j]);
|
||||
let mut output = x_train[j].to_owned();
|
||||
for layer in self.layers.iter_mut() {
|
||||
output = layer.forward_pass(output.view());
|
||||
output = layer.forward_pass(output);
|
||||
}
|
||||
|
||||
// compute loss
|
||||
|
|
Loading…
Reference in a new issue