#include<iostream>

#include"../headers/rand.h" //my homemade random number functions
#include"../headers/rankone.h"

using namespace std;

//used to make sure all vectors, matrices etc., are correctly sized
template<class Metric>
bool OnlineRk1Optim<Metric>::Check()
{
 if(!transparams.size()&&!outparams.size()){
	 if(verbose)cout<<"OnlineRk1Optim: Please set the pointers to transition and output parameters. There are currently no parameters to optimize."<<endl;
	 return false;
 }
 dimV=transparams.size();
 if(vbar.size()!=dimV)vbar=vector<long double>(dimV,0.);
 if(vd.size()!=dimV)vd=vector<long double>(dimV,0.);
 if(g.size()!=dimV)g=vector<long double>(dimV,0.);
 if(UpdateAct&&vup.size()!=dimV)vup=vector<long double>(dimV,0.);

 if(wbar.size()!=dimV){
	 wbar=vector<Vector >(dimV);
	 for(long i=0;i<dimV;i++){
		 wbar[i]=Vector::Zero(transparams[i].size());
	 }
 }
 if(wd.size()!=dimV){
	 wd=vector<Vector >(dimV);
	 for(long i=0;i<dimV;i++)
		 wd[i]=Vector::Zero(transparams[i].size());
 }
 if(gout.size()!=outparams.size()){
	 gout=vector<Vector >(outparams.size());
	 for(long i=0;i<outparams.size();i++)
		 gout[i]=Vector::Zero(outparams[i].size());
 }
 if(outmetric.size()!=outparams.size()){
	 outmetric=vector<Metric>(outparams.size());
	 for(long i=0;i<outparams.size();i++)
		 outmetric[i]=Metric(outparams[i].size());
 }
 if(transmetric.size()!=transparams.size()){
	 transmetric=vector<Metric>(transparams.size());
	 for(long i=0;i<transparams.size();i++)
		 transmetric[i]=Metric(transparams[i].size());
 }
 if(vmetric.size()!=transparams.size())
	 vmetric=vector<long double>(transparams.size(),0.);
 return true;
}

template<class Metric>
void OnlineRk1Optim<Metric>::PrepareTransition()//applies dF/dV to vbar and stores dF/dTheta in vd.wd
{
 if(!Check())return;

 ApplyDFDV(vbar);//and wbar is unchanged
 DFDTheta(wd);
 for(auto&vdi:vd)vdi=1;
}

template<class Metric>
void OnlineRk1Optim<Metric>::MakeGradStep(bool update)
{
 if(update){
  if(!Check())return;
  long i,j,dimi;
  if(DLossDoutparams){
          DLossDoutparams(gout);
          for(i=0;i<outparams.size();i++){
         	 outmetric[i]*=1.-outmetric_gamma;
         	 outmetric[i].OuterProductUpdate(gout[i]);
         	 //outmetric[i].solveInPlace(gout[i],1);//1=regularizaion
         	 //outmetric[i].solveInPlace(gout[i],outmetric[i].rows());
         	 outmetric[i].solveInPlace(gout[i],outdiagregul[i]);
         	 dimi=outparams[i].size();
         	 for(j=0;j<dimi;j++)
         		 *outparams[i][j]-=outrate*gout[i](j);
          }
  }
  if(DLossDV){
          DLossDV(g);
          long double sum=0;//will contain (dLoss/dV).vbar
          for(i=0;i<dimV;i++)
         	 sum+=g[i]*vbar[i];
          long double sum2=0;
          for(i=0;i<dimV;i++){
         	 Vector dLdthetai=(g[i]*vd[i])*wd[i]+sum*wbar[i];//this is dLoss/dV.dV/dtheta_i using the current estimate of dV/dtheta_i
         	 transmetric[i]*=1.-transmetric_gamma;
         	 transmetric[i].OuterProductUpdate(dLdthetai);
         	 //transmetric[i].solveInPlace(dLdthetai,1);//1=regularization
         	 //transmetric[i].solveInPlace(dLdthetai,(i+1)*(i+1));//different regularizations for different units = prior that only a few weights are significantly non-zero
         	 //cout<<transmetric[i].rows()<<endl;
         	 //transmetric[i].solveInPlace(dLdthetai,dimV+transmetric[i].rows());
         	 transmetric[i].solveInPlace(dLdthetai,transdiagregul[i]);
         	 dimi=transparams[i].size();
         	 for(j=0;j<dimi;j++)
         		 *transparams[i][j]-=transrate*dLdthetai(j);
         	 if(UpdateAct){
         		 vup[i]=vd[i]*wd[i].dot(dLdthetai);
         		 sum2+=wbar[i].dot(dLdthetai);
         	 }
          }
          if(UpdateAct){
         	 for(i=0;i>dimV;i++)vup[i]+=vbar[i]*sum2;
         	 for(i=0;i>dimV;i++)vup[i]*=-transrate;
         	 UpdateAct(vup);
          }
  }
 }

 Reduce();
}

template<class Metric>
void OnlineRk1Optim<Metric>::Reduce()
{
 ComputeVMetric();
 long i,j;long double randomsign;
 EqualizeNorm(vbar,wbar);
 for(i=0;i<dimV;i++)EqualizeNorm(i,vd[i],wd[i]);
 for(i=0;i<dimV;i++){ 
	 randomsign=2*alea(2)-1;
	 vbar[i]+=randomsign*vd[i];
	 wbar[i]+=randomsign*wd[i];
	 vd[i]=0;//no need to reset wd[i]
 }
}

template<class Metric>
void OnlineRk1Optim<Metric>::ComputeVMetric()
{
#ifdef USE_METRIC
 static long iter = 0;
 long double epsilon=1e-10;
 long i;
 long double nwbar=epsilon,nwd,dp;
 for(i=0;i<dimV;i++)
	 nwbar+=transmetric[i].inverseMetricSquareNorm(wbar[i],1);//1=regul
 for(i=0;i<dimV;i++){
	 nwd=transmetric[i].inverseMetricSquareNorm(wd[i],1);
	 dp=transmetric[i].inverseMetricDotProd(wd[i],wbar[i],1);
	 vmetric[i]=nwd*vd[i]*vd[i]
		 +vbar[i]*vbar[i]*nwbar
		 +2*vbar[i]*vd[i]*dp
		 ;
	 //TODO: could do it faster, with only one inverseMetricDotProd,
	 //because
	 //vmetric[i]=transmetric[i].inverseMetricSquareNorm(vd[i]*wd[i]+vbar[i]*wbar[i])+(nwbar-transmetric[i].inverseMetricSquareNorm(wbar[i],1))*vbar[i]*vbar[i];
	 if(vmetric[i]<0)vmetric[i]=epsilon;else vmetric[i]+=epsilon;
	 vmetric[i]=1./vmetric[i]+1;//1=regul. Typically vmetric should grow like t or sqrt(t) or... and be at least the inverse covariance of the prior
 }
#else
 for(long i=0;i<dimV;i++)vmetric[i]=1.;
#endif
}

template<class Metric>
void OnlineRk1Optim<Metric>::EqualizeNorm(vector<long double>&v,vector<Vector>&w)
{
 long double epsilon=1e-10;
 //Eudlidian:
 long i;
 long double nv=epsilon,nw=epsilon;
 //Actually I think nw is also computed in ComputeVMetric()
 for(i=0;i<dimV;i++){
	 nw+=transmetric[i].inverseMetricSquareNorm(w[i],1);
	 nv+=vmetric[i]*v[i]*v[i];
 }
 nv=sqrtl(sqrtl(nv/nw));
 for(i=0;i<dimV;i++){
	 v[i]/=nv;w[i]*=nv;
 }
}

template<class Metric>
void OnlineRk1Optim<Metric>::EqualizeNorm(long i,long double&vi,Vector& wi)
{
 long double epsilon=1e-10;
 //Euclidian:
 long double nv,nw;
 nv=epsilon+vi*vi*vmetric[i];
 nw=epsilon+transmetric[i].inverseMetricSquareNorm(wi,1);
 nv=sqrtl(sqrtl(nv/(nw+epsilon)))+epsilon;
 vi/=nv;wi*=nv;
}

template class OnlineRk1Optim<ConstantMetric>;
template class OnlineRk1Optim<UnitWiseMetric>;
template class OnlineRk1Optim<QDMetric>;
