#include<iostream>
#include<fstream>
#include<vector>
#include<cmath>
#include<cfloat>
#include"../headers/rand.h"
#include<map>
#include"../headers/indexes.h"

using namespace std;

//#ifdef VERBOSE
const bool verbose=true;
//#else
//const bool verbose=false;
//#endif

//alphabet size=number of distinct symbols in data
long alph;
//raw training data and validation data
vector<unsigned char> rawdatum,rawvdatum;
//training and validation data transcribed as integers
vector<long> datum,vdatum;
//dictionary between raw data and integer representation
MyDict<unsigned char> dict;

// Bayesian prior on all suffix trees defined by a branching probability:
// 1/2= each node has prior probability 1/2 to be a leaf or further subdivide
// into alph subnodes
long double branching=.5;//1./alph;//or 1/2

long globpos;

// Regularization parameter (for unseen symbols) at each node.
// .5=Jeffreys prior=KT estimator
long double dirich(long depth)
{
 //return .5;
 return 1./alph;
 //return 1.l/alph/(1.l+depth);
 //return 1.l/(alph+depth);
}

long double RissanenCodelength(unsigned long n)
{
 if(n<1ul)return -1;
 long double l=n;long double clog2e=1./logl(2.);
 long double s=0;
 while(l>1.l){
	 l=logl(l)*clog2e;
	 s+=l;
 }
 return s+logl(2.86507)*clog2e;
}


long double nodepred(long count,long totcount,long depth)
{
 return (count+dirich(depth))/(totcount+dirich(depth)*alph);
 //if(!totcount)return 1./alph;
 //long double reg=1./(totcount+1.);//identical to usual with dirich=1./alph
 //long double reg=1./(totcount+1.)/logl(totcount+1.)*logl(2.);
 //long double reg=pow(.5,RissanenCodelength(totcount+1));
 //long double reg=1./(globpos+1.)/logl(globpos+1.)*logl(2.);
 //if(reg>=1)reg=1;
 //return (count+0.l)/(totcount+0.l)*(1.-reg)+reg/alph;
}

long double logbranching(long depth0, long depth1)
{
 return (depth1-depth0)*logl(branching)/M_LN2;
 //return logl((depth0+1.l)/(depth1+1.l))/M_LN2l;
}


class Trie;
Trie*theroot;
vector<Trie*> index;

typedef map<long,Trie*> ChildList;
typedef map<long,long> CountList;

class Trie{
public:
	Trie* father;
	long pos;
	long depth;
	bool isstop;
	ChildList child;
	CountList count;long totcount;
	long double selfll,subtreell;
	Trie():father(NULL),pos(0),depth(0),isstop(false),totcount(0),selfll(0),subtreell(0){};
	~Trie();
};

Trie::~Trie()
{
 ChildList::iterator it;
 for(it=child.begin();it!=child.end();++it)delete it->second;
}

const int maxprintlength=80;

long trans(long c)
{
 if(c<32/*&&c!='\n'*/)return '\\';
 return c;
}

void Print(long pos)
{
 long p;
 for(p=pos-maxprintlength;p!=pos;p++){
	 if(p<0)cout<<" ";else cout<<trans(datum[p]);
 }
 cout<<'X'<<endl;
}

void Print(Trie* node)
{
 long p;long pos=node->pos;
 for(p=pos-maxprintlength;p!=pos;p++){
	 if(p<pos-node->depth)cout<<" ";else cout<<trans(datum[p]);
 }
 if(node->pos==node->depth)cout<<'X'<<endl;else cout<<"*"<<endl;
}

void updateCount(Trie*node,long c)
{
 do{
	 ++(node->count[c]);++(node->totcount);
	 node=node->father;
 }while(node);
}

long matchLength(long pos1,long pos2,long max)
{
 long ml=0;
 --pos1;--pos2;
 for(;pos1>=0&&pos2>=0&&ml<max&&datum[pos1]==datum[pos2];
		 --pos1,--pos2,++ml);
 return ml;
}

void findBestMatch(Trie*node,long pos,Trie*&bestmatch,long&matchlength)
	//after execution, bestmatch contains the node with node->pos closest
	//to pos
	//matchlength contains total matching length between bestmatch->pos and
	//pos
{
 if(pos<0)cerr<<"WARNING pos<0"<<endl;
 if(node==NULL){bestmatch=NULL;matchlength=0;return;}
 if(node==theroot)matchlength=0;
 long ml=matchLength(pos-matchlength,node->pos-matchlength,node->depth-matchlength);
 matchlength+=ml;
 if(matchlength!=node->depth){bestmatch=node;return;}
 //if(matchlength>=maxlength){bestmatch=node;return;}
 if(matchlength==pos){bestmatch=node;return;}
 long c=datum[pos-matchlength-1];
 Trie*next;
 ChildList::iterator it=node->child.find(c);
 if(it==node->child.end()){bestmatch=node;return;}
 next=it->second;
 findBestMatch(next,pos,bestmatch,matchlength);
}

/*void addCounts(CountList& result,const CountList& count1,const CountList& count2)
{
 if(&result!=&count1)result=count1;
 CountList::const_iterator it;
 for(it=count2.begin();it!=count2.end();++it)
	 result[it->first]+=it->second;
}*/

void attach(Trie* afather,Trie* achild)
	//assumes poses and depthes are correct
	//puts achild into afather's children list
{
 long p=achild->pos-afather->depth-1;
 if(p<0)cerr<<"Wrong attachment"<<endl;
 long c=datum[p];
 afather->child[c]=achild;
}

void recomputesubtreell(Trie*);

Trie* join(Trie* node1,Trie* node2,long matchlength)
	//builds a new node having node1 and node2 as subnodes,
	//and node1->father as father
	//ASSUMES node1 IS BESTMATCH FOR node2 AND matchlength IS CORRECT
	//ASSUMES counts of node2 are 0
{
 if(!node1->father){cerr<<"joining to a fatherless node, should not occur"<<endl;return NULL;}
 Trie* newtf=new Trie;
 newtf->pos=node1->pos;newtf->depth=matchlength;
 long c1=datum[node1->pos-matchlength-1],c2=datum[node2->pos-matchlength-1];
 newtf->child[c1]=node1;
 newtf->child[c2]=node2;
 newtf->father=node1->father;
 attach(newtf->father,newtf);
 node1->father=newtf;node2->father=newtf;
 newtf->count=node1->count;newtf->totcount=node1->totcount;
 newtf->selfll=node1->selfll;
 recomputesubtreell(node1);recomputesubtreell(newtf);
 return newtf;
}

Trie* insert(long pos,Trie*bestmatch,long matchlength)
	//ASSUMES bestmatch is the best match for pos, and matchlength is
	//correct
{
 //cout<<"Trying to insert"<<endl;Print(pos);
 //cout<<"  best match is"<<endl;Print(bestmatch->pos);Print(bestmatch);
 if(matchlength>bestmatch->depth||matchlength>pos)cerr<<"WARNING bug matchlength"<<endl;
 Trie*newt;
 if(matchlength==bestmatch->depth&&matchlength==pos)newt=bestmatch;
 else{
	 newt=new Trie;newt->pos=pos;newt->depth=pos;
	 if(matchlength<bestmatch->depth&&matchlength<pos){
		 join(bestmatch,newt,matchlength);
	 }else if(matchlength==bestmatch->depth&&matchlength<pos){
		 newt->father=bestmatch;
		 attach(bestmatch,newt);
	 }else if(matchlength==pos&&matchlength<bestmatch->depth){
		 cerr<<"inserting prefix of already existing node, should not occur"<<endl;
	 }
 }
 index[pos]=newt;
 return newt;
}

unsigned char screenfriendlychar(unsigned char x)
{
  if(x<32&&x!='\n')return '\\';
	 return x;
}

void getAsciiData(istream& is,vector<unsigned char>& therawdatum,vector<long>&thedatum)
{
 char x;
 while(is.get(x))therawdatum.push_back(x);
 dict.AddFromAndTranslate(therawdatum,thedatum);
 alph=dict.word.size();
 if(verbose){
	 cout<<"Read "<<therawdatum.size()<<" characters, beginning with:"<<endl;
	 for(unsigned long t=0;t<78&&t<therawdatum.size();t++)
		 cout<<screenfriendlychar(dict.word[thedatum[t]]);
	 cout<<endl;
	 cout<<endl;
	 cout<<"for a total of "<<alph<<" distinct characters"<<endl<<endl;
 }
}

long countNodes(Trie*node)
{
 long n=1;
 ChildList::iterator it;
 for(it=node->child.begin();it!=node->child.end();++it)
	 n+=countNodes(it->second);
 return n;
}

long double selfpred(Trie*node,long c)
{
 CountList::const_iterator it=node->count.find(c);
 if(it==node->count.end())return nodepred(0,node->totcount,node->depth);
 else return nodepred(it->second,node->totcount,node->depth);
	 //dirich(node->depth)/(node->totcount+dirich(node->depth)*alph);
 //else return (it->second+dirich(node->depth))/(node->totcount+dirich(node->depth)*alph);
}

//long double updateselfLL(Trie*node,long c)
//{
// do{
//	 node->selfll+=logl(selfpred(node,c))/M_LN2l;
//	 node=node->father;
// }while(node);
//}

inline long double clog2(long double x)
{
 return logl(x)/M_LN2;
}

inline long double pow2(long double l)
{
 const long double logepsilon=logl(LDBL_EPSILON)/logl(2)+1;
 return (l>logepsilon?powl(2.,l):0);
}

long double addlogs(long double l1,long double l2)
{
 if(l1>l2)
	 return l1+clog2(1.l+pow2(l2-l1));
 else
	 return l2+clog2(1.l+pow2(l1-l2));
}

void recomputesubtreell(Trie*node)
{
 ChildList::iterator it;
 long double subll=0;
 if(node->depth==node->pos)subll+=clog2(1.l/alph);//internal node was a leave, was once used to do a prediction
 for(it=node->child.begin();it!=node->child.end();++it)
	 subll+=it->second->subtreell;
// long reldepth;
// if(node->father)reldepth=node->depth-node->father->depth;
// else reldepth=1;
// long double logsubcoeff=reldepth*clog2(branching);
 long double logsubcoeff=logbranching((node->father?node->father->depth+1:0),node->depth+1);
 long double logselfcoeff=clog2(1.l-pow2(logsubcoeff));
 node->subtreell=addlogs(logsubcoeff+subll,logselfcoeff+node->selfll);
 return;
}

long double computeprob(Trie*node,long c,long pos)
{
 long double lselfp=clog2(selfpred(node,c));
 if(pos-node->depth-1<0){
	 return lselfp;
 }
 long c2=datum[pos-node->depth-1];
 ChildList::iterator it=node->child.find(c2);
 if(it==node->child.end()){
	 cerr<<"computeprob: insufficient context, should not happen"<<endl;
	 return lselfp;
 }
 long double lsubp=computeprob(it->second,c,pos);
 long double subll=0;
 if(node->depth==node->pos)subll+=clog2(1.l/alph);//internal node was a leave, was once used to do a prediction
 for(it=node->child.begin();it!=node->child.end();++it)
	 subll+=it->second->subtreell;
 //long reldepth;
 //if(node->father)reldepth=node->depth-node->father->depth;
 //else reldepth=1;
 //long double logsubcoeff=reldepth*clog2(branching);
 long double logsubcoeff=logbranching((node->father?node->father->depth+1:0),node->depth+1);
 long double logselfcoeff=clog2(1.l-pow2(logsubcoeff));
 //cout<<"Comparing sub, self: "<<logsubcoeff+subll<<" "<<logselfcoeff+node->selfll<<endl;
 long double err=node->subtreell-addlogs(logsubcoeff+subll,logselfcoeff+node->selfll);
 if(fabsl(err)>0.000001){
	 cout<<"subtreell error: "<<err<<endl;
	 Print(node);
	 cout<<node->selfll<<endl;
 }
 return addlogs(logsubcoeff+subll+lsubp,logselfcoeff+node->selfll+lselfp)-
	 addlogs(logsubcoeff+subll,logselfcoeff+node->selfll);
}

long double updateLL(Trie*node,long c,long pos)
{
 long double lselfp=clog2(selfpred(node,c));
 node->selfll+=lselfp;
 if(pos-node->depth-1<0){
	 node->subtreell+=lselfp;return lselfp;
 }
 long c2=datum[pos-node->depth-1];
 ChildList::iterator it=node->child.find(c2);
 if(it==node->child.end()){
	 cerr<<"updateLL: insufficient context, should not happen"<<endl;
	 node->subtreell+=lselfp;return lselfp;
 }
 updateLL(it->second,c,pos);
 long double old=node->subtreell;
 recomputesubtreell(node);
 return node->subtreell-old;
}

int main(int argc,char**argv)
{
 randomize();
 bool valid=false;if(argc>=3)valid=true;
 if(valid){
	 ifstream f2(argv[2]);getAsciiData(f2,rawvdatum,vdatum);f2.close();
 }
 ifstream f(argv[1]);getAsciiData(f,rawdatum,datum);f.close();
 unsigned long traindatumsize=datum.size();
 datum.insert(datum.end(),vdatum.begin(),vdatum.end());//simpler this way

 index=vector<Trie*>(datum.size());
 theroot=new Trie;
 theroot->isstop=true;
 Trie*bm,*newt;long ml;
 unsigned long pos;unsigned long prevbmpos=-1;
 cout<<"Doing Burrows-Wheeler transform:"<<endl;
 long double ll=0,valid_ll=0,p;
 for(pos=0;pos<datum.size();++pos){
	 if(!(pos%(int)(10*sqrt(datum.size()))))cout<<pos/(datum.size()+0.)*100.<<"%    \r"<<flush;
	 if(prevbmpos+1>=pos||datum[pos-1]!=datum[prevbmpos]){
		 findBestMatch(theroot,pos,bm,ml);
	 }else{
		 //SHORTCUT to compute bestmatches
		 ml++;
		 bm=index[prevbmpos+1];
		 while(bm->father!=NULL&&bm->father->depth>=ml)bm=bm->father;
		 Trie*bm2;long ml2;
		 findBestMatch(theroot,pos,bm2,ml2);
		 if(bm2!=bm||ml2!=ml)cerr<<"SHORTCUT BUG"<<endl;
	 }
	 newt=insert(pos,bm,ml);//NOTE:invalidates some values of selfll
	 //p=0;for(char ctest=0;ctest<alph;ctest++)p+=pow2(computeprob(theroot,ctest,pos));if(fabsl(p-1)>1e-14)cout<<"Total prob-1: "<<p-1<<endl;
	 globpos=pos;
	 p=computeprob(theroot,datum[pos],pos);
	 if(pos>=traindatumsize)valid_ll+=p;
	 long double check=ll+p;
	 ll+=updateLL(theroot,datum[pos],pos);
	 if(fabsl(1-check/ll)>1000*LDBL_EPSILON)cout<<"NUMERICAL DISCREPANCY:"<<ll-check<<endl;
	 //cout<<trans(datum[pos])<<": "<<p<<" bits, "<<-ll<<" bits total, "<<-ll/(pos+1.)<<" bpc."<<endl;
	 updateCount(newt,datum[pos]);
	 prevbmpos=bm->pos;
 }
 cout<<"100%     "<<endl<<endl;
 cout<<-ll<<" bits, "<<-ll/(datum.size()+0.)<<" bpc, compression "<<-100.*ll/(8*datum.size())<<" \%."<<endl;
 if(valid){
	 cout<<"On validation data:"<<endl; 
	 cout<<-valid_ll<<" bits, "<<-valid_ll/(datum.size()-traindatumsize+0.)<<" bpc, compression "<<-100.*valid_ll/(8*(datum.size()-traindatumsize))<<" \%."<<endl;
 }

 return 0;
}
