#include <math.h>
#include "mex.h"

#define all_n n = 0; n < nNodes; n++
#define all_state state = 0; state < nStates; state++
#define all_state1 state1 = 0; state1 < nStates; state1++
#define all_state2 state2 = 0; state2 < nStates; state2++
#define all_features f = 0; f < nFeatureTypes; f++
#define nodePot(a,b) nodePot[a + nNodes*(b)]
#define edgePot(a,b,c) edgePot[a + nStates*(b + nStates*(c))]
#define nodeBel(a,b) nodeBel[a + nNodes*(b)]
#define edgeBel(a,b,c) edgeBel[a + nStates*(b + nStates*(c))]
#define alpha(a,b) alpha[a + nNodes*(b)]
#define beta(a,b) beta[a + nNodes*(b)]
#define tmp(a,b) tmp[a + nStates*(b)]
#define v(a,b) v[a + nStates*(b)]

int maxVal(int x,int y) {
    if(x >= y)
        return x;
    else
        return y;
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
    int n,s,state,state1,state2,nStates,nNodes,*sentences,nSentences,maxSentenceLength;
    int f,feature,featureInd,*nFeatures,nFeatureTypes,nFeaturesTotal,*featureStart,nParams;
    int *X,*y,nWords;
    int iter,maxBacktrack,*sVals,update,maxUpdates;
    double *w,*v_s,*v_e,*v,*nodePot,*edgePot,*nodeBel,*edgeBel,*alpha,*beta,*tmp,*kappa,Z;
    double lambda,entropy,tmp2,tmp3,*dualObj;
    double *thetaNode,*thetaEdge,*bta,dualObjNew,*nodeBels,*edgeBels;
    double *eta,etaS,*thetaN,*thetaE,btaS,*evals,*sumbta,*nrm2,nrm2new;
    double *deltaW,*deltaV;
    mwSize dims[2];
    mxArray *mxA;
    
    if (nrhs != 14)
        mexErrMsgTxt("14 arguments are needed: {w,thetaNode,thetaEdge,data,lambda,eta,nodeBels,edgeBels,bta,evals,dualObj,i,sumbta,nrm2}");
    
    /* Process optimization arguments */
    w = mxGetPr(prhs[0]);
    lambda = mxGetScalar(prhs[4]);
    eta = mxGetPr(prhs[5]);
    bta = mxGetPr(prhs[8]);
    evals = (double*)mxGetPr(prhs[9]);
    dualObj = mxGetPr(prhs[10]);
    sVals = (int*)mxGetPr(prhs[11]);
    sumbta = mxGetPr(prhs[12]);
    nrm2 = mxGetPr(prhs[13]);
    
    /* ******************************
     * Process data struct
     ******************************** */
    if (!mxIsStruct(prhs[3]))
        mexErrMsgTxt("data argument needs to be a struct");
    
    mxA = mxGetField(prhs[3],0,"X");
    X = (int*)mxGetPr(mxA);
    nWords = mxGetM(mxA);
    
    y = (int*)mxGetPr(mxGetField(prhs[3],0,"y"));
    
    nStates = (int)mxGetScalar(mxGetField(prhs[3],0,"nStates"));
    
    mxA = mxGetField(prhs[3],0,"sentences");
    sentences = (int*)mxGetPr(mxA);
    nSentences = mxGetM(mxA);
    
    maxSentenceLength = (int)mxGetScalar(mxGetField(prhs[3],0,"maxSentenceLength"));
    
    featureStart = (int*)mxGetPr(mxGetField(prhs[3],0,"featureStart"));
    
    mxA = mxGetField(prhs[3],0,"nFeatures");
    nFeatures = (int*)mxGetPr(mxA);
    nFeatureTypes = maxVal(mxGetM(mxA),mxGetN(mxA));
    nFeaturesTotal = featureStart[nFeatureTypes]-1;
    nParams = nFeaturesTotal*nStates + nStates + nStates + nStates*nStates;
    
    /* ******************************
     * Initialize intermediate values
     ******************************** */
    nodePot = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    edgePot = mxCalloc(maxSentenceLength*nStates*nStates,sizeof(double));
    nodeBel = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    edgeBel = mxCalloc(maxSentenceLength*nStates*nStates,sizeof(double));
    alpha = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    beta = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    tmp = mxCalloc(nStates*nStates,sizeof(double));
    kappa = mxCalloc(maxSentenceLength,sizeof(double));
    thetaN = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    thetaE = mxCalloc(maxSentenceLength*nStates*nStates,sizeof(double));
    deltaW = mxCalloc(nFeaturesTotal*nStates,sizeof(double));
    deltaV = mxCalloc(nStates*nStates,sizeof(double));
    
    v_s = &w[nFeaturesTotal*nStates];
    v_e = &w[nFeaturesTotal*nStates + nStates];
    v = &w[nFeaturesTotal*nStates + nStates + nStates];
    
    maxUpdates = maxVal(mxGetM(prhs[11]),mxGetN(prhs[11]));
    
    for(update = 0; update < maxUpdates;update++) {
        s = sVals[update]-1;
        if(eta[s] == 0.5)
            maxBacktrack = 2;
        else
            maxBacktrack = 5;
        etaS = eta[s];
        
        nNodes = sentences[s + nSentences]-sentences[s]+1;
        mxA = mxGetCell(prhs[1],s);
        thetaNode = mxGetPr(mxA);
        mxA = mxGetCell(prhs[2],s);
        thetaEdge = mxGetPr(mxA);
        mxA = mxGetCell(prhs[6],s);
        nodeBels = mxGetPr(mxA);
        mxA = mxGetCell(prhs[7],s);
        edgeBels = mxGetPr(mxA);
        
        iter = 0;
        while (1) {
            /* Compute new value of theta */
            for(all_n) {
                for(all_state) {
                    thetaN[n + nNodes*state] = (1-etaS)*thetaNode[n + nNodes*state];
                }
                for(f = 0; f < nFeatureTypes; f++) {
                    feature = X[sentences[s]-1 + n + nWords*f];
                    if(feature != 0) {
                        featureInd = featureStart[f]-1 + feature-1;
                        for(all_state) {
                            thetaN[n + nNodes*state] += (etaS/lambda)*w[featureInd + nFeaturesTotal*state];
                        }
                    }
                }
            }
            for(all_state) {
                thetaN[nNodes*state] += (etaS/lambda)*v_s[state];
                thetaN[nNodes-1 + nNodes*state] += (etaS/lambda)*v_e[state];
            }
            for(n = 0;n < nNodes-1;n++) {
                for(all_state1) {
                    for(all_state2) {
                        thetaE[state1 + nStates*(state2 + nStates*n)] = (1-etaS)*thetaEdge[state1 + nStates*(state2 + nStates*n)] + (etaS/lambda)*v(state1,state2);
                    }
                }
            }
            
            /* Exponentiate new theta values to make potentials */
            for(all_n) {
                for(all_state) {
                    nodePot(n,state) = exp(thetaN[n + nNodes*state]);
                }
            }
            for(n = 0;n < nNodes-1;n++) {
                for(all_state2) {
                    for(all_state1) {
                        edgePot(state1,state2,n) = exp(thetaE[state1 + nStates*(state2 + nStates*n)]);
                    }
                }
            }
            
            /* Forward-Backward to Compute Marginals */
            for(all_n)
                kappa[n] = 0;
            for(all_state) {
                alpha(0,state) = nodePot(0,state);
                kappa[0] += alpha(0,state);
            }
            for(all_state)
                alpha(0,state) /= kappa[0];
            /* Forward Pass */
            for(n = 1; n < nNodes; n++) {
                for(all_state1) {
                    for(all_state2)
                        tmp(state1,state2) = alpha(n-1,state1) * edgePot(state1,state2,n-1);
                }
                for(all_state2) {
                    alpha(n,state2) = 0;
                    for(all_state1)
                        alpha(n,state2) += tmp(state1,state2);
                    alpha(n,state2) = nodePot(n,state2) * alpha(n,state2);
                    kappa[n] += alpha(n,state2);
                }
                for(all_state)
                    alpha(n,state) /= kappa[n];
            }
            /* Backward Pass */
            for(all_state)
                beta(nNodes-1,state) = 1;
            for(n = nNodes-2; n >= 0; n--) {
                for(all_state1) {
                    for(all_state2)
                        tmp(state1,state2) = nodePot(n+1,state2)*edgePot(state1,state2,n)*beta(n+1,state2);
                }
                Z = 0;
                for(all_state1) {
                    beta(n,state1) = 0;
                    for(all_state2)
                        beta(n,state1) += tmp(state1,state2);
                    Z += beta(n,state1);
                }
                for(all_state)
                    beta(n,state) /= Z;
            }
            
            /* Node Beliefs */
            for(all_n) {
                Z = 0;
                for(all_state) {
                    nodeBel(n,state) = alpha(n,state)*beta(n,state);
                    Z += nodeBel(n,state);
                }
                for(all_state)
                    nodeBel(n,state) /= Z;
            }
            /* Edge Beliefs */
            for(n = 0; n < nNodes-1; n++) {
                Z = 0;
                for(all_state1) {
                    for(all_state2) {
                        tmp(state1,state2) = alpha(n,state1)*nodePot(n+1,state2)*beta(n+1,state2)*edgePot(state1,state2,n);
                        Z += tmp(state1,state2);
                    }
                }
                for(all_state1) {
                    for(all_state2) {
                        edgeBel(state1,state2,n) = tmp(state1,state2)/Z;
                    }
                }
            }
            *evals += 1;
            
            /* Compute norm */
            nrm2new = *nrm2;
            for(all_n) {
                for(f = 0; f < nFeatureTypes; f++) {
                    feature = X[sentences[s]-1 + n + nWords*f];
                    if(feature != 0) {
                        featureInd = featureStart[f]-1 + feature-1;
                        for(all_state) {
                            tmp2 = w[featureInd + nFeaturesTotal*state] + deltaW[featureInd + nFeaturesTotal*state];
                            nrm2new -= tmp2*tmp2;
                        }
                        for(all_state) {
                            deltaW[featureInd + nFeaturesTotal*state] += nodeBels[n + nNodes*state] - nodeBel(n,state);
                        }
                        for(all_state) {
                            tmp2 = w[featureInd + nFeaturesTotal*state] + deltaW[featureInd + nFeaturesTotal*state];
                            nrm2new += tmp2*tmp2;
                        }
                    }
                }
            }
            /* Start/end parameters */
            for(all_state) {
                tmp2 = v_s[state];
                tmp3 = tmp2 + nodeBels[0 + nNodes*state] - nodeBel(0,state);
                nrm2new += tmp3*tmp3 - tmp2*tmp2;
                tmp2 = v_e[state];
                tmp3 = tmp2 + nodeBels[nNodes-1 + nNodes*state] - nodeBel(nNodes-1,state);
                nrm2new += tmp3*tmp3-tmp2*tmp2;
            }
            /* Edge parameters */
            for(n = 0; n < nNodes-1; n++)
            {
                for(all_state1) {
                    for(all_state2) {
                        deltaV[state1 + nStates*state2] += edgeBels[state1 + nStates*(state2 + nStates*n)] - edgeBel(state1,state2,n);
                    }
                }
            }
            for(all_state1) {
                for(all_state2) {
                    tmp2 = v(state1,state2);
                    tmp3 = tmp2 + deltaV[state1 + nStates*state2];
                    nrm2new += tmp3*tmp3-tmp2*tmp2;
                    deltaV[state1+nStates*state2] = 0;
                }
            }
            
            /* Reset deltaW variables */
            for(all_n) {
                for(f = 0; f < nFeatureTypes; f++) {
                    feature = X[sentences[s]-1 + n + nWords*f];
                    if(feature != 0) {
                        featureInd = featureStart[f]-1 + feature-1;
                        for(all_state)
                            deltaW[featureInd + nFeaturesTotal*state] = 0;
                    }
                }
            }
            
            /* Compute entropy */
            entropy = 0;
            for(all_n) {
                for(all_state) {
                    entropy -= nodeBel(n,state)*log(nodeBel(n,state));
                }
            }
            for(n = 0;n < nNodes-1;n++) {
                for(all_state1) {
                    for(all_state2) {
                        tmp2 = log(edgeBel(state1,state2,n)) - log(nodeBel(n,state1)) - log(nodeBel(n+1,state2));
                        entropy -= edgeBel(state1,state2,n)*tmp2;
                    }
                }
            }
            btaS = -entropy;
            dualObjNew = *sumbta - bta[s] + btaS + nrm2new/(2.0*lambda);
            
            if (dualObjNew <= *dualObj) {
                for(all_n) {
                    for(f = 0; f < nFeatureTypes; f++) {
                        feature = X[sentences[s]-1 + n + nWords*f];
                        if(feature != 0) {
                            featureInd = featureStart[f]-1 + feature-1;
                            for(all_state) {
                                w[featureInd + nFeaturesTotal*state] += nodeBels[n + nNodes*state] - nodeBel(n,state);
                            }
                        }
                    }
                }
                /* Start/end parameters */
                for(all_state) {
                    v_s[state] += nodeBels[0 + nNodes*state] - nodeBel(0,state);
                    v_e[state] += nodeBels[nNodes-1 + nNodes*state] - nodeBel(nNodes-1,state);
                }
                /* Edge parameters */
                for(n = 0; n < nNodes-1; n++)
                {
                    for(all_state1) {
                        for(all_state2) {
                            v(state1,state2) += edgeBels[state1 + nStates*(state2 + nStates*n)] - edgeBel(state1,state2,n);
                        }
                    }
                }
                eta[s] = etaS;
                *sumbta = *sumbta - bta[s] + btaS;
                bta[s] = btaS;
                *dualObj = dualObjNew;
                for(all_n) {
                    for(all_state) {
                        thetaNode[n+nNodes*state] = thetaN[n+nNodes*state];
                        nodeBels[n+nNodes*state] = nodeBel(n,state);
                    }
                }
                for(n = 0;n < nNodes-1;n++) {
                    for(all_state2) {
                        for(all_state1) {
                            thetaEdge[state1+nStates*(state2+nStates*n)] = thetaE[state1+nStates*(state2+nStates*n)];
                            edgeBels[state1+nStates*(state2+nStates*n)] = edgeBel(state1,state2,n);
                        }
                    }
                }
                *nrm2 = nrm2new;
                break;
            }
            else {
                iter++;
                if (iter <= maxBacktrack || mxIsNaN(dualObjNew) || mxIsInf(dualObjNew)) {
                    /*printf("Backtracking, eta(%d) = %.5e\n",s+1,etaS);*/
                    etaS = etaS/2;
                }
                else {
                    printf("Line Search Failed (i = %d)\n",s+1);
                    eta[s] = etaS;
                    break;
                }
            }
        }
        eta[s] *= 1.05;
    }
    
    
    /* ****************
     * Free memory
     ****************** */
    mxFree(nodePot);
    mxFree(edgePot);
    mxFree(alpha);
    mxFree(beta);
    mxFree(tmp);
    mxFree(nodeBel);
    mxFree(edgeBel);
    mxFree(kappa);
    mxFree(thetaN);
    mxFree(thetaE);
    mxFree(deltaW);
    mxFree(deltaV);
    
    if (nlhs == 4) {
        /* Return relevant quantities for re-calling the function */
        dims[0]=1;
        dims[1]=1;
        tmp2 = *evals;
        plhs[0] = mxCreateNumericArray(2,dims,mxDOUBLE_CLASS,mxREAL);
        evals = mxGetPr(plhs[0]);
        *evals = tmp2;
        tmp2 = *dualObj;
        plhs[1] = mxCreateNumericArray(2,dims,mxDOUBLE_CLASS,mxREAL);
        dualObj = mxGetPr(plhs[1]);
        *dualObj = tmp2;
        tmp2 = *sumbta;
        plhs[2] = mxCreateNumericArray(2,dims,mxDOUBLE_CLASS,mxREAL);
        sumbta = mxGetPr(plhs[2]);
        *sumbta = tmp2;
        tmp2 = *nrm2;
        plhs[3] = mxCreateNumericArray(2,dims,mxDOUBLE_CLASS,mxREAL);
        nrm2 = mxGetPr(plhs[3]);
        *nrm2 = tmp2;
    }
}