
using namespace std;

#include"../headers/network.h"

void NN_SoftMaxOutput::fillProbFromEnergy(const vector<long double>& someenergy,vector<long double>& prob)
{
 long double max=someenergy[0],z;
 long y;
 for(y=1;y<alph;y++)if(someenergy[y]>max)max=someenergy[y];
 for(y=0;y<alph;y++)prob[y]=expl(someenergy[y]-max);
 z=0;for(y=0;y<alph;y++)z+=prob[y];
 z=1.l/z;
 for(y=0;y<alph;y++)prob[y]*=z;
}

void NN_SoftMaxOutput::computePred()
{
 int i;long y;
 for(y=0;y<alph;y++)energy[y]=wbias[y];
 for(y=0;y<alph;y++)for(i=0;i<n;i++)
	 energy[y]+=act[i]*w[y][i];
 fillProbFromEnergy(energy,pred);
}

long double NN_SoftMaxOutput::computePred(long x)
{
 //For softmax output, need to compute everything
 computePred();return pred[x];
}

void NN_SoftMaxOutput::Setup()
{
 NeuralNetwork::Setup();
 w=vector<vector<long double> > (alph,vector<long double>(n,0.));
 wbias=vector<long double> (alph,0.);
 prev_w=w;prev_wbias=wbias;
 energy=vector<long double>(alph,0);
}

void NN_SoftMaxOutput::saveWParams()
{
 prev_w=w;prev_wbias=wbias;
}

void NN_SoftMaxOutput::restoreWParams()
{
 w=prev_w;wbias=prev_wbias;
}

void NN_SoftMaxOutput::DLossDoutparams(vector<Vector>&grad){
	 long y,j;
	 for(y=0;y<alph;y++){
		 grad[y][0]=pred[y];//wbias
		 for(j=0;j<n;j++)
			 grad[y][j+1]=pred[y]*act[j];
	 }
	 y=last_symbol_read;
	 grad[y][0]-=1;
	 for(j=0;j<n;j++)
		 grad[y][j+1]-=act[j];
}

void NN_SoftMaxOutput::DLossDV(vector<long double>&grad){
	 long y;
	 for(long j=0;j<n;j++){
		 y=last_symbol_read;
		 grad[j]=-w[y][j];
		 for(long y=0;y<alph;y++)
			 grad[j]+=pred[y]*w[y][j];
		 grad[j]*=deractfunc(act[j]); 
	 }
}

void NN_SoftMaxOutput::setOutparams(vector<vector<long double*>>&outparams,vector<vector<long double>>&outdiagregul){
 outparams=vector<vector<long double*> >(alph,vector<long double*>());
 outdiagregul=vector<vector<long double> >(alph,vector<long double>());
 long i,x;
 for(x=0;x<alph;x++){
	 outparams[x].push_back(&wbias[x]);
	 outdiagregul[x].push_back(1./inftheorysequence(0));
	 for(i=0;i<n;i++){
		 outparams[x].push_back(&w[x][i]);
		 outdiagregul[x].push_back(1./inftheorysequence(i+1.));
	 }
 }
}
