#include"../headers/lrnn.h"

void LRNN::computeNextAct()
{
 int i,j,e;
 for(j=0;j<n;j++)
	 sigmalpha[j]=1./(1.+exp(-alpha[j]));
 for(j=0;j<n;j++)
	 V[j]=V[j]*(1.-sigmalpha[j])+read[last_symbol_read][j]+bias[j];
 for(e=0;e<nedges;e++)
	 V[edgedest[e]]+=tau[e]*act[edgesrc[e]];
 for(i=0;i<n;i++)act[i]=actfunc(V[i]);
}

void LRNN::Setup()
{
 RNN::Setup();
 alpha=vector<long double>(n,0.);
 sigmalpha=vector<long double>(n,.5);
}

void LRNN::saveTauParams()
{
 RNN::saveTauParams();prev_alpha=alpha;prev_sigmalpha=sigmalpha;
}

void LRNN::restoreTauParams()
{
 RNN::restoreTauParams();alpha=prev_alpha;sigmalpha=prev_sigmalpha;
}

void LRNN::initWeights()
{
 int i,e,j;long y;
 long double s;
 for(y=0;y<alph;y++)wbias[y]=0;
 for(i=0;i<n;i++){
	 for(y=0;y<alph;y++)w[y][i]=0;//.1/(i+1.)/(i+1.)*alea_norm();
	 alpha[i]=log(alea()+epsilon);
	 //alpha[i]=log(1./(i+2.));
	 //if(alea(2))alpha[i]=0;else alpha[i]=log(alea()+epsilon);
	 //alpha[i]=0;
	 sigmalpha[i]=1./(1.+exp(-alpha[i]));
	 cout<<sigmalpha[i]<<endl;
 }
 for(i=0;i<n;i++){
	 for(y=0;y<alph;y++)read[y][i]=(alea()*2-1)*sqrt(sigmalpha[i]);
	 bias[i]=0;
	 start[i]=0;
	 long counter=0;
	 for(auto edgeto:edgesto[i]){
		 e=edgeto.second;j=edgeto.first;
		 counter++;tau[e]=sqrt(sigmalpha[i]*(sigmalpha[i]+sigmalpha[j]))*alea_norm()*inftheorysequence(counter);
		 if(j==i)tau[e]=-sigmalpha[i];
	 }
 }
}

/*void LRNN::initWeights()
{
 int i,e,j;long y;
 long double s;
 for(y=0;y<alph;y++)wbias[y]=0;
 for(i=0;i<n;i++){
	 for(y=0;y<alph;y++)w[y][i]=0;//.1/(i+1.)/(i+1.)*alea_norm();
	 //alpha[i]=log(alea()+epsilon);
	 alpha[i]=0;
	 sigmalpha[i]=1./(1.+exp(-alpha[i]));
	 cout<<sigmalpha[i]<<endl;
	 //alea() returns a uniform random value in [0;1]
	 for(y=0;y<alph;y++)read[y][i]=(alea()*2-1)*sigmalpha[i];
	 bias[i]=0;
	 //starting at equilibrium value
	 start[i]=0;
	 long counter=0;
	 for(auto edgeto:edgesto[i]){
		 e=edgeto.second;j=edgeto.first;
		 counter++;tau[e]=sigmalpha[i]*alea_norm()*inftheorysequence(counter);
	 }
 }
}
*/

void LRNN::ApplyDFDV(vector<long double>&dv){
	 vector<long double> oldda=dv;
	 long i,e;
	 for(i=0;i<n;i++)oldda[i]*=deractfunc(act[i]);
	 for(i=0;i<n;i++)dv[i]*=1.-sigmalpha[i];
	 for(e=0;e<nedges;e++)
		 dv[edgedest[e]]+=oldda[edgesrc[e]]*tau[e];
}

void LRNN::MultiplyRightDFDV(vector<long double>&dv){
	 vector<long double> olddv=dv;
	 long i,e;
	 for(i=0;i<n;i++)dv[i]*=1.-sigmalpha[i];
	 for(e=0;e<nedges;e++)
		 dv[edgesrc[e]]+=deractfunc(act[edgesrc[e]])*tau[e]*olddv[edgedest[e]];
}

void LRNN::DFDTheta(vector<Vector>&grad){
	 long i,k;
	 for(i=0;i<n;i++){
		 grad[i](0)=1;// = d V[i](t+1)/d bias[i]
		 k=1;
		 for(auto edge:edgesto[i])
			 grad[i](k++)=act[edge.first]; // = d V[i](t+1)/d tau[e]
		 for(long y=0;y<alph;y++)
			 grad[i](k++)=(y==last_symbol_read?1:0); // = d V[i](t+1)/ d read[y][i]
		 grad[i](k++)= -sigmalpha[i]*(1.-sigmalpha[i])*V[i];// = d V[i](t+1) / d alpha[i]
	 }
}

function<void()> LRNN::SaveState(){
	 vector<long double> activities=act;
	 vector<long double> predec=pred;
	 long symb=last_symbol_read; 
	 return [activities,predec,symb,this](){
		 act=activities;
		 last_symbol_read=symb;
		 pred=predec;
	 };
}

void LRNN::UpdateAct(const vector<long double>& dv){
	 long i;
	 for(i=0;i<n;i++)V[i]+=dv[i];
	 for(i=0;i<n;i++)act[i]=actfunc(V[i]);
}

void LRNN::setTransparams(vector<vector<long double*>>&transparams,vector<vector<long double>>&transdiagregul){
 long i,x;
 transparams=vector<vector<long double*> >(n,vector<long double*>());
 transdiagregul=vector<vector<long double> >(n,vector<long double>());
 for(i=0;i<n;i++){
	 transparams[i].push_back(&bias[i]);
	 transdiagregul[i].push_back(1./inftheorysequence(0)/sigmalpha[i]/sigmalpha[i]);
	 long double counter=1;
	 for(auto edge:edgesto[i]){
		 transparams[i].push_back(&tau[edge.second]);//params for edges to  unit i
		 long j=edge.first;long double regul=sigmalpha[i]*(sigmalpha[i]+sigmalpha[j]);
		 transdiagregul[i].push_back(1./inftheorysequence(counter)/regul);
		 ++counter;
	 }
	 for(x=0;x<alph;x++){
		 transparams[i].push_back(&read[x][i]);//input params for unit i
		 transdiagregul[i].push_back(1./inftheorysequence(0)/sigmalpha[i]);
	 }
	 transparams[i].push_back(&alpha[i]);//decay coeff for unit i
	 transdiagregul[i].push_back(100000000000.);
 }
}
