#include"../headers/qdmetric.h"
#include<iostream>

//QD inverse: replaces vec with (M+regul.Id)^{-1}.vec
void QDMetric::solveInPlace(Vector& vec,long double regulamount)
{
 long double m00,m0i,sum,g0,denom;
 m00=d[0]+regulamount;g0=vec(0);
 sum=0;
 for(long i=1;i<rows();i++){
	 m0i=q[i];
	 denom=(d[i]+regulamount)*m00-m0i*m0i;
	 if(denom>0)
		 vec(i)=(vec(i)*m00-g0*m0i)/denom;
	 else
		 vec(i)=0;
	 sum+=vec(i)*m0i;
 }
 vec(0)=(g0-sum)/m00;
}

//QD inverse: replaces vec with (M+diagregul)^{-1}.vec
void QDMetric::solveInPlace(Vector& vec,const std::vector<long double>&diagregul)
{
 long double m00,m0i,sum,g0,denom;
 m00=d[0]+diagregul[0];g0=vec(0);
 sum=0;
 for(long i=1;i<rows();i++){
	 m0i=q[i];
	 denom=(d[i]+diagregul[i])*m00-m0i*m0i;
	 if(denom>0)
		 vec(i)=(vec(i)*m00-g0*m0i)/denom;
	 else
		 vec(i)=0;
	 sum+=vec(i)*m0i;
 }
 vec(0)=(g0-sum)/m00;
}

long double QDMetric::squareNorm(const Vector&vec)
{
 //QD reduction of a matrix M is diag(M)+v.v^T/M00-diag(v.v^T)/M00 where
 //v is the first column of M
 long double sum=0,scal=0,sum2=0;
 long i;
 for(i=0;i<rows();i++)sum+=d[i]*vec(i)*vec(i);
 if(d[0]==0)return sum;
 scal=d[0]*vec(0);
 for(i=1;i<rows();i++)scal+=q[i]*vec(i);
 sum+=scal*scal/d[0];
 sum2=d[0]*d[0]*vec(0)*vec(0);
 for(i=1;i<rows();i++)sum2+=q[i]*q[i]*vec(i)*vec(i);
 sum-=sum2/d[0];
 return sum;
}

long double QDMetric::dotProd(const Vector& vec1,const Vector& vec2)
{
 //QD reduction of a matrix M is diag(M)+v.v^T/M00-diag(v.v^T)/M00 where
 //v is the first column of M
 long double sum=0,scal1=0,scal2=0,sum2=0;
 long i;
 for(i=0;i<rows();i++)sum+=d[i]*vec1(i)*vec2(i);
 if(d[0]==0)return sum;
 scal1=d[0]*vec1(0);
 for(i=1;i<rows();i++)scal1+=q[i]*vec1(i);
 scal2=d[0]*vec2(0);
 for(i=1;i<rows();i++)scal2+=q[i]*vec2(i);
 sum+=scal1*scal2/d[0];
 sum2=d[0]*d[0]*vec1(0)*vec2(0);
 for(i=1;i<rows();i++)sum2+=q[i]*q[i]*vec1(i)*vec2(i);
 sum-=sum2/d[0];
 return sum;
}

long double QDMetric::inverseMetricSquareNorm(const Vector& vec,long double regulamount)
{
 long double sqnorm=inverseMetricDotProd(vec,vec,regulamount);
 return sqnorm>0?sqnorm:0;
}

//dot product in the inverse metric: vec1^T.(M+regul.Id)^{-1}.vec2
long double QDMetric::inverseMetricDotProd(const Vector&vec1,const Vector&vec2,long double regul)
{
 Vector vec3=vec1;
 solveInPlace(vec3,regul);
 long double dotprod=0;
 for(long i=0;i<rows();i++)dotprod+=vec3(i)*vec2(i);
 return dotprod;
}


long double QDMetric::operator()(long i,long j)const
{
 if(i==j)return d[i];
 if(i==0&&j>0)return q[j];
 if(i>0&&j==0)return q[i];
 return q[i]*q[j]/d[0];
}

long double& QDMetric::operator()(long i,long j)
{
 if(i==j)return d[i];
 if(i==0&&j>0)return q[j];
 if(i>0&&j==0)return q[i];
 q[0]=q[i]*q[j]/d[0];
 return q[0];//sink value, used when trying to set illegal values...
}

void QDMetric::OuterProductUpdate(const Vector&vec,long double factor)
{
 long i;
 for(i=0;i<d.size();i++)d[i]+=factor*vec(i)*vec(i);
 for(i=1;i<d.size();i++)q[i]+=factor*vec(0)*vec(i);
}

QDMetric& QDMetric::operator*=(long double lambda)
{
 for(auto& v:d)v*=lambda;
 for(auto& v:q)v*=lambda;
 return *this;
}

void UnitWiseMetric::solveInPlace(Vector& vec,long double regulamount)
{
(metric+regulamount*Matrix::Identity(metric.rows(),metric.cols())).ldlt().solveInPlace(vec);
}

//replaces vec with (M+diagregul)^{-1}.vec
void UnitWiseMetric::solveInPlace(Vector& vec,const std::vector<long double>&diagregul)
{
 Vector diagreg=Vector::Map(diagregul.data(),diagregul.size());
 (metric+diagreg.asDiagonal().toDenseMatrix()).ldlt().solveInPlace(vec);
}

long double UnitWiseMetric::squareNorm(const Vector&vec)
{
 return vec.adjoint()*metric*vec;
}

long double UnitWiseMetric::dotProd(const Vector& vec1,const Vector& vec2)
{
 return vec1.adjoint()*metric*vec2;
}

long double UnitWiseMetric::inverseMetricSquareNorm(const Vector& vec,long double regulamount)
{
 long double sqnorm=inverseMetricDotProd(vec,vec,regulamount);
 return sqnorm>0?sqnorm:0;
}

//dot product in the inverse metric: vec1^T.(M+regul.Id)^{-1}.vec2
long double UnitWiseMetric::inverseMetricDotProd(const Vector&vec1,const Vector&vec2,long double regul)
{
 Vector vec3=vec1;
 solveInPlace(vec3,regul);
 return vec2.adjoint()*vec3;
}


long double UnitWiseMetric::operator()(long i,long j)const
{
 return metric(i,j);
}

long double& UnitWiseMetric::operator()(long i,long j)
{
 return metric(i,j);
}

void UnitWiseMetric::OuterProductUpdate(const Vector&vec,long double factor)
{
	metric+=factor*vec*vec.adjoint();
}

UnitWiseMetric& UnitWiseMetric::operator*=(long double lambda)
{
 metric=(lambda*metric).eval();
 return *this;
}

void ConstantMetric::solveInPlace(Vector& vec,long double regulamount)
{
 long i;
 for(i=0;i<size;i++)vec(i)=vec(i)/(lambda+regulamount);
}

//QD inverse: replaces vec with (M+diagregul)^{-1}.vec
void ConstantMetric::solveInPlace(Vector& vec,const std::vector<long double>&diagregul)
{
 long i;
 for(i=0;i<size;i++)vec(i)=vec(i)/(lambda+diagregul[i]);
}

long double ConstantMetric::squareNorm(const Vector&vec)
{
 return lambda*vec.squaredNorm();
}

long double ConstantMetric::dotProd(const Vector& vec1,const Vector& vec2)
{
 return lambda*vec1.adjoint()*vec2;
}

long double ConstantMetric::inverseMetricSquareNorm(const Vector& vec,long double regulamount)
{
 long double sqnorm=inverseMetricDotProd(vec,vec,regulamount);
 return sqnorm>0?sqnorm:0;
}

//dot product in the inverse metric: vec1^T.(M+regul.Id)^{-1}.vec2
long double ConstantMetric::inverseMetricDotProd(const Vector&vec1,const Vector&vec2,long double regul)
{
 Vector vec3=vec1;
 solveInPlace(vec3,regul);
 return vec2.adjoint()*vec3;
}


long double ConstantMetric::operator()(long i,long j)const
{
 return (i==j?lambda:0);
}

long double& ConstantMetric::operator()(long i,long j)
{
 long double q=0;
 return (i==j?lambda:q);//sink value, used when trying to set illegal values... Undefined behaviour might occur if used.
 //this will trigger a warning during make. You can ignore this warning if you don't use this function for i!=j
}

void ConstantMetric::OuterProductUpdate(const Vector&vec,long double factor)
{
	lambda+=1;
}

ConstantMetric& ConstantMetric::operator*=(long double gamma)
{
 lambda*=gamma;
 return *this;
}




