#include <math.h>
#include "mex.h"

#define all_n n = 0; n < nNodes; n++
#define all_state state = 0; state < nStates; state++
#define all_state1 state1 = 0; state1 < nStates; state1++
#define all_state2 state2 = 0; state2 < nStates; state2++
#define all_features f = 0; f < nFeatureTypes; f++
#define nodePot(a, b) nodePot[a + nNodes*(b)]
#define edgePot(a, b) edgePot[a + nStates*(b)]
#define nodeBel(a, b) nodeBel[a + nNodes*(b)]
#define edgeBel(a, b, c) edgeBel[a + nStates*(b + nStates*(c))]
#define alpha(a, b) alpha[a + nNodes*(b)]
#define beta(a, b) beta[a + nNodes*(b)]
#define tmp(a, b) tmp[a + nStates*(b)]
#define v(a, b) v[a + nStates*(b)]

int maxVal(int x, int y) {
    if(x >= y)
        return x;
    else
        return y;
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
    int n, s, state, state1, state2, nStates, nNodes, *sentences, nSentences, maxSentenceLength;
    int f, feature, featureInd, *nFeatures, nFeatureTypes, nFeaturesTotal, *featureStart;
    int *X, *y, nWords;
    int *yHat, sentenceError, *maximizers;
    double *nErrs, nTokens=0, *wv, *nodePot, *nodeBel, *edgePot, *alpha, *beta, *tmp, *w, *v_start, *v_end, *v, maxBel, *kappa, Z;
    double max_tmp;
    mxArray *mxA;
    
    /* Process data struct */
    if (!mxIsStruct(prhs[1]))
        mexErrMsgTxt("data argument needs to be a struct: crfChain_testErrC(wv,data)");
    
    mxA = mxGetField(prhs[1], 0, "X");
    X = (int*)mxGetPr(mxA);
    nWords = mxGetM(mxA);
    
    y = (int*)mxGetPr(mxGetField(prhs[1], 0, "y"));
    
    nStates = (int)mxGetScalar(mxGetField(prhs[1], 0, "nStates"));
    
    mxA = mxGetField(prhs[1], 0, "sentences");
    sentences = (int*)mxGetPr(mxA);
    nSentences = mxGetM(mxA);
    
    maxSentenceLength = (int)mxGetScalar(mxGetField(prhs[1], 0, "maxSentenceLength"));
    
    featureStart = (int*)mxGetPr(mxGetField(prhs[1], 0, "featureStart"));
    
    mxA = mxGetField(prhs[1], 0, "nFeatures");
    nFeatures = (int*)mxGetPr(mxA);
    nFeatureTypes = maxVal(mxGetM(mxA), mxGetN(mxA));
    nFeaturesTotal = featureStart[nFeatureTypes]-1;
    
    /* Initialize intermediate values */
    nodePot = mxCalloc(maxSentenceLength*nStates, sizeof(double));
    edgePot = mxCalloc(nStates*nStates, sizeof(double));
    nodeBel = mxCalloc(maxSentenceLength*nStates, sizeof(double));
    alpha = mxCalloc(maxSentenceLength*nStates, sizeof(double));
    beta = mxCalloc(maxSentenceLength*nStates, sizeof(double));
    tmp = mxCalloc(nStates*nStates, sizeof(double));
    kappa = mxCalloc(maxSentenceLength, sizeof(double));
    maximizers = mxCalloc(maxSentenceLength*nStates, sizeof(int));
    yHat = mxCalloc(maxSentenceLength, sizeof(int));
    
    wv = mxGetPr(prhs[0]);
    w = wv;
    v_start = &wv[nFeaturesTotal*nStates];
    v_end = &wv[nFeaturesTotal*nStates + nStates];
    v = &wv[nFeaturesTotal*nStates + nStates + nStates];
    
    plhs[0] = mxCreateDoubleMatrix(1, 4, mxREAL);
    nErrs = mxGetPr(plhs[0]);
    
    /* Edge Potentials are constant */
    for(all_state2) {
        for(all_state1) {
            edgePot(state1, state2) = exp(v(state1, state2));
        }
    }
    
    for(s = 0;s < nSentences; s++) {
        nNodes = sentences[s + nSentences]-sentences[s]+1;
        nTokens += nNodes;
        
        /* Compute nodePot */
        for(all_n) {
            for(all_state)
                nodePot(n, state) = 0;
            for(f = 0; f < nFeatureTypes; f++) {
                feature = X[sentences[s]-1 + n + nWords*f];
                if(feature != 0) {
                    featureInd = featureStart[f]-1 + feature-1;
                    for(all_state) {
                        nodePot(n, state) += w[featureInd + nFeaturesTotal*state];
                    }
                }
            }
        }
        /* Add beginning/end of sentence modification */
        for(all_state) {
            nodePot(0, state) += v_start[state];
            nodePot(nNodes-1, state) += v_end[state];
        }
        
        /* Exponentiate Potentials */
        for(all_n) {
            for(all_state)
                nodePot(n, state) = exp(nodePot(n, state));
        }
        
        
        /* Forward-Backward to Compute Marginals */
        
        /* Forward Pass */
        for(all_n)
            kappa[n] = 0;
        for(all_state) {
            alpha(0, state) = nodePot(0, state);
            kappa[0] += alpha(0, state);
        }
        for(all_state)
            alpha(0, state) /= kappa[0];
        for(n = 1; n < nNodes; n++) {
            for(all_state1) {
                for(all_state2)
                    tmp(state1, state2) = alpha(n-1, state1) * edgePot(state1, state2);
            }
            for(all_state2) {
                alpha(n, state2) = 0;
                for(all_state1)
                    alpha(n, state2) += tmp(state1, state2);
                alpha(n, state2) = nodePot(n, state2) * alpha(n, state2);
                kappa[n] += alpha(n, state2);
            }
            for(all_state)
                alpha(n, state) /= kappa[n];
        }
        
        /* Backward Pass */
        for(all_state)
            beta(nNodes-1, state) = 1;
        for(n = nNodes-2; n >= 0; n--) {
            for(all_state1) {
                for(all_state2)
                    tmp(state1, state2) = nodePot(n+1, state2)*edgePot(state1, state2)*beta(n+1, state2);
            }
            Z = 0;
            for(all_state1) {
                beta(n, state1) = 0;
                for(all_state2)
                    beta(n, state1) += tmp(state1, state2);
                Z += beta(n, state1);
            }
            for(all_state)
                beta(n, state) /= Z;
        }
        
        /* Compute maximum of unnormalized marginals */
        sentenceError = 0;
        for(all_n) {
            maxBel = alpha(n, 0)*beta(n, 0);
            *yHat = 0;
            for(state = 1; state < nStates; state++) {
                if (alpha(n, state)*beta(n, state) > maxBel) {
                    maxBel = alpha(n, state)*beta(n, state);
                    *yHat = state;
                }
            }
            if (*yHat != y[sentences[s]-1+n]-1) {
                nErrs[0] += 1; /* Error at token */
                if (!sentenceError) {
                    nErrs[1] += 1; /* Error in sentence */
                    sentenceError = 1;
                }
            }
        }
        
        /* Forward-Bacward to compute optimal decoding */
        *kappa = 0;
        for(all_state) {
            alpha[nNodes*state] = nodePot(0, state);
            *kappa += alpha(0, state);
        }
        for(all_state) {
            alpha(0, state) /= *kappa;
        }
        
        /* Forward Pass */
        *kappa = 0;
        for(n = 1; n < nNodes;n++) {
            for(all_state1) {
                for(all_state2) {
                    tmp(state1, state2) = alpha(n-1, state1)*edgePot(state1, state2);
                }
            }
            for(all_state2) {
                max_tmp = 0;
                for(all_state1) {
                    if(tmp(state1, state2) > max_tmp) {
                        maximizers[n + nNodes*state2] = state1;
                        max_tmp = tmp(state1, state2);
                    }
                }
                alpha(n, state2) = nodePot(n, state2)*max_tmp;
                *kappa += alpha(n, state2);
            }
            for(all_state) {
                alpha(n, state) /= *kappa;
            }
        }
        
        
        /* Backward Pass and to compute decoding */
        max_tmp = 0;
        for(all_state) {
            if(alpha(nNodes-1, state) > max_tmp) {
                max_tmp = alpha(nNodes-1, state);
                yHat[nNodes-1] = state;
            }
        }
        
        for(n = nNodes-2; n >= 0; n--)
            yHat[n] = maximizers[n+1 + nNodes*(yHat[n+1])];
        
        /* Count errors made by decoding */
        

        sentenceError = 0;
        for(all_n) {
            if (yHat[n] != y[sentences[s]-1+n]-1) {
                nErrs[2] += 1; /* Error at token */
                if (!sentenceError) {
                    nErrs[3] += 1; /* Error in sentence */
                    sentenceError = 1;
                }
            }
        }
        
    }
    
    nErrs[0] /= nTokens;
    nErrs[1] /= nSentences;
    nErrs[2] /= nTokens;
    nErrs[3] /= nSentences;
    
    mxFree(nodePot);
    mxFree(edgePot);
    mxFree(alpha);
    mxFree(beta);
    mxFree(tmp);
    mxFree(nodeBel);
    mxFree(kappa);
    mxFree(maximizers);
    mxFree(yHat);
}