#include "correlation.h"

using namespace std;

namespace papi
{

    Correlation::Correlation( char const* name):
        mutual_information_step(0),mutual_information_max(0),mutual_information_min(0),
        autocorrelation_dar_max(0),size_alphabet(0),
        disable_ac(false),disable_mi(false)
        
    {
        module_id = name;
    }

    Correlation::~Correlation()
    {
        if(count_char)
            delete [] count_char;

        if(count_dist_k)
        {
            for(int i=0;i<size_alphabet;++i)
            {
                for(int j=0;j<size_alphabet;++j)
                    delete [] count_dist_k[i][j];
                delete count_dist_k[i];
            }
            delete [] count_dist_k;
        }
        if(mutual_information)
            delete [] mutual_information;
        if(autocorrelation_dar)
            delete [] autocorrelation_dar;

        if(prob)
            delete [] prob;
        
    }

    void Correlation::init(AnalyzeSetting & settings, long long file_length, const IndexMap<char> *ind_map)
    {
        this->size_alphabet = ind_map->size;
        this->ind_map = ind_map;
        
        disable_ac = settings.getBool("disable_autocorrelation_dar",false);
        disable_mi = settings.getBool("disable_mutual_information",false);
        disable_character_distribution = settings.getBool("disable_character_distribution",false);


        if(disable_ac && disable_mi && disable_character_distribution)
        {
            cerr<<"Correlation Module: All submodules disabled"<<endl;
        }
        else 
        {
            mutual_information_step = settings.getInt("mutual_information_step",1);
            mutual_information_max = settings.getInt("mutual_information_max",1);
            mutual_information_min = settings.getInt("mutual_information_min",1);
            autocorrelation_dar_max = settings.getInt("autocorrelation_dar_max",1);

            if(disable_mi)
                mutual_information_max = mutual_information_min = mutual_information_min_rest = 0;
            if(disable_ac)
                autocorrelation_dar_max = 0;
            
            
            count_char =  new long long[size_alphabet];
            memset(count_char,0,sizeof(long long)*size_alphabet);

            /// Determine size of array count_dist_k[a][b] by counting distance values
            int dist_k_buckets = 0;
            if(!disable_ac)
                dist_k_buckets += autocorrelation_dar_max;
            if(!disable_mi)
            {
                if(mutual_information_min <= autocorrelation_dar_max)
                    mutual_information_min_rest = ((autocorrelation_dar_max-mutual_information_min)/mutual_information_step)*mutual_information_step+mutual_information_step+mutual_information_min;
                else
                    mutual_information_min_rest = mutual_information_min;
                
                if(mutual_information_min_rest<=mutual_information_max)
                    dist_k_buckets += (mutual_information_max-mutual_information_min_rest)/mutual_information_step+1;
            }
            
            count_dist_k = new long long**[size_alphabet];
            for(int i=0;i<size_alphabet;++i)
            {
                count_dist_k[i] = new long long*[size_alphabet];
                for(int j=0;j<size_alphabet;++j)
                {
                    count_dist_k[i][j] = new long long[dist_k_buckets];
                    memset(count_dist_k[i][j],0,sizeof(long long)*dist_k_buckets);
                }
            }

        }
    }

    void Correlation::process(long long bufSize,unsigned char const* buffer,string dir, bool write_stdout, long file_id,string file_path,long long file_length)
    {
        cerr<<"Correlation module started"<<endl;
        for(long long i=0; i<(long long)bufSize;++i)
        {
            int c_index = buffer[i];

            
            ++count_char[c_index];
            
            // Increase first autocorrelation_dar_max counters of count_dist_k[a][b]
        
            if(!disable_ac)
            {
                long long max_j = min((long long)bufSize-i-1,(long long)autocorrelation_dar_max);
                for(long long j=1;j<=max_j;++j)
                    ++count_dist_k[c_index][buffer[i+j]][j-1];
            }                    
            
            // Increase the rest of the counters
            if(!disable_mi)
            {
                long long max_j = min((long long)bufSize-i-1,(long long)mutual_information_max);
                for(long long k=autocorrelation_dar_max,j=mutual_information_min_rest; j<=max_j ; j+=mutual_information_step,++k)
                    ++count_dist_k[c_index][buffer[i+j]][k];
            }

            
        }


        prob = new long double[size_alphabet];
        for(int i=0;i<size_alphabet;++i){
            prob[i] = (long double)count_char[i]/bufSize;
        }
        
        /*
         Mutual information
         */
        int mutual_information_num = 0;
        if(!disable_mi)
        {
            mutual_information_num = (mutual_information_max-mutual_information_min)/mutual_information_step+1;
            mutual_information = new long double[mutual_information_num];
            memset(mutual_information,0,sizeof(long double)*(mutual_information_num));
            
            
            long double **prob_prod = new long double*[size_alphabet];
            for(int i=0;i<size_alphabet;++i)
                prob_prod[i] = new long double[size_alphabet];
            for(int i=0;i<size_alphabet;++i)
                for(int j=0;j<size_alphabet;++j)
                {
                    prob_prod[i][j] = prob[i]*prob[j];
                }
            
            
            // Total sum of pairs with distance k = length - k
            long long n = bufSize-mutual_information_min;        
            int mutual_information_index = 0;
            
            for(int k=mutual_information_min-1;k<autocorrelation_dar_max;k+=mutual_information_step,++mutual_information_index)
            {
                for(int i=0;i<size_alphabet;++i)
                    for(int j=0;j<size_alphabet;++j)
                    {
                        if(count_dist_k[i][j][k]>0)
                            mutual_information[mutual_information_index]+= (long double)count_dist_k[i][j][k] * ( log((long double) count_dist_k[i][j][k]/prob_prod[i][j]/n))/n;
                    }
                mutual_information[mutual_information_index]/=log((long double)2);
                n-=mutual_information_step;
            }
            
            for(int k=autocorrelation_dar_max; mutual_information_index<mutual_information_num;++k,++mutual_information_index)
            {
                for(int i=0;i<size_alphabet;++i)
                    for(int j=0;j<size_alphabet;++j)
                    {
                        if(count_dist_k[i][j][k]>0)
                            mutual_information[mutual_information_index]+= (long double)count_dist_k[i][j][k] * ( log((long double) count_dist_k[i][j][k]/prob_prod[i][j]/n))/n;
                    }
                mutual_information[mutual_information_index]/=log((long double)2);
                n-=mutual_information_step;
            }
            if(prob_prod)
            {
                for(int i=0;i<size_alphabet;++i)
                    delete [] prob_prod[i];
                delete [] prob_prod;
            }
            
            
        }    
        
        /*
         autocorrelation_dar
         */
        
        if(!disable_ac)
        {
            autocorrelation_dar = new long double[autocorrelation_dar_max+1];
            memset(autocorrelation_dar,0,sizeof(long double)*(autocorrelation_dar_max+1));
            autocorrelation_dar[0] = 1;
            
            for(int k=1;k<=(int)autocorrelation_dar_max;++k)
            {
                long double sum1=0;
                for(int i=0;i<size_alphabet;++i)
                {
                    long double sum2=0;
                    for(int j=0;j<size_alphabet;++j)
                    {
                        if(i!=j)
                        {
                            sum2+=count_dist_k[i][j][k-1];
                        }
                    }
                    sum1+=sum2/(bufSize-count_char[i]);
                    
                }
                if(k<(long)bufSize)
                    autocorrelation_dar[k] = max((long double)0.0,1-(sum1*bufSize/(bufSize-k)));
            }
        }    
        
        
        if(!disable_character_distribution)
        {
            printCsvCharacterDistribution(dir, write_stdout, getId(),
                                         file_id, file_path, file_length,
                                         ind_map, prob);
        }
        
        if(!disable_mi)
        {
            printCsvMutualInformation(dir, write_stdout, getId(),
                                      file_id, file_path, file_length,
                                      mutual_information, mutual_information_min,
                                      mutual_information_max, mutual_information_step);
        }
        
        
        if(!disable_ac)
        {
            printCsvAutocorrelationDar(dir, write_stdout, getId(), 
                                      file_id, file_path, file_length, 
                                       autocorrelation_dar, autocorrelation_dar_max);
        }
        
       
        cerr<<"Correlation module finished"<<endl;

    }

    const string Correlation::getId()
    {
        return module_id;
    }

}

