#ifndef BPTT_H

#define BPTT_H

#include<vector>
#include<functional>
#include<eigen3/Eigen/Dense>
#include"qdmetric.h"

using std::vector;
using std::function;

template<class Metric>
struct BPTTOptimBase{
	//BPTTOptim provides an implementation of Backpropagation
	//through time for a model V(t+1)=F(V(t),theta).
	//The first template parameter defines how many steps of Backprop
	//should be used, the second one if the parameters are updated
	//at each time step or at each "truncate" time steps.
	//BPTTOptimBase provides the core features that are to be derived
	//by the different template instantiation.
	//
	vector<vector<long double*>> transparams;//Please set this so that transparams[i][j] points to the value of the j-th transition parameter of unit i
	vector<vector<long double*>> outparams;//Please set this so that outparams[i][j] points to the value of the j-th output parameter of output component i
	vector<vector<long double>> transdiagregul;//Regularization on transition parameters.
	vector<vector<long double>> outdiagregul;//Regularization on output parameter.
	void PrepareTransition();//Please externally call this function right before each transition of the main system
	virtual void MakeGradStep(bool update)=0;//Make the actual gradient step. Typically call this right when each new observation is available	
	long double outrate=1.,transrate=1.;//Learning rates for the output and transition parameters
	long double outmetric_gamma=0.,transmetric_gamma=0.;//decay coefficients for the metrics
	function<function<void()>()> SaveState;//set this function so that calling SaveState save the current state of the network and provides a function which when called restore the state to the saved state. Examples of such functions are given in rnn.cpp and lrnn.cpp
	function<void(vector<long double>&)> DLossDV;//set this function so that calling DLossDV(g) fills the vector g with g[i]=d loss/d V[i], the derivative of the loss at the current time with respect to the current state. g will be correctly dimensioned but not set to 0
	function<void(vector<Vector>&)> DLossDoutparams;//set this funciton so that callig DLossDoutparams(g) fills the vector g with g[i][j]=d loss/d outparams[i][j]. g will be correctly dimensioned but not set to 0
	function<void(vector<long double>&)> MultiplyRightDFDV;//set this function so that ApplyDFDV(act(t),deltav) computes dF/dV at the state act(t) of the units, and computes deltav*dF/dV
	function<void(vector<Vector>&)> DFDTheta;//set this function so that DFDTheta(act(t),symb,w) fills w with w[i](j)=dF[i]dTheta[ij] computed at state (act(t),symb) of the system, where Theta[ij] is the j-th param of unit i. w will be correctly dimensioned beforehand.

	BPTTOptimBase(){};

	long dimV;//should always be same as params.size()
	vector<long double> g;//will store d loss / d V
	vector<Vector> gout;//will store the gradient wrt output parameters
	vector<Vector> wd;//wd[i](j) will be d F[i]/dtheta[i][j]
	vector<Vector> deltatransparams;//Used to accumulate the gradient wrt the transparams
	vector<Vector> deltaoutparams;//Used to accumulate the gradient wrt the outparams

	bool Check();//checks that all vectors, etc., are correctly sized
	vector<Metric> transmetric,outmetric;//Stores the metric used for transition and output parameters.
};

template <long truncate,bool update,class Metric=ConstantMetric>
struct BPTTOptim{};

template<long truncate,class Metric>
struct BPTTOptim<truncate,true,Metric>:public BPTTOptimBase<Metric>{
	void MakeGradStep(bool update);

	vector<function<void()>> buff;
	BPTTOptim():buff(truncate),epoch(0){
		for(long j=0;j<truncate;j++)
			buff[j]=[](){};
	}
	long epoch;

};

template<long truncate,class Metric>
struct BPTTOptim<truncate,false,Metric>:public BPTTOptimBase<Metric>{
	void MakeGradStep(bool update);

	vector<function<void()>> buff;
	BPTTOptim():buff(truncate),epoch(1){
		for(long j=0;j<truncate;j++)
			buff[j]=[](){};
	}
	long epoch;
};

template<class Metric>
struct BPTTOptim<-1,true,Metric>:public BPTTOptimBase<Metric>{
	void MakeGradStep(bool update);

	vector<function<void()>> buff;
	BPTTOptim(){}
};

template<class Metric>
struct BPTTOptim<-1,false,Metric>:public BPTTOptimBase<Metric>{
	void MakeGradStep(bool update);

	vector<function<void()>> buff;
	BPTTOptim(){}
};
#endif /*end of include guard: BPTT_H */
