extern crate ndarray; use ndarray::{Array1, Array2, arr1, arr2, Array, ArrayView1, ShapeBuilder}; use ndarray_rand::RandomExt; use ndarray_rand::rand_distr::{Normal, Uniform}; use super::Layer; pub enum Initializer { Zeros, Ones, Gaussian(f64, f64), GaussianWFactor(f64, f64, f64), Uniform(f64, f64) } impl Initializer { pub fn init(&self, shape: Sh) -> Array where Sh: ShapeBuilder, D: ndarray::Dimension { match self { Self::Zeros => Array::zeros(shape), Self::Ones => Array::ones(shape), Self::Gaussian(mean, stddev) => Array::random(shape, Normal::new(*mean, *stddev).unwrap()), Self::GaussianWFactor(mean, stddev, factor) => Array::random(shape, Normal::new(*mean, *stddev).unwrap()) * *factor, Self::Uniform(low, high) => Array::random(shape, Uniform::new(low, high)) } } } pub struct FCLayer { num_neurons: usize, is_initialized: bool, weight_initializer: Initializer, bias_initializer: Initializer, input: Array1, output: Array1, weights: Array2, biases: Array1, } impl FCLayer { pub fn new(num_neurons: usize, weight_initializer: Initializer, bias_initializer: Initializer) -> Self { FCLayer { num_neurons, is_initialized: false, weight_initializer, bias_initializer, input: arr1(&[]), output: arr1(&[]), weights: arr2(&[[]]), biases: arr1(&[]) } } fn initialize(&mut self, input_size: usize) { self.weights = self.weight_initializer.init((input_size, self.num_neurons)); self.biases = self.bias_initializer.init(self.num_neurons); self.is_initialized = true; } } impl Layer for FCLayer { fn forward_pass(&mut self, input: ArrayView1) -> Array1 { if !self.is_initialized { self.initialize(input.len()); } self.input = input.to_owned(); self.output = self.input.dot(&self.weights) + &self.biases; self.output.clone() } fn backward_pass(&mut self, output_error: ArrayView1, learning_rate: f64) -> Array1 { let input_error = output_error.dot(&self.weights.t()); let delta_weights = self.input.to_owned().into_shape((self.input.len(), 1usize)).unwrap() .dot(&output_error.into_shape((1usize, output_error.len())).unwrap()); self.weights = &self.weights + learning_rate * &delta_weights; self.biases = &self.biases + learning_rate * &output_error; input_error } }