Initial commit
This commit is contained in:
commit
961626616f
12 changed files with 569 additions and 0 deletions
40
src/layers/activation_layer.rs
Normal file
40
src/layers/activation_layer.rs
Normal file
|
@ -0,0 +1,40 @@
|
|||
use ndarray::{Array1, arr1, ArrayView1};
|
||||
|
||||
use crate::functions::activation_functions::*;
|
||||
use super::Layer;
|
||||
|
||||
pub struct ActivationLayer {
|
||||
input: Array1<f64>,
|
||||
output: Array1<f64>,
|
||||
activation: fn(&Array1<f64>) -> Array1<f64>,
|
||||
activation_prime: fn(&Array1<f64>) -> Array1<f64>
|
||||
}
|
||||
|
||||
impl ActivationLayer {
|
||||
pub fn new(activation_fn: Type) -> Self {
|
||||
let (activation, activation_prime) = parse_type(activation_fn);
|
||||
ActivationLayer {
|
||||
input: arr1(&[]),
|
||||
output: arr1(&[]),
|
||||
activation,
|
||||
activation_prime
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Layer for ActivationLayer {
|
||||
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
|
||||
self.input = input.to_owned();
|
||||
self.output = (self.activation)(&self.input);
|
||||
self.output.clone()
|
||||
}
|
||||
|
||||
fn backward_pass(&mut self, output_error: ArrayView1<f64>, _learning_rate: f64) -> Array1<f64> {
|
||||
// (self.activation_prime)(&self.input).into_shape((1 as usize, output_error.len() as usize)).unwrap().dot(&output_error)
|
||||
// (self.activation_prime)(&self.input) * &output_error
|
||||
let mut temp = (self.activation_prime)(&self.input);
|
||||
temp.zip_mut_with(&output_error, |x, y| *x *= y);
|
||||
temp
|
||||
}
|
||||
|
||||
}
|
104
src/layers/fc_layer.rs
Normal file
104
src/layers/fc_layer.rs
Normal file
|
@ -0,0 +1,104 @@
|
|||
extern crate ndarray;
|
||||
|
||||
use ndarray::{Array1, Array2, arr1, arr2, Array, ArrayView1, ShapeBuilder};
|
||||
use ndarray_rand::RandomExt;
|
||||
use ndarray_rand::rand_distr::{Normal, Uniform};
|
||||
|
||||
use super::Layer;
|
||||
|
||||
pub enum Initializer {
|
||||
Zeros,
|
||||
Ones,
|
||||
Gaussian(f64, f64),
|
||||
GaussianWFactor(f64, f64, f64),
|
||||
Uniform(f64, f64)
|
||||
}
|
||||
|
||||
impl Initializer {
|
||||
pub fn init<Sh, D>(&self, shape: Sh) -> Array<f64, D>
|
||||
where
|
||||
Sh: ShapeBuilder<Dim = D>, D: ndarray::Dimension
|
||||
{
|
||||
match self {
|
||||
Self::Zeros => Array::zeros(shape),
|
||||
Self::Ones => Array::ones(shape),
|
||||
Self::Gaussian(mean, stddev) => Array::random(shape, Normal::new(*mean, *stddev).unwrap()),
|
||||
Self::GaussianWFactor(mean, stddev, factor)
|
||||
=> Array::random(shape, Normal::new(*mean, *stddev).unwrap()) * *factor,
|
||||
Self::Uniform(low, high) => Array::random(shape, Uniform::new(low, high))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FCLayer {
|
||||
num_neurons: usize,
|
||||
is_initialized: bool,
|
||||
weight_initializer: Initializer,
|
||||
bias_initializer: Initializer,
|
||||
input: Array1<f64>,
|
||||
output: Array1<f64>,
|
||||
weights: Array2<f64>,
|
||||
biases: Array1<f64>,
|
||||
}
|
||||
|
||||
impl FCLayer {
|
||||
pub fn new(num_neurons: usize, weight_initializer: Initializer, bias_initializer: Initializer) -> Self {
|
||||
FCLayer {
|
||||
num_neurons,
|
||||
is_initialized: false,
|
||||
weight_initializer,
|
||||
bias_initializer,
|
||||
input: arr1(&[]),
|
||||
output: arr1(&[]),
|
||||
weights: arr2(&[[]]),
|
||||
biases: arr1(&[])
|
||||
}
|
||||
}
|
||||
|
||||
fn initialize(&mut self, input_size: usize) {
|
||||
self.weights = self.weight_initializer.init((input_size, self.num_neurons));
|
||||
self.biases = self.bias_initializer.init(self.num_neurons);
|
||||
self.is_initialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
impl Layer for FCLayer {
|
||||
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64> {
|
||||
if !self.is_initialized {
|
||||
self.initialize(input.len());
|
||||
}
|
||||
|
||||
self.input = input.to_owned();
|
||||
self.output = self.input.dot(&self.weights) + &self.biases;
|
||||
self.output.clone()
|
||||
}
|
||||
|
||||
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64> {
|
||||
//let input_error = output_error.dot(&self.weights.clone().reversed_axes());
|
||||
/* let input_error = stack(Axis(0), &vec![output_error; self.num_neurons]).unwrap().dot(&self.weights.clone().reversed_axes());
|
||||
|
||||
// let weights_error = self.input.clone().into_shape((1 as usize, self.num_neurons as usize)).unwrap().dot(&output_error);
|
||||
// let weights_error = self.input.clone().reversed_axes().dot(&output_error);
|
||||
// let mut weights_error = self.input.clone();
|
||||
// weights_error.zip_mut_with(&output_error, |x, y| *x *= y);
|
||||
let weights_error = self.input.clone().t().dot(&output_error.broadcast((self.input.len(),)).unwrap());
|
||||
|
||||
self.weights = &self.weights + learning_rate * weights_error;
|
||||
self.biases = &self.biases + learning_rate * &output_error;
|
||||
let len = input_error.len();
|
||||
let a = input_error.into_shape((len, )).unwrap();
|
||||
a */
|
||||
/* let delta_weights = &self.output.t() * &output_error;
|
||||
let delta_biases = output_error.sum_axis(Axis(0));
|
||||
self.weights = &self.weights + learning_rate * delta_weights;
|
||||
self.biases = &self.biases + learning_rate * delta_biases;
|
||||
output_error.dot(&self.weights.t()) */
|
||||
let input_error = output_error.dot(&self.weights.t());
|
||||
let delta_weights =
|
||||
self.input.to_owned().into_shape((self.input.len(), 1usize)).unwrap()
|
||||
.dot(&output_error.into_shape((1usize, output_error.len())).unwrap());
|
||||
self.weights = &self.weights + learning_rate * &delta_weights;
|
||||
self.biases = &self.biases + learning_rate * &output_error;
|
||||
input_error
|
||||
}
|
||||
}
|
9
src/layers/mod.rs
Normal file
9
src/layers/mod.rs
Normal file
|
@ -0,0 +1,9 @@
|
|||
use ndarray::{Array1, ArrayView1};
|
||||
|
||||
pub mod activation_layer;
|
||||
pub mod fc_layer;
|
||||
|
||||
pub trait Layer {
|
||||
fn forward_pass(&mut self, input: ArrayView1<f64>) -> Array1<f64>;
|
||||
fn backward_pass(&mut self, output_error: ArrayView1<f64>, learning_rate: f64) -> Array1<f64>;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue