#include"../headers/rand.h" //homemade random number functions
#include"../headers/bptt.h"
#include"../headers/qdmetric.h"
#include <iostream>

template<class Metric>
bool BPTTOptimBase<Metric>::Check()
{
 if(!transparams.size()){
	 return false;
 }
 if(!outparams.size()){
	 return false;
 }
 dimV=transparams.size();
 if(g.size()!=dimV)g=vector<long double>(dimV,0.);
 if(deltatransparams.size()!=transparams.size()){
	 deltatransparams.resize(transparams.size());
	 for(long i=0;i<transparams.size();i++){
		 deltatransparams[i]=Vector::Zero(transparams[i].size());
	 }
 }
 if(deltaoutparams.size()!=outparams.size()){
	 deltaoutparams.resize(outparams.size());
	 for(long i=0;i<outparams.size();i++){
		 deltaoutparams[i]=Vector::Zero(outparams[i].size());
	 }
 }
 if(wd.size()!=dimV){
	 wd=vector<Vector>(dimV);
	 for(long i=0;i<dimV;i++)
		 wd[i]=Vector::Zero(transparams[i].size());
 }
 if(gout.size()!=outparams.size()){
        gout=vector<Vector>(outparams.size());       
	for(long i=0;i<outparams.size();i++)
		gout[i]=Vector::Zero(outparams[i].size());
 }
 if(transmetric.size()!=transparams.size()){
	 transmetric=vector<Metric>(transparams.size());
	 for(long i=0;i<transparams.size();i++)
		 transmetric[i]=Metric(transparams[i].size());
 }
 if(outmetric.size()!=outparams.size()){
	 outmetric=vector<Metric>(outparams.size());
	 for(long i=0;i<outparams.size();i++)
		 outmetric[i]=Metric(outparams[i].size());
 }
 return true;
}

//Does nothing, but kept to preserve uniformity with the other algorithms.
template<class Metric>
void BPTTOptimBase<Metric>::PrepareTransition(){
}

//This version performs truncate steps of Backpropagation at each epoch.
template<long truncate,class Metric>
void BPTTOptim<truncate,true,Metric>::MakeGradStep(bool update){
 //saves the current state. Further calls to buff[epoch] will restore this state
 buff[epoch]=this->SaveState();
 if(update){
  if(!this->Check())return;
  long i,j,t,mod,dimi;
  if(this->DLossDoutparams){
          this->DLossDoutparams(this->gout);
         	 for(i=0;i<this->outparams.size();i++){
         		 this->outmetric[i]*=1.-this->outmetric_gamma;
         		 this->outmetric[i].OuterProductUpdate(this->gout[i]);
         		 this->outmetric[i].solveInPlace(this->gout[i],this->outdiagregul[i]);
         		 dimi=this->outparams[i].size();
         		 for(j=0;j<dimi;j++)*this->outparams[i][j]-=this->outrate*this->gout[i](j);
         	 }
  }
  if(this->DLossDV){
          this->DLossDV(this->g);
	  //performs truncate steps of backpropagation.
          for(t=0;t<truncate;t++){
         	 mod=(epoch-t>=0)?epoch-t:epoch-t+truncate;
		 //restores previous state.
         	 buff[mod-1>=0?mod-1:mod-1+truncate]();
         	 this->DFDTheta(this->wd);
         	 for(i=0;i<this->transparams.size();i++){
         		 dimi=this->transparams[i].size();
         		 for(j=0;j<dimi;j++){
         			 this->deltatransparams[i](j)+=this->g[i]*this->wd[i](j);
         			 }
         	 }
         	 this->MultiplyRightDFDV(this->g);
          }
          for(i=0;i<this->transparams.size();i++){
         	 dimi=this->transparams[i].size();
         	 this->transmetric[i]*=1.-this->transmetric_gamma;
         	 this->transmetric[i].OuterProductUpdate(this->deltatransparams[i]);
         	 this->transmetric[i].solveInPlace(this->deltatransparams[i],this->transdiagregul[i]);
         	 for(j=0;j<dimi;j++){
         		 *this->transparams[i][j]-=this->transrate*this->deltatransparams[i][j];
         		 this->deltatransparams[i][j]=0.;
         	 }
          }
  }
 }
 buff[epoch]();
 epoch=(epoch+1)%truncate;
}

//This version perfoms truncate steps of backpropagation every truncate epochs.
//For this version, epoch starts at 1 and not 0.
template<long truncate,class Metric>
void BPTTOptim<truncate,false,Metric>::MakeGradStep(bool update){
 buff[epoch-1]=this->SaveState();
 if(update){
  if(!this->Check())return;
  long i,j,t,mod,dimi;
  if(this->DLossDoutparams){
          this->DLossDoutparams(this->gout);
         	 for(i=0;i<this->outparams.size();i++){
         		 this->outmetric[i]*=1.-this->outmetric_gamma;
         		 this->outmetric[i].OuterProductUpdate(this->gout[i]);
         		 this->outmetric[i].solveInPlace(this->gout[i],this->outdiagregul[i]);
         		 dimi=this->outparams[i].size();
         		 for(j=0;j<dimi;j++)
         			 this->deltaoutparams[i](j)+=this->gout[i](j);
         	 }
  }
  if(this->DLossDV&&epoch==1){
          vector<long double> g2(this->dimV,0.);
          for(t=0;t<truncate;t++){
         	 this->DLossDV(this->g);
         	 for(i=0;i<this->dimV;i++)g2[i]+=this->g[i];
         	 mod=(epoch-t-1>=0)?epoch-t-1:epoch-t+truncate-1;
         	 buff[mod-1>=0?mod-1:mod-1+truncate]();
         	 this->DFDTheta(this->wd);
         	 for(i=0;i<this->transparams.size();i++){
         		 dimi=this->transparams[i].size();
         		 for(j=0;j<dimi;j++){
         			 this->deltatransparams[i](j)+=g2[i]*this->wd[i](j);
         		 }
         	 }
         	 this->MultiplyRightDFDV(g2);
          }
          for(i=0;i<this->outparams.size();i++){
         	 dimi=this->outparams[i].size();
         	 for(j=0;j<dimi;j++){
         		 *this->outparams[i][j]-=this->outrate*this->deltaoutparams[i](j);
         		 this->deltaoutparams[i](j)=0.;
         	 }
          }
          for(i=0;i<this->transparams.size();i++){
         	 this->transmetric[i]*=1.-this->transmetric_gamma;
         	 this->transmetric[i].OuterProductUpdate(this->deltatransparams[i]);
         	 this->transmetric[i].solveInPlace(this->deltatransparams[i],this->transdiagregul[i]);
         	 dimi=this->transparams[i].size();
         	 for(j=0;j<dimi;j++){
         		 *this->transparams[i][j]-=this->transrate*this->deltatransparams[i][j];
         		 this->deltatransparams[i][j]=0.;
         	 }
          }
  }
 }
 buff[epoch-1]();
 epoch=epoch%truncate+1;
}


template class BPTTOptimBase<QDMetric>;
template class BPTTOptimBase<ConstantMetric>;

template class BPTTOptim<5,true,QDMetric>; 
template class BPTTOptim<10,true,QDMetric>;
template class BPTTOptim<15,true,QDMetric>;
template class BPTTOptim<20,true,QDMetric>;
template class BPTTOptim<25,true,QDMetric>;
template class BPTTOptim<21,true,QDMetric>;
template class BPTTOptim<5,false,QDMetric>; 
template class BPTTOptim<10,false,QDMetric>;
template class BPTTOptim<15,false,QDMetric>;
template class BPTTOptim<20,false,QDMetric>;
template class BPTTOptim<21,false,QDMetric>;
template class BPTTOptim<25,false,QDMetric>;


template class BPTTOptim<5,true,ConstantMetric>; 
template class BPTTOptim<10,true,ConstantMetric>;
template class BPTTOptim<15,true,ConstantMetric>;
template class BPTTOptim<20,true,ConstantMetric>;
template class BPTTOptim<21,true,ConstantMetric>;
template class BPTTOptim<25,true,ConstantMetric>;
template class BPTTOptim<5,false,ConstantMetric>; 
template class BPTTOptim<10,false,ConstantMetric>;
template class BPTTOptim<15,false,ConstantMetric>;
template class BPTTOptim<20,false,ConstantMetric>;
template class BPTTOptim<21,false,ConstantMetric>;
template class BPTTOptim<25,false,ConstantMetric>;
