#include <math.h>
#include "mex.h"

#define all_n n = 0; n < nNodes; n++
#define all_state state = 0; state < nStates; state++
#define all_state1 state1 = 0; state1 < nStates; state1++
#define all_state2 state2 = 0; state2 < nStates; state2++
#define all_features f = 0; f < nFeatureTypes; f++
#define nodePot(a,b) nodePot[a + nNodes*(b)]
#define edgePot(a,b,c) edgePot[a + nStates*(b + nStates*(c))]
#define nodeBel(a,b) nodeBel[a + nNodes*(b)]
#define edgeBel(a,b,c) edgeBel[a + nStates*(b + nStates*(c))]
#define alpha(a,b) alpha[a + nNodes*(b)]
#define beta(a,b) beta[a + nNodes*(b)]
#define tmp(a,b) tmp[a + nStates*(b)]
#define v(a,b) v[a + nStates*(b)]

int maxVal(int x,int y) {
    if(x >= y)
        return x;
    else
        return y;
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
    int n,s,state,state1,state2,nStates,nNodes,*sentences,nSentences,maxSentenceLength;
    int f,feature,featureInd,*nFeatures,nFeatureTypes,nFeaturesTotal,*featureStart,nParams;
    int *X,*y,nWords;
    double *w,*v_start,*v_end,*v,*nodePot,*edgePot,*nodeBel,*edgeBel,*alpha,*beta,*tmp,*kappa,Z;
    double lambda,entropy,tmp2,*dualObj;
    double *thetaNode,*thetaEdge,*bta;
    mxArray *mxA;
    
    if (nrhs != 6)
        mexErrMsgTxt("6 arguments are needed: {thetaNode,thetaEdge,data,lambda,nodeBels,edgeBels}");
    
    /* Process optimization arguments */
    lambda = mxGetScalar(prhs[3]);
    
    /* ******************************
     * Process data struct 
     ******************************** */
    if (!mxIsStruct(prhs[2]))
        mexErrMsgTxt("data argument needs to be a struct");
    
    mxA = mxGetField(prhs[2],0,"X");
    X = (int*)mxGetPr(mxA);
    nWords = mxGetM(mxA);
    
    y = (int*)mxGetPr(mxGetField(prhs[2],0,"y"));
    
    nStates = (int)mxGetScalar(mxGetField(prhs[2],0,"nStates"));
    
    mxA = mxGetField(prhs[2],0,"sentences");
    sentences = (int*)mxGetPr(mxA);
    nSentences = mxGetM(mxA);
    
    maxSentenceLength = (int)mxGetScalar(mxGetField(prhs[2],0,"maxSentenceLength"));
    
    featureStart = (int*)mxGetPr(mxGetField(prhs[2],0,"featureStart"));
    
    mxA = mxGetField(prhs[2],0,"nFeatures");
    nFeatures = (int*)mxGetPr(mxA);
    nFeatureTypes = maxVal(mxGetM(mxA),mxGetN(mxA));
    nFeaturesTotal = featureStart[nFeatureTypes]-1;
    nParams = nFeaturesTotal*nStates + nStates + nStates + nStates*nStates;
    
    /* ******************************
     * Initialize intermediate values
     ******************************** */
    nodePot = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    edgePot = mxCalloc(maxSentenceLength*nStates*nStates,sizeof(double));
    nodeBel = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    edgeBel = mxCalloc(maxSentenceLength*nStates*nStates,sizeof(double));
    alpha = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    beta = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    tmp = mxCalloc(nStates*nStates,sizeof(double));
    kappa = mxCalloc(maxSentenceLength,sizeof(double));
    
    /* *******************
     * Output
     * ******************* */
    
    plhs[0] = mxCreateDoubleMatrix(nParams,1,mxREAL);
    w = mxGetPr(plhs[0]);
    v_start = &w[nFeaturesTotal*nStates];
    v_end = &w[nFeaturesTotal*nStates + nStates];
    v = &w[nFeaturesTotal*nStates + nStates + nStates];
    
    plhs[1] = mxCreateDoubleMatrix(1,1,mxREAL);
    dualObj = mxGetPr(plhs[1]);
    
    plhs[2] = mxCreateDoubleMatrix(nSentences,1,mxREAL);
    bta = mxGetPr(plhs[2]);
    
    
    for(s = 0; s < nSentences; s++) {
         nNodes = sentences[s + nSentences]-sentences[s]+1;
         
         /* Exponentiate theta to make potentials */
         for(all_n) {
             mxA = mxGetCell(prhs[0],s);
             thetaNode = mxGetPr(mxA);
             for(all_state) {
                 nodePot(n,state) = exp(thetaNode[n + nNodes*state]);
             }
         }
         for(n = 0;n < nNodes-1;n++) {
            mxA = mxGetCell(prhs[1],s);
            thetaEdge = mxGetPr(mxA);
            for(all_state2) {
                for(all_state1) {
                    edgePot(state1,state2,n) = exp(thetaEdge[state1 + nStates*(state2 + nStates*n)]);
                }
            }
         }
         
         /* Forward-Backward to Compute Marginals */
        for(all_n)
            kappa[n] = 0;
        for(all_state) {
            alpha(0,state) = nodePot(0,state);
            kappa[0] += alpha(0,state);
        }
        for(all_state)
            alpha(0,state) /= kappa[0];
        /* Forward Pass */
        for(n = 1; n < nNodes; n++) {
            for(all_state1) {
                for(all_state2)
                    tmp(state1,state2) = alpha(n-1,state1) * edgePot(state1,state2,n-1);
            }
            for(all_state2) {
                alpha(n,state2) = 0;
                for(all_state1)
                    alpha(n,state2) += tmp(state1,state2);
                alpha(n,state2) = nodePot(n,state2) * alpha(n,state2);
                kappa[n] += alpha(n,state2);
            }
            for(all_state)
                alpha(n,state) /= kappa[n];
        }
        /* Backward Pass */
       for(all_state)
            beta(nNodes-1,state) = 1;
        for(n = nNodes-2; n >= 0; n--) {
            for(all_state1) {
                for(all_state2)
                    tmp(state1,state2) = nodePot(n+1,state2)*edgePot(state1,state2,n)*beta(n+1,state2);
            }
            Z = 0;
            for(all_state1) {
                beta(n,state1) = 0;
                for(all_state2)
                    beta(n,state1) += tmp(state1,state2);
                Z += beta(n,state1);
            }
            for(all_state)
                beta(n,state) /= Z;
        }
         
         /* Node Beliefs */
        for(all_n) {
            Z = 0;
            for(all_state) {
                nodeBel(n,state) = alpha(n,state)*beta(n,state);
                Z += nodeBel(n,state);
            }
            for(all_state)
                nodeBel(n,state) /= Z;
        }         
        /* Edge Beliefs */
        for(n = 0; n < nNodes-1; n++) {
            Z = 0;
            for(all_state1) {
                for(all_state2) {
                    tmp(state1,state2) = alpha(n,state1)*nodePot(n+1,state2)*beta(n+1,state2)*edgePot(state1,state2,n);
                    Z += tmp(state1,state2);
                }
            }
            for(all_state1) {
                for(all_state2) {
                    edgeBel(state1,state2,n) = tmp(state1,state2)/Z;
                }
            }
        }
         
         /* Update w */
         for(all_n) {
            for(f = 0; f < nFeatureTypes; f++) {
                feature = X[sentences[s]-1 + n + nWords*f];
                if(feature != 0) {
                    featureInd = featureStart[f]-1 + feature-1;
                    for(all_state) {
                        w[featureInd + nFeaturesTotal*state] -= nodeBel(n,state);
                    }
                    state = y[sentences[s]-1 +n]-1;
                    w[featureInd + nFeaturesTotal*state] += 1;
                }
            }
        }
         /* Start/end parameters */
        for(all_state) {
            v_start[state] -= nodeBel(0,state);
            v_end[state] -= nodeBel(nNodes-1,state);
        }
        v_start[y[sentences[s]-1]-1] += 1;
        v_end[y[sentences[s]-1 + nNodes-1]-1] += 1;
        /* Edge parameters */
        for(n = 0; n < nNodes-1; n++)
        {
            for(all_state1) {
                for(all_state2) {
                    v[state1 + nStates*state2] -= edgeBel(state1,state2,n);
                }
            }
            state1 = y[sentences[s]-1 +n]-1;
            state2 = y[sentences[s]-1 +n+1]-1;
            v[state1 + nStates*state2] += 1;
        }
         
        /* Compute entropy */
        entropy = 0;
        for(all_n) {
            for(all_state) {
                entropy -= nodeBel(n,state)*log(nodeBel(n,state));
            }
        }
        for(n = 0;n < nNodes-1;n++) {
            for(all_state1) {
                for(all_state2) {
                    tmp2 = log(edgeBel(state1,state2,n)) - log(nodeBel(n,state1)) - log(nodeBel(n+1,state2));
                   entropy -= edgeBel(state1,state2,n)*tmp2;
                }
            }
        }
        bta[s] = -entropy;
        
        /* Store marginals */
        for(all_n) {
             mxA = mxGetCell(prhs[4],s);
             thetaNode = mxGetPr(mxA); /* Probably not the best name */
             for(all_state) {
                 thetaNode[n+nNodes*state] = nodeBel(n,state);
             }
         }
         for(n = 0;n < nNodes-1;n++) {
            mxA = mxGetCell(prhs[5],s);
            thetaEdge = mxGetPr(mxA); /* Probably not the best name */
            for(all_state2) {
                for(all_state1) {
                    thetaEdge[state1+nStates*(state2+nStates*n)] = edgeBel(state1,state2,n);
                }
            }
         }
    }
    
    *dualObj = 0;
    for(f = 0;f < nParams;f++)
        *dualObj += w[f]*w[f];
    *dualObj /= (2.0*lambda);
    for(s = 0;s < nSentences;s++)
        *dualObj += bta[s];
    
    /* ****************
     * Free memory
     ****************** */
    mxFree(nodePot);
    mxFree(edgePot);
    mxFree(alpha);
    mxFree(beta);
    mxFree(tmp);
    mxFree(nodeBel);
    mxFree(edgeBel);
    mxFree(kappa);
}