#include "approximate_repeats_allison.h"

using namespace std;

namespace papi
{

    ApproximateRepeats::ApproximateRepeats( char const* name):
    module_id(name)
    {
        string_params[START]=START_STRING;
        string_params[END]=END_STRING;
        string_params[CONTINUE]=CONTINUE_STRING;
        string_params[MATCH]=MATCH_STRING;
        string_params[CHANGE]=CHANGE_STRING;
        string_params[INDEL]=INDEL_STRING;
        tail = new unsigned char[2*markov_order];
    }

    ApproximateRepeats::~ApproximateRepeats()
    {
        if(tail)
            delete [] tail;
        if(prob_char)
            delete [] prob_char;
        if(word_pos_map)
            delete word_pos_map; 
        if(mirrored_last_k)
            delete mirrored_last_k;
        if(complement)
            delete complement;
        if(inverted_last_k)
            delete inverted_last_k;
        if(prob_markov)
        {
            for(unsigned_cstring_fixed_length_ptr_float_hash_map::iterator it = prob_markov->begin();it!=prob_markov->end();++it)
                delete it->second;
            delete prob_markov;
        }
        for(list<RepeatMachine*>::iterator it = repeatMachines.begin();it!=repeatMachines.end();++it)
            delete *it;
    }

    void ApproximateRepeats::init(AnalyzeSetting & settings, long long file_length, const IndexMap<char> *ind_map)
    {
        this->ind_map = ind_map;
        
        
        disable_inverted_repeat = settings.getBool("disable_inverted_repeat",false);
        disable_mirror_repeat = settings.getBool("disable_mirror_repeat",false);
        disable_direct_repeat = settings.getBool("disable_direct_repeat",false);
        disable_qgram_distribution = settings.getBool("disable_qgram_distribution",false);
        disable_character_distribution = settings.getBool("disable_character_distribution",false);
        markov_order = settings.getInt("markov_order",0);
        iterations = settings.getInt("number_of_iterations",2);
        float minimum_relative_probability = (float)settings.getDouble("minimum_relative_probability",0);
        int minimum_region_size = settings.getInt("minimum_region_size",1);
        minimum_hit_length = settings.getInt("minimum_hit_length", 0);
        bool ignore_case = settings.getBool("ignore_case",false);
        
        complement = new unsigned char[ind_map->size];
        
        for(int i=0;i<ind_map->size;++i)
        {
            char c = ind_map->getValue(i);
            int cmpt = settings.getChar(string(&c,1),0);
            if(ignore_case && cmpt == 0)
            {
                c = tolower(c);
                cmpt = settings.getChar(string(&c,1),0);
            }
            if(cmpt == 0)
                complement[i] = (unsigned char)i;
            else 
            {
                int cmpt_ind;
                if(ignore_case)
                    cmpt_ind = ind_map->getIndex(toupper((char)cmpt));
                else
                    cmpt_ind = ind_map->getIndex(cmpt);
                if(cmpt_ind == -1)
                    complement[i] = (unsigned char)ind_map->size;
                else
                    complement[i] = (unsigned char)cmpt_ind;
            }
            
        }
        
        prob_char = new float[ind_map->size];
        memset(prob_char,0,sizeof(float)*ind_map->size);
        
        
        mirrored_last_k = new StackFixedSizeRandomAccess<unsigned char>(minimum_hit_length,'\0');
        inverted_last_k = new StackFixedSizeRandomAccess<unsigned char>(minimum_hit_length,'\0');
        
        word_pos_map = new unsigned_cstring_fixed_length_list_long_long_hash_map(0,hash_djb2::fixed_length(minimum_hit_length),eq::fixed_length(minimum_hit_length));
        
        if(markov_order>0)
            prob_markov = new unsigned_cstring_fixed_length_ptr_float_hash_map(0,hash_djb2::fixed_length(markov_order),eq::fixed_length(markov_order));
        
        prob_no_start = 1.0f;
        if(!disable_direct_repeat)
        {
            DirectRepeatMachine *m = new DirectRepeatMachine(DIRECT_REPEAT, word_pos_map, minimum_region_size, minimum_hit_length,minimum_relative_probability,prob_char,prob_markov, markov_order);
            m->prob[MATCH] = (float)settings.getDouble("direct_repeat_init_prob_match",0.925);
            m->prob[CHANGE] = (float)settings.getDouble("direct_repeat_init_prob_change",0.025);
            m->prob[INDEL] = (float)settings.getDouble("direct_repeat_init_prob_indel",0.025);
            m->prob[START] = (float)settings.getDouble("direct_repeat_init_prob_start",0.05); 
            m->prob[END] = (float)settings.getDouble("direct_repeat_init_prob_end",0.05); 
            repeatMachines.push_back(m);
            m->prob[CONTINUE] = 1-m->prob[END];
            prob_no_start-= m->prob[START];
        }
        if(!disable_mirror_repeat)
        {
            MirrorRepeatMachine *m = new MirrorRepeatMachine(mirrored_last_k,MIRROR_REPEAT,word_pos_map, minimum_region_size, minimum_hit_length,minimum_relative_probability,prob_char,prob_markov, markov_order);
            m->prob[MATCH] = (float)settings.getDouble("mirror_repeat_init_prob_match",0.925);
            m->prob[CHANGE] = (float)settings.getDouble("mirror_repeat_init_prob_change",0.025);
            m->prob[INDEL] = (float)settings.getDouble("mirror_repeat_init_prob_indel",0.025);
            m->prob[START] = (float)settings.getDouble("mirror_repeat_init_prob_start",0.05); 
            m->prob[END] = (float)settings.getDouble("mirror_repeat_init_prob_end",0.05);
            repeatMachines.push_back(m);
            m->prob[CONTINUE] = 1-m->prob[END];
            prob_no_start-= m->prob[START];
            
            
        }
        if(!disable_inverted_repeat)
        {
            InvertedRepeatMachine *m = new InvertedRepeatMachine(inverted_last_k,complement,INVERTED_REPEAT, word_pos_map, minimum_region_size, minimum_hit_length,minimum_relative_probability,prob_char,prob_markov, markov_order);
            m->prob[MATCH] = (float)settings.getDouble("inverted_repeat_init_prob_match",0.925);
            m->prob[CHANGE] = (float)settings.getDouble("inverted_repeat_init_prob_change",0.025);
            m->prob[INDEL] = (float)settings.getDouble("inverted_repeat_init_prob_indel",0.025);
            m->prob[START] = (float)settings.getDouble("inverted_repeat_init_prob_start",0.05); 
            m->prob[END] = (float)settings.getDouble("inverted_repeat_init_prob_end",0.05);
            repeatMachines.push_back(m);
            m->prob[CONTINUE] = 1-m->prob[END];
            prob_no_start-= m->prob[START];
        }
        
    }



    void ApproximateRepeats::preprocessing(long long bufSize, string directory, bool write_stdout, long file_id, string file_path, long long file_length)
    {

        if(bufSize< (long long)minimum_hit_length)
        {
            cerr << "ERROR: Approximate Repeat: File ist shorter than minimum hit length!"<<endl;
            exit(EXIT_FAILURE);
        }

        for(int i=0;i<markov_order;++i)
            tail[i] = buffer[bufSize-markov_order];
        for(int i=0;i<markov_order;++i)
            tail[markov_order+i] = buffer[i];

        
        for(long long pos = 0; pos <bufSize; ++pos)
        {
            //absolute frequency of characters
            ++prob_char[buffer[pos]];
            
            // store qgram positions
            if(minimum_hit_length>0 && pos<=bufSize-minimum_hit_length)
            {
                list<long long>* wp = &(*word_pos_map)[buffer+pos];
                wp->push_back(pos);
            }    
            
            // conditional probability
            if(markov_order>0 && pos<=bufSize-markov_order-1)
            {
                float ** p = &(*prob_markov)[buffer+pos];
                if(*p == NULL)
                {
                    *p = new float[ind_map->size];
                    memset(*p, 0, ind_map->size*sizeof(float));
                }
                ++(*p)[buffer[pos+markov_order]];
            }
            
        }
        
        // underlying markov chain
        if(markov_order>0) {
            for(int i=0;i<markov_order;++i) {
                float ** p = &(*prob_markov)[tail+i];
                if(*p == NULL)
                {
                    *p = new float[ind_map->size];
                    memset(*p, 0, ind_map->size*sizeof(float));
                }
                ++(*p)[tail[i+markov_order]];

            }
            

            ofstream oFile;
            CsvOutStream csv(POINT);
            if(!disable_qgram_distribution) {
                printCsvQgramDistributionHeader(csv, oFile, 
                                                directory, write_stdout, getId(),
                                                file_id, file_path, file_length,
                                                markov_order+1, ind_map->size);
            }
            for(unsigned_cstring_fixed_length_ptr_float_hash_map::iterator it = prob_markov->begin();it!=prob_markov->end();++it)
            {
                if(!disable_qgram_distribution)
                {
                    for(int j=0;j<ind_map->size;++j)
                    {
                        string s;
                        for(int i=0;i<markov_order;++i)
                            s.push_back(ind_map->getValue(it->first[i]));
                        s.push_back(ind_map->getValue(j));
                        csv.addCell(s);
                        csv.addCell(it->second[j]/((float)bufSize));
                        csv.newline();
                    }
                }
                float sum = 0;
                for(int i=0;i<ind_map->size;++i)
                    sum+=(it->second)[i];
                for(int i=0;i<ind_map->size;++i)
                    (it->second)[i]/=sum;
            }
            if(!disable_qgram_distribution) {
                csv.newline();
                if(oFile.is_open()) {
                    oFile.close();
                }
            }
        }
        //relative frequencies
        for(int i=0;i<ind_map->size;++i)
            prob_char[i]/=bufSize;
        
        
    } 
    
    void ApproximateRepeats::initOutputIterations(string & directory, long file_id, string & file_path, long long file_length, 
                                                  ofstream & oFile_direct_repeat, ofstream & oFile_mirror_repeat, ofstream & oFile_inverted_repeat,
                                                  CsvOutStream & csv_direct_repeat, CsvOutStream & csv_mirror_repeat, CsvOutStream & csv_inverted_repeat)
    {
        // write headers
        if(!disable_direct_repeat)
        {
            printCsvRepeatsIterationsHeader(csv_direct_repeat, oFile_direct_repeat, directory, 
                                            false, getId(), file_id, file_path, file_length,
                                            "direct_repeat", string_params, NUM_COUNTS);
            
        }
        if(!disable_mirror_repeat)
        {
            printCsvRepeatsIterationsHeader(csv_mirror_repeat, oFile_mirror_repeat, directory, 
                                            false, getId(), file_id, file_path, file_length,
                                            "mirror_repeat", string_params, NUM_COUNTS);
        }
        if(!disable_inverted_repeat)
        {
            printCsvRepeatsIterationsHeader(csv_inverted_repeat, oFile_inverted_repeat, directory, 
                                            false, getId(), file_id, file_path, file_length,
                                            "inverted_repeat", string_params, NUM_COUNTS);
        }    
        
        //write start parameters
        for(list<RepeatMachine*>::iterator it = repeatMachines.begin(); it!=repeatMachines.end(); ++it)
        {
            if((*it)->type_id == DIRECT_REPEAT && !disable_direct_repeat)
            {
                printCsvRepeatsIteration(csv_direct_repeat, (*it)->prob, NUM_COUNTS,0);
            }
            if((*it)->type_id == MIRROR_REPEAT && !disable_mirror_repeat)
            {
                printCsvRepeatsIteration(csv_mirror_repeat, (*it)->prob, NUM_COUNTS,0);
            }
            if((*it)->type_id == INVERTED_REPEAT && !disable_inverted_repeat)
            {
                printCsvRepeatsIteration(csv_inverted_repeat, (*it)->prob, NUM_COUNTS,0);
            }
        }
    }
        
    void ApproximateRepeats::writeIteration(int n,CsvOutStream & csv_direct_repeat, CsvOutStream & csv_mirror_repeat, CsvOutStream & csv_inverted_repeat)
    {
        for(list<RepeatMachine*>::iterator it = repeatMachines.begin(); it!=repeatMachines.end(); ++it)
        {
            if((*it)->type_id == DIRECT_REPEAT && !disable_direct_repeat)
            {
                printCsvRepeatsIteration(csv_direct_repeat, (*it)->prob, NUM_COUNTS,n);
            }
            if((*it)->type_id == MIRROR_REPEAT && !disable_mirror_repeat)
            {
                printCsvRepeatsIteration(csv_mirror_repeat, (*it)->prob, NUM_COUNTS,n);
            }
            if((*it)->type_id == INVERTED_REPEAT && !disable_inverted_repeat)
            {
                printCsvRepeatsIteration(csv_inverted_repeat, (*it)->prob, NUM_COUNTS,n);
            }
        }
        
    }

    void ApproximateRepeats::writeOutput(string &directory, bool write_stdout, long file_id, string file_path, long long file_length)
    {
        ofstream oFile_direct_repeat,oFile_inverted_repeat,oFile_mirror_repeat;
        CsvOutStream csv_direct(POINT),csv_inverted(POINT),csv_mirror(POINT);

        if(!disable_character_distribution) {
            
            printCsvCharacterDistribution(directory, write_stdout, getId(),
                                         file_id, file_path, file_length, 
                                         ind_map, prob_char);
        }
        
        for(list<RepeatMachine*>::iterator it = repeatMachines.begin(); it!=repeatMachines.end(); ++it)
        {
            
            if((*it)->type_id == DIRECT_REPEAT && !disable_direct_repeat)
            {
                printCsvRepeats(directory, write_stdout, getId(), 
                                      file_id, file_path, file_length, "direct_repeat", 
                                      string_params, (*it)->prob, NUM_COUNTS,
                                      NULL, NULL);
            }
            if((*it)->type_id ==  MIRROR_REPEAT && !disable_mirror_repeat)
            {
                printCsvRepeats(directory, write_stdout, getId(), 
                                      file_id, file_path, file_length, "mirror_repeat", 
                                      string_params, (*it)->prob, NUM_COUNTS,
                                      NULL, NULL);
            }
            if((*it)->type_id == INVERTED_REPEAT && !disable_inverted_repeat)
            {
                printCsvRepeats(directory, write_stdout, getId(), 
                                      file_id, file_path, file_length, "inverted_repeat",
                                      string_params, (*it)->prob, NUM_COUNTS, 
                                      ind_map, complement);
            }    
        }
        if(oFile_direct_repeat)
            oFile_direct_repeat.close();
        if(oFile_mirror_repeat)
            oFile_mirror_repeat.close();
        if(oFile_inverted_repeat)
            oFile_inverted_repeat.close();
    }


    void ApproximateRepeats::initIteration()
    {
        // initialize last_k
        if(!disable_mirror_repeat || !disable_inverted_repeat)
            for(long long pos = 0; pos < max((long long)minimum_hit_length,(long long)1) -(long long)1; ++pos)
            {
                if(!disable_mirror_repeat)
                    mirrored_last_k->push_front(buffer[pos]);
                if(!disable_inverted_repeat)
                    inverted_last_k->push_front(complement[buffer[pos]]);
                
            }
        // reset base state
        baseState.reset();
        baseState.prob=1;

        
    }

    void ApproximateRepeats::finishIteration()
    {
        prob_no_start = 1;
        
        for(list<RepeatMachine*>::iterator it = repeatMachines.begin(); it!=repeatMachines.end(); ++it)
        {
            (*it)->finishLastRow(baseState);
            (*it)->reestimateParameters(baseState);
            prob_no_start-=(*it)->prob[START];
        }
    }


    void ApproximateRepeats::processCharacter(long long pos, long long bufSize, bool noHit)
    {
        
        if(!disable_mirror_repeat && minimum_hit_length>0)
            mirrored_last_k->push_front(buffer[pos+minimum_hit_length-1]);
        if(!disable_inverted_repeat && minimum_hit_length>0)
            inverted_last_k->push_front(complement[buffer[pos+minimum_hit_length-1]]);
        
        for(list<RepeatMachine*>::iterator it = repeatMachines.begin(); it!=repeatMachines.end(); ++it)
        {
            // algorithm with speed ups
            if(minimum_hit_length>0)
            {
                //activate regions ar exact matchings with at least minimum hit length
                if((noHit==false) && pos + minimum_hit_length <=bufSize)
                    (*it)->determineNewRegions(buffer,pos);
                
                //Update old regions and sort out insignificant regions
                (*it)->updateOldRegions();
                
                // Merge old and new regions
                (*it)->mergeRegions();
                
                // update state list of active regions
                (*it)->updateStateList();

                // let regions age
                (*it)->updateAge();

            }
            // full algorithm
            else
            {
                (*it)->addCompleteRegion(pos);
            }
            
            //finish calculations of current row
            (*it)->finishRow(baseState);
            
            // insert new repeat start states and merge them
            (*it)->mergeWithNewStartStates(baseState,pos);
            
        }
        
        // update base state (to next row)
        baseState.prob*= prob_no_start*getProbCharNormalized(buffer,pos,-1);
        ++baseState.count_random;
        
        // transitions to the next row
        for(list<RepeatMachine*>::iterator it = repeatMachines.begin(); it!=repeatMachines.end(); ++it)
            (*it)->calcNextRowAndUpdateBaseState(baseState,buffer,pos);
        
        // normalize probabilities relative to base state
        for(list<RepeatMachine*>::iterator it = repeatMachines.begin(); it!=repeatMachines.end(); ++it)
            (*it)->normalize(baseState.prob);
        baseState.prob = 1.0;

    }
        
    void ApproximateRepeats::process(long long bufSize,unsigned char const* buf,std::string directory,bool write_stdout, long file_id,std::string file_path,long long file_length)
    {
        cerr << "ApproximateRepeats started"<<endl;
        
        buffer = buf;
        
        ofstream oFile_direct_repeat,oFile_inverted_repeat,oFile_mirror_repeat;
        CsvOutStream csv_direct_repeat(POINT),csv_mirror_repeat(POINT),csv_inverted_repeat(POINT);

        preprocessing(bufSize,directory,write_stdout,file_id,file_path,file_length);

        if(!directory.empty()) {
            initOutputIterations(directory,file_id,file_path, file_length, 
                                 oFile_direct_repeat, oFile_mirror_repeat, oFile_inverted_repeat,
                                 csv_direct_repeat,csv_mirror_repeat,csv_inverted_repeat);
        }

                
        for(int n=1;n<=iterations;++n)
        {
            initIteration();

            long long pos;
            for(pos = 1; pos <=min(max(bufSize,(long long)1)-(long long)1,max(bufSize,(long long)minimum_hit_length)-(long long)minimum_hit_length); ++pos)
                processCharacter(pos, bufSize,false);
            for(; pos <bufSize; ++pos)
                processCharacter(pos, bufSize,true);
            
            finishIteration();

            if(!directory.empty())
                writeIteration(n,csv_direct_repeat,csv_mirror_repeat,csv_inverted_repeat);
        }

        writeOutput(directory,write_stdout,file_id,file_path,file_length);

        if(oFile_direct_repeat.is_open())
            oFile_direct_repeat.close();
        if(oFile_mirror_repeat.is_open())
            oFile_mirror_repeat.close();
        if(oFile_inverted_repeat.is_open())
            oFile_inverted_repeat.close();
        

        cerr << "ApproximateRepeats finished" << endl;
        
    }

    const string ApproximateRepeats::getId()
    {
        return module_id;
    }

    float ApproximateRepeats::getProbCharNormalized(const unsigned char* buffer, long long pos, int c)
    {
        if(markov_order == 0 || pos < (long long) markov_order)
        {
            if(c == -1)
                return prob_char[buffer[pos]];
            
            else {
                return prob_char[buffer[pos]]/(1-prob_char[c]);
            }
            
        }
        else {
            unsigned_cstring_fixed_length_ptr_float_hash_map::iterator it = prob_markov->find(buffer+pos-markov_order);
            if(it==prob_markov->end())
            {
                if(c == -1)
                    return prob_char[buffer[pos]];
                
                else {
                    return prob_char[buffer[pos]]/(1-prob_char[c]);
                }
            }
            else {
                if(c == -1)
                    return (it->second)[buffer[pos]];
                else {
                    return (it->second)[buffer[pos]]/(1-(it->second)[c]);
                }
                
            }
            
        }
        
    }

}

