#include "repeat_machine.h"

namespace papi
{

    RepeatMachine::State & RepeatMachine::State::operator=(const RepeatMachine::State &rhs) {
        prob = rhs.prob;
        memcpy(counts, rhs.counts, sizeof(counts));
        count_random = rhs.count_random;
        return *this;
    }


    void RepeatMachine::State::merge(const State &s, float edge_prob, int edge_type, int machine_id)
    {
        
        if(prob < 0.000000001 )
        {
            *this = s;
            ++counts[machine_id][edge_type];
            prob*=edge_prob;
        }
        else {
            
            float s_prob_rel = s.prob*edge_prob/prob;
            
            if(s_prob_rel == 0)
                return;

            
            for(int j=0;j<NUM_MACHINES;++j)
                for(int i=0;i<NUM_COUNTS;++i)
                    counts[j][i] += s.counts[j][i]*s_prob_rel;
            count_random += s.count_random*s_prob_rel;
            
            counts[machine_id][edge_type]+= s_prob_rel;
            
            prob+=s_prob_rel*prob;
            
            count_random/=(1+s_prob_rel);
            for(int j=0;j<NUM_MACHINES;++j)
                for(int i=0;i<NUM_COUNTS;++i)
                    counts[j][i] /= (1+s_prob_rel);
        }
    }

    void RepeatMachine::State::merge(const State &s)
    {
        if(prob < 0.000000001 )
        {
            *this = s;
        }
        else {
            float s_prob_rel = s.prob/prob;
            
            if(s_prob_rel == 0)
                return;
            
            for(int j=0;j<NUM_MACHINES;++j)
                for(int i=0;i<NUM_COUNTS;++i)
                {
                    counts[j][i] += s.counts[j][i]*s_prob_rel;
                }
            count_random += s.count_random*s_prob_rel;
            
            prob+=s.prob;
            
            count_random/=(1+s_prob_rel);
            for(int j=0;j<NUM_MACHINES;++j)
                for(int i=0;i<NUM_COUNTS;++i)
                    counts[j][i] /= (1+s_prob_rel);
        }
    }

    void RepeatMachine::State::reset()
    {
        prob = 0.0;
        count_random = 0.0;
        memset(counts,0,sizeof(counts));
    }



    RepeatMachine::RepeatMachine(int type_id,   
                        unsigned_cstring_fixed_length_list_long_long_hash_map *word_pos_map,
                                 long minimum_region_size, 
                                 long minimum_hit_length, 
                                 float minimum_relative_probability, 
                                 float *prob_char,
                        unsigned_cstring_fixed_length_ptr_float_hash_map *prob_markov, int markov_order):
     type_id(type_id),prob(new float[NUM_COUNTS]),  active_regions(new list<Region>()),active_regions_save(new list<Region>()), num_unused_states(0),
        markov_order(markov_order),prob_markov(prob_markov), prob_char(prob_char), minimum_region_size(minimum_region_size),
        minimum_hit_length(minimum_hit_length), minimum_relative_probability(minimum_relative_probability), word_pos_map(word_pos_map)
    {
        //delimiting state
        states.push_back(State());
        unused_states = states.begin();
        unused_states->prob = -1;
        
    }

    RepeatMachine::~RepeatMachine()
    {
        if(active_regions)
            delete active_regions;
        if(active_regions_save)
            delete active_regions_save;
        if(prob)
            delete prob;
    }



    void RepeatMachine::finishLastRow(State & baseState)
    {
        list<State>::iterator it_state = states.begin();
        
        
        for(list<Region>::iterator it_region = active_regions->begin();it_region!=active_regions->end();++it_region)
        {
            int start = it_region->start;
            int end = it_region->end;
            
            // All repeats end (end probbility = 1.0)
            for(int i=start;i<=end;++i,++it_state)
            {
                baseState.merge(*it_state);
            }
        }
        
        //Clean up
        active_regions->clear();
        active_regions_save->clear();
        new_hit_regions.clear();
        states.clear();
        states.push_back(State());
        
        
        unused_states = states.begin();
        num_unused_states = states.size()-1;
    }

    void RepeatMachine::reestimateParameters(State & baseState)
    {
        float sum;
        
        sum = 0;
        for(int i=0;i<NUM_MACHINES;++i)
        {
            sum+=baseState.counts[i][START];
        }
        sum+=baseState.count_random;
        
        prob[START] = baseState.counts[type_id][START]/sum;
        
        sum = 0;
        sum+=baseState.counts[type_id][END];
        sum+=baseState.counts[type_id][CONTINUE];
        
        prob[END] = baseState.counts[type_id][END]/sum;
        prob[CONTINUE] = 1 - prob[END];
        
        sum = 0;
        sum+=baseState.counts[type_id][MATCH];
        sum+=baseState.counts[type_id][CHANGE];
        sum+=baseState.counts[type_id][INDEL];
        
        prob[MATCH] = baseState.counts[type_id][MATCH]/sum;
        prob[CHANGE] = baseState.counts[type_id][CHANGE]/sum;
        
        //Counts both insert and delete and therefore must be divided by two
        prob[INDEL] = baseState.counts[type_id][INDEL]/sum/2.0;
        
    }

    void RepeatMachine::updateOldRegionsDir(bool dir)
    {
        //Direction of the repeat
        short fwd_ext = (dir==FORWARD)?1:0;
        short rev_ext = (dir==FORWARD)?0:-1;
        
        list<State>::iterator it_state = states.begin();
        for( list<Region>::iterator it=active_regions->begin();it!=active_regions->end();++it)
        {
            // Young regions continue to grow. The minimum age is here chosen as the size of the exact matching
            if(it->age<=minimum_hit_length)
            {
                active_regions_save->push_back(Region(max(it->start+rev_ext,-1ll),it->end+fwd_ext,it->age));
            }
            else
            {
                long long start = it->start;
                long long end = it->end;
                
                //-1 ist a valid index, therefore use -2 as a flag that no region was created
                long long new_start = -2,new_end = -2;
                for(long long i=start;i<=end;++i)
                {
                    //significant state
                    if(it_state->prob>=minimum_relative_probability)
                    {
                        //activate region around state and let region grow
                        if(new_start==-2)
                        {
                            //new_start can be -1
                            // Index -1 is a helper index making sure at another point
                            // that a matching at position 0 automatically leads to the base state.
                            new_start = max( max(start+rev_ext,-1ll),i-minimum_region_size+rev_ext); 
                            new_end = min(end+fwd_ext,i+minimum_region_size+fwd_ext);
                        }
                        // Extend existing region
                        else {
                            new_end = min(end+fwd_ext,i+minimum_region_size+fwd_ext);
                        }
                    }
                    
                    // If state is not significant and the distance to the next significant state is big
                    // cose this region
                    else if(new_start!=-2 && i>new_end+minimum_region_size-rev_ext){
                        active_regions_save->push_back(Region(new_start,new_end,it->age));
                        new_start = -2;
                    }
                    ++it_state;
                }
                // Close region
                if(new_start!=-2)
                {
                    active_regions_save->push_back(Region(new_start,new_end,it->age));
                }
            }
        }
        
        // Save old regions in order to update the state list later on
        list<Region> *swap;
        swap = active_regions;
        active_regions = active_regions_save;
        active_regions_save = swap;
    }

    void RepeatMachine::mergeRegions()
    {
        // merge old and new regions lists
        active_regions->merge(new_hit_regions);
        new_hit_regions.clear();
        
        list<Region>::iterator it_next = active_regions->begin();
        list<Region>::iterator it = it_next++;
        
        
        // merge overlapping regions
        // age is the minimum of overlapping regionss
        while(it_next!=active_regions->end())
        {
            if(it->end+1<it_next->start)
            {
                it = it_next++;
            }
            else {
                it->age = min(it->age,it_next->age);
                it->end = max(it->end,it_next->end);
                it_next = active_regions->erase(it_next);
            }
        }
        
    }

    /*
     * update state list to the current active regions
     */
    void RepeatMachine::updateStateList()
    {
        
        long long start_old =-2, end_old=-2;
        long long start_new =-2, end_new=-2;
        
        list<State>::iterator it_state = states.begin();
        list<Region>::iterator it_old = active_regions_save->begin();
        list<Region>::iterator it_new = active_regions->begin();
        
        bool done = false;
        
        
        if(it_new != active_regions->end())
        {
            start_new = it_new->start;
            end_new = it_new->end;
            ++it_new;
        }
        else {
            done = true;
        }
        
        if(it_old != active_regions_save->end())
        {
            start_old = it_old->start;
            end_old = it_old->end;
            ++it_old;
        }
        else {
            done = true;
        }
        
        
        while(!done)
        {
            
            if( end_new < start_old)
            {
                insertStates(it_state,end_new-start_new+1);
                if(it_new == active_regions->end())
                {
                    start_new = -2;
                    done = true;
                }
                else
                {
                    start_new = it_new->start;
                    end_new = it_new->end;
                    ++it_new;
                }
            }
            else if( end_old < start_new )
            {
                it_state = removeStates(it_state,end_old-start_old+1);
                if(it_old == active_regions_save->end())
                {
                    start_old = -2;
                    done = true;
                }
                else
                {
                    start_old = it_old->start;
                    end_old = it_old->end;
                    ++it_old;
                }
            }
            else
            {
                int inCommon;
                if( start_new < start_old)
                {
                    insertStates(it_state,start_old-start_new);
                    inCommon = -start_old;
                }
                else if(start_new > start_old)
                {
                    it_state = removeStates(it_state,start_new-start_old);
                    inCommon = -start_new;
                }
                else {
                    inCommon = -start_new;
                }
                
                
                if( end_new < end_old)
                {
                    inCommon += end_new+1;
                    start_old = end_new+1;
                    if(it_new == active_regions->end())
                    {
                        start_new = -2;
                        done = true;
                    }
                    else
                    {
                        start_new = it_new->start;
                        end_new = it_new->end;
                        ++it_new;
                    }
                }
                else if(end_new > end_old)
                {
                    inCommon += end_old+1;
                    start_new = end_old+1;
                    if(it_old == active_regions_save->end())
                    {
                        start_old = -2;
                        done = true;
                    }
                    else
                    {
                        start_old = it_old->start;
                        end_old = it_old->end;
                        ++it_old;
                    }
                }
                else {
                    inCommon += end_new+1;
                    if(it_new == active_regions->end())
                    {
                        start_new = -2;
                        done = true;
                    }
                    else
                    {
                        start_new = it_new->start;
                        end_new = it_new->end;
                        ++it_new;
                    }
                    if(it_old == active_regions_save->end())
                    {
                        start_old = -2;
                        done = true;
                    }
                    else
                    {
                        start_old = it_old->start;
                        end_old = it_old->end;
                        ++it_old;
                    }
                }
                
                advance(it_state, inCommon);
            }
        }
        
        if(start_new!=-2)
        {
            insertStates(it_state,end_new-start_new+1);
            while(it_new!=active_regions->end())
            {
                start_new = it_new->start;
                end_new = it_new->end;
                ++it_new;    
                insertStates(it_state,end_new-start_new+1);
            }
            
        }
        else if(start_old!=-2)
        {
            it_state = removeStates(it_state,end_old-start_old+1);
            while(it_old!=active_regions_save->end())
            {
                start_old = it_old->start;
                end_old = it_old->end;
                ++it_old;
                it_state = removeStates(it_state,end_old-start_old+1);
            }
            
        }
        active_regions_save->clear();
        
    }


    void RepeatMachine::finishRowDir(State & baseState, bool dir)
    {    
        if(dir == FORWARD)
        {
            
            list<State>::iterator it2_next = states.begin();
            list<State>::iterator it2 = it2_next++;
            
            for(list<Region>::iterator it = active_regions->begin(); it!=active_regions->end();++it)
            {
                long long start = it->start;
                long long end = it->end;
                
                for(long long i=start;i<end-1; ++i,it2=it2_next++)
                {
                    // end of repeat transitions
                    baseState.merge(*it2,prob[END],END,type_id);
                    
                    // repeat continues
                    it2->prob*=prob[CONTINUE];
                    ++it2->counts[type_id][CONTINUE];
                    
                    // delete transitions
                    it2_next->merge(*it2,prob[INDEL],INDEL,type_id);
                }
                //transition other than delete
                baseState.merge(*it2,prob[END],END,type_id);
                it2->prob*=prob[CONTINUE];
                ++it2->counts[type_id][CONTINUE];
                it2=it2_next++;
                
                // skip state which was added when the region grew
                it2=it2_next++;
            }
        }
        //Analog to FORWARD
        else {
            list<State>::reverse_iterator it2_next = list<State>::reverse_iterator (unused_states);
            list<State>::reverse_iterator it2 = it2_next++;
            
            for(list<Region>::reverse_iterator it = active_regions->rbegin(); it!=active_regions->rend();++it)
            {
                long long start = it->start;
                long long end = it->end;
                
                for(long long i=end;i>start+1; --i,it2=it2_next++)
                {
                    baseState.merge(*it2,prob[END],END,type_id);
                    it2->prob*=prob[CONTINUE];
                    ++it2->counts[type_id][CONTINUE];
                    it2_next->merge(*it2,prob[INDEL],INDEL,type_id);
                }
                baseState.merge(*it2,prob[END],END,type_id);
                it2->prob*=prob[CONTINUE];
                ++it2->counts[type_id][CONTINUE];
                it2=it2_next++;
                if(start==-1)
                {
                    baseState.merge(*it2);
                }
                it2=it2_next++;
            }
        }
        
        }


    void RepeatMachine::mergeWithNewStartStatesDir(State & baseState, long long pos, bool dir)
    {
        State new_start = baseState;
        new_start.prob*=prob[START]/(float)pos;
        ++new_start.counts[type_id][START];
        State carry_state;
        
        if(dir==FORWARD)
        {
            
            list<State>::iterator it2_next = states.begin();
            list<State>::iterator it2 = it2_next++;
            
            for(list<Region>::iterator it = active_regions->begin(); it!=active_regions->end();++it)
            {
                long long start = it->start;
                long long end = it->end;
                
                for(long long i=start;i<=end-1; ++i,it2=it2_next++)
                {
                    // New repeat beginning at this position
                    carry_state.merge(new_start);
                    // New repeat which started before but consisted only of delete operations so far
                    it2->merge(carry_state);
                    
                    //delete transition
                    carry_state.prob*=prob[INDEL];
                    ++carry_state.counts[type_id][INDEL];
                }
                it2=it2_next++;
            }
        }
        else 
        {
            list<State>::reverse_iterator it2_next = list<State>::reverse_iterator (unused_states);
            list<State>::reverse_iterator it2 = it2_next++;
            
            for(list<Region>::reverse_iterator it = active_regions->rbegin(); it!=active_regions->rend();++it)
            {
                long long start = it->start;
                long long end = it->end;
                
                for(long long i=end;i>=start+1; --i,it2=it2_next++)
                {
                    carry_state.merge(new_start);
                    it2->merge(carry_state);
                    carry_state.prob*=prob[INDEL];
                    ++carry_state.counts[type_id][INDEL];
                }
                it2=it2_next++;
            }
        }
        
        
    }

    void RepeatMachine::insertStates(list<State>::iterator pos,int num)
    {    
        
        list<State>::iterator start = unused_states;
        ++start;
        
        // enough unused state available
        if(num<=num_unused_states)
        {
            
            list<State>::iterator end = start;
            advance(end,num);
            
            //reset states
            for(list<State>::iterator it = start; it!=end;++it)
                it->reset();
            
            states.splice(pos,states,start,end);
            
            num_unused_states-=num;
        }
        else 
        {
            for(list<State>::iterator it = start; it!=states.end();++it)
                it->reset();
            states.splice(pos,states,start,states.end());
            states.insert(pos,num-num_unused_states, State());
            num_unused_states = 0;
        }
        
        
    }
    list< RepeatMachine::State>::iterator RepeatMachine::removeStates(list<State>::iterator pos,int num)
    {
        
        
        list<State>::iterator end = pos;
        advance(end, num);
        
        //Move unused states to the end
        states.splice(states.end(),states,pos,end);
        num_unused_states+=num;
        return end;
    }

    void RepeatMachine::updateAge()
    {
        for(list<Region>::iterator it_region = active_regions->begin(); it_region!=active_regions->end(); ++it_region)
            ++it_region->age;
    }




    void RepeatMachine::normalize(float baseValue)
    {
        list<State>::iterator it_state = states.begin();
        for(list<Region>::iterator it_region = active_regions->begin(); it_region!=active_regions->end();++it_region)
        {
            long long start = it_region->start;
            long long end = it_region->end;
            
            for( long long i =start; i<=end;++i,++it_state)
            {
                it_state->prob/=baseValue;
            }
        }
    }



    void RepeatMachine::addCompleteRegionDir(long long pos,bool dir)
    {
        if(pos>1)
        {
            if(dir == FORWARD)
                active_regions->front().end = pos;
            else
                active_regions->front().end = pos-1;
            insertStates(unused_states, 1);

        }
        else {
            if(dir == FORWARD)
                active_regions->push_back(Region(0,1,0));
            else
                active_regions->push_back(Region(-1,0,0));
            insertStates(states.begin(), 2);
        }

        
        
    }



    float RepeatMachine::getProbCharNormalized(const unsigned char* buffer, long long pos, int c)
    {

        if(markov_order == 0 || pos < (long long)markov_order)
        {
            if(c == -1)
                return prob_char[buffer[pos]];
            
            else {
                return prob_char[buffer[pos]]/(1-prob_char[c]);
            }
            
        }
        else {
            unsigned_cstring_fixed_length_ptr_float_hash_map::iterator it = prob_markov->find(buffer+pos-markov_order);
            if(it==prob_markov->end())
            {
                if(c == -1)
                    return prob_char[buffer[pos]];
                
                else {
                    return prob_char[buffer[pos]]/(1-prob_char[c]);
                }
            }
            else {
                if(c == -1)
                    return (it->second)[buffer[pos]];
                else {
                    return (it->second)[buffer[pos]]/(1-(it->second)[c]);
                }
                
            }
            
        }
        
    }

}

