#include"../headers/rnn.h"

void RNN::Setup()
{
 NN_SoftMaxOutput::Setup();
 tau=vector<long double> (nedges,0.);
 bias=vector<long double>(n,0.);
 read=vector<vector<long double> > (alph,vector<long double>(n,0.));
 start=vector<long double> (n,0.);
}

void RNN::saveTauParams()
{
 prev_tau=tau;prev_bias=bias;prev_start=start;prev_read=read;
}

void RNN::restoreTauParams()
{
 tau=prev_tau;bias=prev_bias;start=prev_start;read=prev_read;
}

int RNN::suggestedConn()
{
 return (alph/2>3)?alph/2:3;//did the first experiments with alph instead of alph/2, was a bit too slow
 //return alph;
}

//{
// return log(2.)*(1./log(n+2.)-1./log(n+3.));//sums to 1, decreases like 1/(n.(log n)^2)
// //return 1./(n+1.)-1./(n+2.);//sums to 1, decreases like 1/n^2
//}

void RNN::initWeights()
{
 int i,e,j;long y;
 long double s;
 long double read_epsilon;
 for(y=0;y<alph;y++)wbias[y]=0;
 vector<long double> invtimescale(n);
 for(i=0;i<n;i++){
	 for(y=0;y<alph;y++)w[y][i]=0;
	 invtimescale[i]=alea();
	 //alea() returns a uniform random value in [0;1]
	 for(y=0;y<alph;y++)read[y][i]=(alea()*2-1)*invtimescale[i];
	 bias[i]=0;
	 //starting at equilibrium value
	 start[i]=0;
	 long counter=0;
	 for(auto edgeto:edgesto[i]){
		 e=edgeto.second;j=edgeto.first;
		 if(i==j)
			 tau[e]=1.-invtimescale[i];
		 else{
			 counter++;tau[e]=invtimescale[i]*alea_norm()*inftheorysequence(counter);
		 }
	 }
 }
}

//old one
//void RNN::initWeights()
//{
// int i,e;long y;
// long double s;
// long double read_epsilon;
// //Put the right global frequencies
//// for(y=0;y<alph;y++)wbias[y]=logl(predfreq[y]);
// for(y=0;y<alph;y++)wbias[y]=0;
// for(i=0;i<n;i++){
//	 //for(y=0;y<alph;y++)w[y][i]=0;
//	 for(y=0;y<alph;y++)w[y][i]=.1/n*alea_norm();
//	 read_epsilon=.5;
//	 //alea() returns a uniform random value in [0;1]
//	 for(y=0;y<alph;y++)read[y][i]=alea()*read_epsilon;
//	 //centering
//	 //s=0;for(y=0;y<alph;y++)s+=read[y][i]*readfreq[y];
//	 s=0;for(y=0;y<alph;y++)s+=read[y][i]/alph;
//	 for(y=0;y<alph;y++)read[y][i]-=s;
//	 bias[i]=0;
//	 //starting at equilibrium value
//	 start[i]=0;
// }
// //This code assumes there are no double edges in the graph
// for(e=0;e<nedges;e++)
//	 if(edgesrc[e]!=edgedest[e])tau[e]=0;
//	 //else tau[e]=.5;//Memory 2
// else tau[e]=1.-1./(edgesrc[e]+1.);//Memory j for unit j
//}

void RNN::setToStartAct()
{
 int i;
 for(i=0;i<n;i++)V[i]=start[i];
 for(i=0;i<n;i++)act[i]=actfunc(V[i]);
}

void RNN::computeNextAct()
{
 int i,j,e;
 for(j=0;j<n;j++)
	 V[j]=read[last_symbol_read][j]+bias[j];
 for(e=0;e<nedges;e++)
	 V[edgedest[e]]+=tau[e]*act[edgesrc[e]];
 for(i=0;i<n;i++)act[i]=actfunc(V[i]);
}

void RNN::DFDTheta(vector<Vector>&grad){
	 long i,k;
	 for(i=0;i<n;i++){
		 grad[i](0)=1;// = d V[i](t+1)/d bias[i]
		 k=1;
		 for(auto edge:edgesto[i])
			 grad[i](k++)=act[edge.first]; // = d V[i](t+1)/d tau[e]
		 for(long y=0;y<alph;y++)
			 grad[i](k++)=(y==last_symbol_read?1:0); // = d V[i](t+1)/ d read[y][i]
	 }
}

void RNN::ApplyDFDV(vector<long double>&dv){
	 vector<long double> olddv=dv;
	 long i,e;
	 //long x=last_symbol_read;
	 for(i=0;i<n;i++)olddv[i]*=deractfunc(act[i]);
	 for(i=0;i<n;i++)dv[i]=0;
	 for(e=0;e<nedges;e++)
		 dv[edgedest[e]]+=olddv[edgesrc[e]]*tau[e];
}

void RNN::MultiplyRightDFDV(vector<long double>&dv){
	 vector<long double> olddv=dv;
	 long i,e;
	 for(i=0;i<n;i++)dv[i]=0;
	 for(e=0;e<nedges;e++)
		 dv[edgesrc[e]]+=deractfunc(act[edgesrc[e]])*tau[e]*olddv[edgedest[e]];
}

function<void()> RNN::SaveState(){
	 vector<long double> activities=act;
	 vector<long double> predec=pred;
	 long symb=last_symbol_read; 
	 return [activities,predec,symb,this](){
		 act=activities;
		 last_symbol_read=symb;
		 pred=predec;
	 };
}

void RNN::UpdateAct(const vector<long double>&dv){
	 long i;
	 for(i=0;i<n;i++)V[i]+=dv[i];
	 for(i=0;i<n;i++)act[i]=actfunc(V[i]);
}

void RNN::setTransparams(vector<vector<long double*>>&transparams,vector<vector<long double>>&transdiagregul){
 long i,x;
 transparams=vector<vector<long double*> >(n,vector<long double*>());
 transdiagregul=vector<vector<long double> >(n,vector<long double>());
 for(i=0;i<n;i++){
	 transparams[i].push_back(&bias[i]);
	 transdiagregul[i].push_back(1./inftheorysequence(0));
	 long double counter=1;
	 for(auto edge:edgesto[i]){
		 transparams[i].push_back(&tau[edge.second]);//params for edges to  unit i
		 transdiagregul[i].push_back(1./inftheorysequence(counter));++counter;
	 }
	 for(x=0;x<alph;x++){
		 transparams[i].push_back(&read[x][i]);//input params for unit i
		 transdiagregul[i].push_back(1./inftheorysequence(0));
	 }
 }
}
