#include<iostream>
#include"../headers/rand.h" //my homemade random number functions
#include"../headers/rankn.h"

using namespace std;

//used to make sure al vectors, matrices etc., are correctly sized
template<long N, class Metric>
bool OnlineOptim<N,Metric>::Check()
{
 if(!transparams.size()&&!outparams.size()){
	 if(verbose)cout<<"OnlineOptim: Please set the pointers to transition and output parameters. There are currently no parameters to optimize."<<endl;
	 return false;
 }
 dimV=transparams.size();
 if(vbar.size()!=N)vbar=vector<vector<long double>>(N);
 if(vd.size()!=N)vd=vector<vector<long double>>(N); 
 if(wbar.size()!=N)wbar=vector<vector<Vector>>(N);
 if(wd.size()!=N)wd=vector<vector<Vector>>(N);
 if(g.size()!=dimV)g=vector<long double>(dimV,0.);
 if(UpdateAct&&vup.size()!=dimV)vup=vector<long double>(dimV,0.);
 for(long n=0;n<N;n++){
  if(vbar[n].size()!=dimV)vbar[n]=vector<long double>(dimV,0.);
  if(vd[n].size()!=dimV)vd[n]=vector<long double>(dimV,0.);
  if(wbar[n].size()!=dimV){
	  wbar[n]=vector<Vector>(dimV);
	  for(long i=0;i<dimV;i++)
		  wbar[n][i]=Vector::Zero(transparams[i].size());
  }
  if(wd[n].size()!=dimV){
	  wd[n]=vector<Vector>(dimV);
	  for(long i=0;i<dimV;i++)
		  wd[n][i]=Vector::Zero(transparams[i].size());
  }
 }
 if(gout.size()!=outparams.size()){
	 gout=vector<Vector>(outparams.size());
	 for(long i=0;i<outparams.size();i++)
		 gout[i]=Vector::Zero(outparams[i].size());
 }
 if(outmetric.size()!=outparams.size()){
	 outmetric=vector<Metric>(outparams.size());
	 for(long i=0;i<outparams.size();i++)
		 outmetric[i]=Metric(outparams[i].size());
 }
 if(transmetric.size()!=transparams.size()){
	 transmetric=vector<Metric>(transparams.size());
	 for(long i=0;i<transparams.size();i++)
		 transmetric[i]=Metric(transparams[i].size());
 }
 if(vmetric.size()!=transparams.size())
	 vmetric=vector<long double>(transparams.size(),0.);
 return true;
}

template<long N, class Metric>
void OnlineOptim<N,Metric>::PrepareTransition()//applies dF/dV to the vbars and stores dF/dTheta in the vd.wds
{
 if(!Check())return;
 for(long n=0;n<N;n++)ApplyDFDV(vbar[n]);
 for(long n=0;n<N;n++)DFDTheta(wd[n]);
 for(long n=0;n<N;n++){
	 for(auto&vdi:vd[n])vdi=1;
 }
}

template<long N, class Metric>
void OnlineOptim<N,Metric>::MakeGradStep(bool update)
{
 if(update){
  if(!Check())return;
  long n,i,j,dimi;
  if(DLossDoutparams){
          DLossDoutparams(gout);
          for(i=0;i<outparams.size();i++){
         	 outmetric[i]*=1.-outmetric_gamma;
         	 outmetric[i].OuterProductUpdate(gout[i]);
         	 outmetric[i].solveInPlace(gout[i],outdiagregul[i]);
         	 dimi=outparams[i].size();
         	 for(j=0;j<dimi;j++)
         		 *outparams[i][j]-=outrate*gout[i](j);
          }
  }
  if(DLossDV){
          DLossDV(g);
          vector<long double> sum(N,0.);//will contain (dLoss/dV).vbar[n]
          vector<long double> sum2(N,0.);//will contain (wbar[n].dLoss/dtheta)
          for(n=0;n<N;n++)for(i=0;i<dimV;i++)sum[n]+=g[i]*vbar[n][i];
          for(i=0;i<dimV;i++){
         	 Vector dLdthetai=Vector::Zero(wbar[0][i].rows());
         	 for(n=0;n<N;n++)dLdthetai+=(g[i]*vd[n][i])*wd[n][i]+sum[n]*wbar[n][i];
         	 dLdthetai/=N;//this is dLoss/dV.dV/dtheta_i using the current estimate of dV/dtheta_i
         	 transmetric[i]*=1.-transmetric_gamma;
         	 transmetric[i].OuterProductUpdate(dLdthetai);
         	 transmetric[i].solveInPlace(dLdthetai,transdiagregul[i]);
         	 dimi=transparams[i].size();
         	 for(j=0;j<dimi;j++)
         		 *transparams[i][j]-=transrate*dLdthetai(j);
         	 if(UpdateAct){
         		 vup[i]=0.;
         		 for(n=0;n<N;n++)vup[i]+=vd[n][i]*wd[n][i].dot(dLdthetai);
         		 for(n=0;n<N;n++)sum2[n]+=wbar[n][i].dot(dLdthetai);
         	 }
          }
          if(UpdateAct){
         	 for(i=0;i<dimV;i++)for(n=0;n<N;n++)vup[i]+=vbar[n][i]*sum2[n];
         	 for(i=0;i<dimV;i++)vup[i]*=-transrate/N;
         	 UpdateAct(vup);
          }
  }
 }
 Reduce();
}

template<long N, class Metric>
void OnlineOptim<N,Metric>::Reduce()
{
 ComputeVMetric();
 long n,i,j;long double randomsign;
 for(n=0;n<N;n++)EqualizeNorm(vbar[n],wbar[n]);
 for(i=0;i<dimV;i++)for(n=0;n<N;n++)EqualizeNorm(i,vd[n][i],wd[n][i]);
 for(i=0;i<dimV;i++){
	 for(n=0;n<N;n++){
		 randomsign=2*alea(2)-1;
		 vbar[n][i]+=randomsign*vd[n][i];
		 wbar[n][i]+=randomsign*wd[n][i];
		 vd[n][i]=0;//no need to reset wd[n][i]
	 }
 }
}

template<long N, class Metric>
void OnlineOptim<N,Metric>::ComputeVMetric()
{
 long double epsilon=1e-10;
 long n,i;
 long double nwbarn,nwdn,dpn,vmetricn;
 for(i=0;i<dimV;i++)vmetric[i]=0.;//reseting vmetric
 for(n=0;n<N;n++){
	 nwbarn=epsilon;
	 for(i=0;i<dimV;i++)
		 nwbarn+=transmetric[i].inverseMetricSquareNorm(wbar[n][i],1);
	 for(i=0;i<dimV;i++){
		 nwdn=transmetric[i].inverseMetricSquareNorm(wd[n][i],1);
		 dpn=transmetric[i].inverseMetricDotProd(wd[n][i],wbar[n][i],1);
		 vmetricn=nwdn*vd[n][i]*vd[n][i]
		 +vbar[n][i]*vbar[n][i]*nwbarn
		 +2*vbar[n][i]*vd[n][i]*dpn;
		 if(vmetricn<0)vmetricn=epsilon;else vmetricn+=epsilon;
		 vmetric[i]+=vmetricn;
	 }
 }
 for(i=0;i<dimV;i++)vmetric[i]=N/vmetric[i]+1.;
}

template<long N, class Metric>
void OnlineOptim<N,Metric>::EqualizeNorm(vector<long double>&v, vector<Vector>&w)
{
 long double epsilon=1e-10;
 long i;
 long double nv=epsilon,nw=epsilon;
 for(i=0;i<dimV;i++){
	 nw+=transmetric[i].inverseMetricSquareNorm(w[i],1);
	 nv+=vmetric[i]*v[i]*v[i];
 }
 nv=sqrtl(sqrtl(nv/nw));
 for(i=0;i<dimV;i++){
	 v[i]/=nv;w[i]*=nv;
 }
}

template<long N, class Metric>
void OnlineOptim<N,Metric>::EqualizeNorm(long i,long double&vi,Vector&wi)
{
 long double epsilon=1e-10;
 //Euclidian:
 long double nv,nw;
 nv=epsilon+vi*vi*vmetric[i];
 nw=epsilon+transmetric[i].inverseMetricSquareNorm(wi,1);
 nv=sqrtl(sqrtl(nv/nw));
 vi/=nv;wi*=nv;
}
template class OnlineOptim<1,QDMetric>;
template class OnlineOptim<2,QDMetric>;
template class OnlineOptim<5,QDMetric>;
template class OnlineOptim<10,QDMetric>;
template class OnlineOptim<100,QDMetric>;
template class OnlineOptim<1000,QDMetric>;

template class OnlineOptim<1,UnitWiseMetric>;
template class OnlineOptim<2,UnitWiseMetric>;
template class OnlineOptim<5,UnitWiseMetric>;
template class OnlineOptim<10,UnitWiseMetric>;
template class OnlineOptim<100,UnitWiseMetric>;
template class OnlineOptim<1000,UnitWiseMetric>;

template class OnlineOptim<1,ConstantMetric>;
template class OnlineOptim<2,ConstantMetric>;
template class OnlineOptim<5,ConstantMetric>;
template class OnlineOptim<10,ConstantMetric>;
template class OnlineOptim<100,ConstantMetric>;
template class OnlineOptim<1000,ConstantMetric>;
