rust-nn/src/layers/fc_layer.rs

85 lines
No EOL
2.6 KiB
Rust

extern crate ndarray;
use ndarray::{Array1, Array2, arr1, arr2, Array, ArrayView1, ShapeBuilder};
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::{Normal, Uniform};
use super::Layer;
pub enum Initializer {
Zeros,
Ones,
Gaussian(f64, f64),
GaussianWFactor(f64, f64, f64),
Uniform(f64, f64)
}
impl Initializer {
pub fn init<Sh, D>(&self, shape: Sh) -> Array<f64, D>
where
Sh: ShapeBuilder<Dim = D>, D: ndarray::Dimension
{
match self {
Self::Zeros => Array::zeros(shape),
Self::Ones => Array::ones(shape),
Self::Gaussian(mean, stddev) => Array::random(shape, Normal::new(*mean, *stddev).unwrap()),
Self::GaussianWFactor(mean, stddev, factor)
=> Array::random(shape, Normal::new(*mean, *stddev).unwrap()) * *factor,
Self::Uniform(low, high) => Array::random(shape, Uniform::new(low, high))
}
}
}
pub struct FCLayer {
num_neurons: usize,
is_initialized: bool,
weight_initializer: Initializer,
bias_initializer: Initializer,
input: Array1<f64>,
output: Array1<f64>,
weights: Array2<f64>,
biases: Array1<f64>,
}
impl FCLayer {
pub fn new(num_neurons: usize, weight_initializer: Initializer, bias_initializer: Initializer) -> Self {
FCLayer {
num_neurons,
is_initialized: false,
weight_initializer,
bias_initializer,
input: arr1(&[]),
output: arr1(&[]),
weights: arr2(&[[]]),
biases: arr1(&[])
}
}
fn initialize(&mut self, input_size: usize) {
self.weights = self.weight_initializer.init((input_size, self.num_neurons));
self.biases = self.bias_initializer.init(self.num_neurons);
self.is_initialized = true;
}
}
impl Layer for FCLayer {
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
if !self.is_initialized {
self.initialize(input.len());
}
self.input = input.to_owned();
self.output = self.input.dot(&self.weights) + &self.biases;
self.output.clone()
}
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64> {
let input_error = output_error.dot(&self.weights.t());
let delta_weights =
self.input.to_owned().into_shape((self.input.len(), 1usize)).unwrap()
.dot(&output_error.into_shape((1usize, output_error.len())).unwrap());
self.weights = &self.weights + learning_rate * &delta_weights;
self.biases = &self.biases + learning_rate * &output_error;
input_error
}
}