#ifndef RANKONE_H
#define RANKONE_H
#include<vector>
#include<functional>
#include<eigen3/Eigen/Dense>

using std::vector;
using std::function;

#include"qdmetric.h"

template<class Metric=ConstantMetric>
struct OnlineRk1Optim{
	//Model is
	// V(t+1)=F_theta(V(t),x(t))
	// The loss at time t is a function of V(t) and the value y(t) to be
	// predicted

	//This class provides a learning algorithm for theta. Please do not
	//change the parameter set or state dimension while running or strange
	//things may happen.

	//The following objects:
	//params, DLossDV, DFDTheta, ApplyDFDV
	//should be externally set at startup.
	//The function PrepareTransition() should be called right before each
	//transition of the dynamical system
	//The function MakeGradStep() should typically be called right after
	//each new observation becomes available. It modifies the values
	//pointed by params.

	vector<vector<long double*> > transparams;//Please set this so that transparams[i][j] points to the value of the j-th transition parameter of unit i
	vector<vector<long double*> > outparams;//Please set this so that outparams[i][j] points to the value of the j-th output parameter of output component i

	void PrepareTransition();//Please externally call this function right before each transition of the main system
	void MakeGradStep(bool update);//Make the actual gradient step. Typically call this right when each new observation is available
	long double outrate=1.,transrate=1.;//Learning rates for the output and transition parameters
	long double outmetric_gamma=0.,transmetric_gamma=0;//decay coefficients for the metrics
	vector<vector<long double> > outdiagregul,transdiagregul; //Regularization coefficients (initial inverse covariance) for each parameters. Ideally, for each unit their inverses should have a finite sum


	function <void(vector<long double>&) > DLossDV;//set this function so that calling DLossDV(g) fills the vector g with g[i]=d loss/d V[i], the derivative of the loss at the current time with respect to the current state. g will be correctly dimensioned but not set to 0
	function <void(vector<Vector>&) > DLossDoutparams;//set this function so that calling DLossDoutparams(g) fills the vector g with g[i][j]=d loss/d outparams[i][j]. g will be correctly dimensioned but not set to 0
	function<void(vector<long double>&) > ApplyDFDV;//set this function so that ApplyDFDV(deltav) computes dF/dV at the current state of the system, applies it to deltav, and places the result in deltav again.
	function<void(vector<Vector>&)> DFDTheta;//set this function so that DFDTheta(w) fills w with w[i](j)=dF[i]/dTheta[ij] computed at the current state of the system, where Theta[ij] is the j-th param of unit i. w will be correctly dimensioned beforehand.
	function<void(const vector<long double>& dv)> UpdateAct;//may be set to a function that updates the internal state of the network as v <- v+dv. If set, each time the parameters are changed, the current state will be changed as well according to an estimate of how the using the new parameters would have changed the trajectory.
	function<void(vector<long double>& h)> D2LossDoutput2; //may be set so that h[i] contains the Fisher information matrix or minus the Hessian of the current loss wrt the i-th output unit (eg p_i(1-p_i) for softmax output). This will be used as second-order information. May be left unset for Euclidean gradient descent.h will be correctly dimensioned beforehand.

	bool verbose=false;//set this to get messages on cout
	OnlineRk1Optim(){};


	//All the rest is internal stuff, we provide access to it for
	//convenience

	long dimV;//should always be same as params.size()
	vector<long double> vbar,vd,g,vup;
	vector<Vector > wbar,wd;
	//vbar,wbar store the rank-one part, vd,wd store the (block)-diag part
	vector<Vector> gout;//will store the gradient wrt output parameters

	void Reduce();//reduces current grad information to something rank-one

	bool Check();//checks that all vectors, etc., are correctly sized

	vector<Metric> transmetric,outmetric;
	vector<long double> vmetric;

	void ComputeVMetric();
	void EqualizeNorm(vector<long double>&,vector<Vector>&);
	void EqualizeNorm(long,long double&,Vector&);//EqualizeNorm(i,vi,wi)
};

#endif
