Improved forward_pass

This commit is contained in:
lluni 2023-02-04 18:14:06 +01:00
parent d130c7cce1
commit f2e54cfac1
Signed by: lluni
GPG key ID: ACEEB468BC325D35
4 changed files with 9 additions and 11 deletions

View file

@ -23,8 +23,8 @@ impl ActivationLayer {
}
impl Layer for ActivationLayer {
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
self.input = input.to_owned();
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64> {
self.input = input;
// output isn't needed elsewhere
// self.output = (self.activation)(&self.input);
// self.output.clone()

View file

@ -71,12 +71,12 @@ impl FCLayer {
}
impl Layer for FCLayer {
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64> {
if !self.is_initialized {
self.initialize(input.len());
}
self.input = input.to_owned();
self.input = input;
// output isn't needed elsewhere
// self.output = self.input.dot(&self.weights) + &self.biases;
// self.output.clone()

View file

@ -4,6 +4,6 @@ pub mod activation_layer;
pub mod fc_layer;
pub trait Layer {
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64>;
fn forward_pass(&mut self, input: Array1<f64>) -> Array1<f64>;
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64>;
}

View file

@ -30,10 +30,9 @@ impl Network {
let mut result = vec![];
for input in inputs.iter() {
let mut output = Array1::default(inputs[0].raw_dim());
output.assign(input);
let mut output = input.to_owned();
for layer in &mut self.layers {
output = layer.forward_pass(output.view());
output = layer.forward_pass(output);
}
result.push(output);
}
@ -57,10 +56,9 @@ impl Network {
let mut err = 0.0;
for j in 0..num_samples {
// forward propagation
let mut output = Array1::default(x_train[0].raw_dim());
output.assign(&x_train[j]);
let mut output = x_train[j].to_owned();
for layer in self.layers.iter_mut() {
output = layer.forward_pass(output.view());
output = layer.forward_pass(output);
}
// compute loss