#include<iostream>
#include"../headers/rand.h" //my homemade random number functions
#include"../headers/rtrl.h"

using namespace std;

template <class Metric>
bool RTRLOptim<Metric>::Check()
{
 if(!transparams.size()){
	 if(verbose)cout<<"RTRLOptim: Please set the pointers to transition parameters."<<endl;
	 return false;
 }
 if(!outparams.size()){
	 if(verbose)cout<<"RTRLOptim: Please set the pointers to output parameters."<<endl;
	 return false;
 }
 dimV=transparams.size();
 if(UpdateAct&&vup.size()!=dimV)vup=vector<long double>(dimV,0.);

 if(g.size()!=dimV)g=vector<long double>(dimV,0.);
 if(dVdtheta.size()!=transparams.size()){
	 dVdtheta.resize(transparams.size());
	 for(long i=0;i<dVdtheta.size();i++){
		 dVdtheta[i].resize(transparams[i].size());
		 for(auto& component:dVdtheta[i])component=vector<long double>(dimV,0.);
	 }
 }
 if(wd.size()!=dimV){
	 wd=vector<Vector >(dimV);
	 for(long i=0;i<dimV;i++)
		 wd[i]=Vector::Zero(transparams[i].size());
 }
 if(gout.size()!=outparams.size()){
	 gout=vector<Vector >(outparams.size());
	 for(long i=0;i<outparams.size();i++)
		 gout[i]=Vector::Zero(outparams[i].size());
 }
 if(outmetric.size()!=outparams.size()){
	 outmetric=vector<Metric>(outparams.size());
	 for(long i=0;i<outparams.size();i++)
		 outmetric[i]=Metric(outparams[i].size());
 }
 if(transmetric.size()!=transparams.size()){
	 transmetric=vector<Metric>(transparams.size());
	 for(long i=0;i<transparams.size();i++)
		 transmetric[i]=Metric(transparams[i].size());
 }
 return true;
}

template <class Metric>
void RTRLOptim<Metric>::PrepareTransition()//applies dF/dV to dVdtheta and adds dF/dTheta
{
 if(!Check())return;

 for(auto&block:dVdtheta)
	 for(auto&component:block)
		 ApplyDFDV(component);
 DFDTheta(wd);
 for(long i=0;i<dVdtheta.size();i++)
	 for(long j=0;j<dVdtheta[i].size();j++)
		 dVdtheta[i][j][i]+=wd[i](j);
}

template <class Metric>
void RTRLOptim<Metric>::MakeGradStep(bool update)//No metric for now
{
 if(update){
  if(!Check())return;
  long i,j,dimi;
  if(DLossDoutparams){
          DLossDoutparams(gout);
          for(i=0;i<outparams.size();i++){
         	 outmetric[i]*=1.-outmetric_gamma;
         	 outmetric[i].OuterProductUpdate(gout[i]);
         	 outmetric[i].solveInPlace(gout[i], outdiagregul[i]);
         	 dimi=outparams[i].size();
         	 for(j=0;j<dimi;j++)
         		 *outparams[i][j]-=outrate*gout[i](j);
          }
  }
  if(DLossDV){
          DLossDV(g);
          long double sum=0;
          if(UpdateAct)vup=vector<long double>(dimV,0.);
          for(long i=0;i<dVdtheta.size();i++){
         	 Vector dLdthetai=Vector::Zero(dVdtheta[i].size());
         	 for(long j=0;j<dVdtheta[i].size();j++){
         		 sum=0;
         		 for(long k=0;k<dimV;k++)dLdthetai(j)+=g[k]*dVdtheta[i][j][k];
         	 }
         	 transmetric[i]*=1.-transmetric_gamma;
         	 transmetric[i].OuterProductUpdate(dLdthetai);
         	 transmetric[i].solveInPlace(dLdthetai,transdiagregul[i]);
         	 for(j=0;j<dVdtheta[i].size();j++){
         		 *transparams[i][j]-=transrate*dLdthetai(j);
         		 if(UpdateAct)
         			 for(long k=0;k<dimV;k++)
         				 vup[k]-=transrate*dLdthetai(j)*dVdtheta[i][j][k];
         	 }
          }
          if(UpdateAct)UpdateAct(vup);
  }
 }
}

template class RTRLOptim<QDMetric>;
template class RTRLOptim<UnitWiseMetric>;
template class RTRLOptim<ConstantMetric>;
