#include <math.h>
#include "mex.h"

#define all_n n = 0; n < nNodes; n++
#define all_state state = 0; state < nStates; state++
#define all_state1 state1 = 0; state1 < nStates; state1++
#define all_state2 state2 = 0; state2 < nStates; state2++
#define all_features f = 0; f < nFeatureTypes; f++
#define nodePot(a,b) nodePot[a + nNodes*(b)]
#define edgePot(a,b) edgePot[a + nStates*(b)]
#define nodeBel(a,b) nodeBel[a + nNodes*(b)]
#define edgeBel(a,b,c) edgeBel[a + nStates*(b + nStates*(c))]
#define alpha(a,b) alpha[a + nNodes*(b)]
#define beta(a,b) beta[a + nNodes*(b)]
#define tmp(a,b) tmp[a + nStates*(b)]
#define v(a,b) v[a + nStates*(b)]
#define dV(a,b) dV[a + nStates*(b)]
#define NB(a,b) NB[a + nWords*(b)]
#define EB(a,b,c) EB[a + nStates*(b + nStates*(c))]

int maxVal(int x,int y) {
    if(x >= y)
        return x;
    else
        return y;
}

double logBase2(double x) {
	return log(x)/log(2);
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
    int n,s,state,state1,state2,nStates,nNodes,*sentences,nSentences,maxSentenceLength;
    int f,feature,featureInd,*nFeatures,nFeatureTypes,nFeaturesTotal,*featureStart,nParams;
    int *X,*y,nWords;
    int iter,maxIter,nCovered, *lastVisited,  *covered, ind;
    int nextpow2, nLevels, levelMax,level;
    double *wv,*nodePot,*nodeBel,*edgePot,*edgeBel,*alpha,*beta,*tmp,*w,*v_start,*v_end,*v,Z,logZ,*kappa;
    double lambda, stepSize, rho=1.0, step, *cumSum, *d,*dW,*dV_start,*dV_end,*dV,*NB,*EB;
    double eps=1e-8,gg,*gW,*gV_start,*gV_end,*gV,fi,fi_new,*nFwdBwd;
    double *randVals, *Lmax, *Li, Lmean, Li_old, *LiMatrix,offset,u,z,tmp2;
    mxArray *mxA;
    
    if (nrhs != 10)
        mexErrMsgTxt("11 arguments are needed: {wv,data,lambda,randVals,d,covered,NB,EB,Lmax,Li}");
    
    
    /* Process optimization arguments */
    lambda = mxGetScalar(prhs[2]);
    randVals = mxGetPr(prhs[3]);
    d = mxGetPr(prhs[4]);
    covered = (int*)mxGetPr(prhs[5]);
    NB = mxGetPr(prhs[6]);
    EB = mxGetPr(prhs[7]);
    Lmax = mxGetPr(prhs[8]);
    Li = mxGetPr(prhs[9]);
    maxIter = mxGetM(prhs[3]);
    if (!mxIsClass(prhs[5],"int32"))
        mexErrMsgTxt("covered must be int32");
    
    /* ******************************
     * Process data struct
     ******************************** */
    if (!mxIsStruct(prhs[1]))
        mexErrMsgTxt("data argument needs to be a struct");
    
    mxA = mxGetField(prhs[1],0,"X");
    X = (int*)mxGetPr(mxA);
    nWords = mxGetM(mxA);
    
    y = (int*)mxGetPr(mxGetField(prhs[1],0,"y"));
    
    nStates = (int)mxGetScalar(mxGetField(prhs[1],0,"nStates"));
    
    mxA = mxGetField(prhs[1],0,"sentences");
    sentences = (int*)mxGetPr(mxA);
    nSentences = mxGetM(mxA);
    
    maxSentenceLength = (int)mxGetScalar(mxGetField(prhs[1],0,"maxSentenceLength"));
    
    featureStart = (int*)mxGetPr(mxGetField(prhs[1],0,"featureStart"));
    
    mxA = mxGetField(prhs[1],0,"nFeatures");
    nFeatures = (int*)mxGetPr(mxA);
    nFeatureTypes = maxVal(mxGetM(mxA),mxGetN(mxA));
    nFeaturesTotal = featureStart[nFeatureTypes]-1;
    nParams = nFeaturesTotal*nStates + nStates + nStates + nStates*nStates;
    
    /* Error Checking for sizes of NB and EB */
    if (nWords != mxGetM(prhs[6]))
        mexErrMsgTxt("X and NB must have the same number of rows");
    if (nStates != mxGetN(prhs[6]))
        mexErrMsgTxt("NB must have nStates columns");
    if (3 != mxGetNumberOfDimensions(prhs[7]))
        mexErrMsgTxt("EB must be a 3D array");
    if (nStates != mxGetDimensions(prhs[7])[0])
        mexErrMsgTxt("EB must have nStates rows");
    if (nStates != mxGetDimensions(prhs[7])[1])
        mexErrMsgTxt("EB must have nStates columns");
    if (nSentences != mxGetDimensions(prhs[7])[2])
        mexErrMsgTxt("EB must be (nStates,nStates,nSentences)");
    
    /* ******************************
     * Initialize intermediate values
     ******************************** */
    nodePot = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    edgePot = mxCalloc(nStates*nStates,sizeof(double));
    nodeBel = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    edgeBel = mxCalloc(maxSentenceLength*nStates*nStates,sizeof(double));
    alpha = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    beta = mxCalloc(maxSentenceLength*nStates,sizeof(double));
    tmp = mxCalloc(nStates*nStates,sizeof(double));
    kappa = mxCalloc(maxSentenceLength,sizeof(double));
    gW = mxCalloc(nFeaturesTotal*nStates,sizeof(double));
    gV_start = mxCalloc(nStates,sizeof(double));
    gV_end = mxCalloc(nStates,sizeof(double));
    gV = mxCalloc(nStates*nStates,sizeof(double));
    
    lastVisited = mxCalloc(nFeaturesTotal*nStates,sizeof(int));
    cumSum = mxCalloc(maxIter,sizeof(double));
    

    
    wv = mxGetPr(prhs[0]);
    w = wv;
    v_start = &wv[nFeaturesTotal*nStates];
    v_end = &wv[nFeaturesTotal*nStates + nStates];
    v = &wv[nFeaturesTotal*nStates + nStates + nStates];
    
    dW = d;
    dV_start = &d[nFeaturesTotal*nStates];
    dV_end = &d[nFeaturesTotal*nStates + nStates];
    dV = &d[nFeaturesTotal*nStates + nStates + nStates];
    
    if(maxVal(mxGetM(prhs[0]),mxGetN(prhs[0])) != maxVal(mxGetM(prhs[4]),mxGetN(prhs[4])))
        mexErrMsgTxt("wv and d must have the same length");
    if(nSentences != maxVal(mxGetM(prhs[5]),mxGetN(prhs[5])))
        mexErrMsgTxt("nCovered must length equal to nSentences");
    
    if (nlhs > 0) {
        plhs[0] = mxCreateDoubleMatrix(1,1,mxREAL);
        nFwdBwd = mxGetPr(plhs[0]);
        *nFwdBwd = 0;
        plhs[1] = mxCreateDoubleMatrix(1,1,mxREAL);
        tmp2 = *Lmax;
        Lmax = mxGetPr(plhs[1]);
        *Lmax = tmp2;
    }
    
    /* ***********************************************
     * Here is the non-uniform sampling initialization
     * *********************************************** */
    nCovered = 0;
    Lmean = 0;
    for(s = 0;s < nSentences;s++) {
        if (covered[s]!=0) {
            nCovered++;
            Lmean += Li[s];
        }
    }
    if(nCovered > 0)
        Lmean /= nCovered;
    nextpow2 = pow(2, ceil(logBase2(nSentences)/logBase2(2)));
    nLevels = 1+(int)ceil(logBase2(nSentences));
    
    LiMatrix = mxCalloc(nextpow2*nLevels,sizeof(double)); /* Sums of Lipschitz constants (including lambda) of loss over descendants */
    
    for(s=0;s<nSentences;s++) {
        if (covered[s]) 
            LiMatrix[s] = Li[s] + lambda;
    }
    levelMax = nextpow2;
    for (level=1;level<nLevels;level++) {
        levelMax = levelMax/2;
        for(s=0;s<levelMax;s++) {
            LiMatrix[s + nextpow2*level] = LiMatrix[2*s + nextpow2*(level-1)] + LiMatrix[2*s+1 + nextpow2*(level-1)];
        }
    }    
    /* ***********************************************
     * End of the non-uniform sampling initialization
     * *********************************************** */
    
    for(iter=0; iter < maxIter; iter++) {
        
        /* *****************************************
         * Non-uniform selection of training example
         * ***************************************** */
        u = randVals[iter+maxIter];
        if(iter ==0 || randVals[iter] < 0.5)
            /* Half the time we just pick a random example */
            s = floor(randVals[iter+maxIter]*nSentences);
        else {
            /* Half the time we sample according to the Lipschitz constants of the covered sentences */
            offset = 0;
            s = 0;
            Z = LiMatrix[nextpow2*(nLevels-1)];
            for(level=nLevels-1;level>=0;level--) {
                z = offset + LiMatrix[2*s + nextpow2*level];
                if(u < z/Z)
                    s = 2*s;
                else {
                    offset = z;
                    s = 2*s+1;
                }
            }
        }
        /* *****************************************
         * End non-uniform selection of training example
         * ***************************************** */
        
        /* Compute node potentials */
        nNodes = sentences[s + nSentences]-sentences[s]+1;
        /* Update needed values of w and compute nodePot */
        for(all_n) {
            for(all_state)
                nodePot(n,state) = 0;
            for(all_features) {
                feature = X[sentences[s]-1 + n + nWords*f];
                if(feature != 0) {
                    featureInd = featureStart[f]-1 + feature-1;
                    for(all_state) {
                        ind = featureInd + nFeaturesTotal*state;
                        if (iter > 0) {
                            if (lastVisited[ind]==0)
                                w[ind] -= d[ind]*cumSum[iter-1];
                            else if (lastVisited[ind] != iter)
                                w[ind] -= d[ind]*(cumSum[iter-1]-cumSum[lastVisited[ind]-1]);
                            lastVisited[ind] = iter;
                        }
                        nodePot(n,state) += w[ind];
                    }
                }
            }
        }
        /* Add beginning/end of sentence modification */
        for(all_state) {
            nodePot(0,state) += v_start[state];
            nodePot(nNodes-1,state) += v_end[state];
        }
        /* Compute part of objective function based on true labels
           (needed for line-search variant */
        fi = 0;
        for(all_n) {
            state = y[sentences[s]-1 + n]-1;
            fi -= rho*nodePot(n,state);
        }
        for(n = 0; n < nNodes-1; n++) {
            state1 = y[sentences[s]-1 + n]-1;
            state2 = y[sentences[s]-1 + n+1]-1;
            fi -= rho*v(state1,state2);
        }
        /* Exponentiate Potentials */
        for(all_n) {
            for(all_state) {
                nodePot(n,state) = exp(rho*nodePot(n,state));
            }
        }
        /* Edge potentials */
        for(all_state2) {
            for(all_state1) {
                edgePot(state1,state2) = exp(rho*v(state1,state2));
            }
        }
        
        
        /* Forward-Backward to Compute Marginals */
        if (nlhs > 0)
            (*nFwdBwd)++;
        for(all_n)
            kappa[n] = 0;
        for(all_state) {
            alpha(0,state) = nodePot(0,state);
            kappa[0] += alpha(0,state);
        }
        for(all_state)
            alpha(0,state) /= kappa[0];
        /* Forward Pass */
        for(n = 1; n < nNodes; n++) {
            for(all_state1) {
                for(all_state2)
                    tmp(state1,state2) = alpha(n-1,state1) * edgePot(state1,state2);
            }
            for(all_state2) {
                alpha(n,state2) = 0;
                for(all_state1)
                    alpha(n,state2) += tmp(state1,state2);
                alpha(n,state2) = nodePot(n,state2) * alpha(n,state2);
                kappa[n] += alpha(n,state2);
            }
            for(all_state)
                alpha(n,state) /= kappa[n];
        }
        /* Backward Pass */
        for(all_state)
            beta(nNodes-1,state) = 1;
        for(n = nNodes-2; n >= 0; n--) {
            for(all_state1) {
                for(all_state2)
                    tmp(state1,state2) = nodePot(n+1,state2)*edgePot(state1,state2)*beta(n+1,state2);
            }
            Z = 0;
            for(all_state1) {
                beta(n,state1) = 0;
                for(all_state2)
                    beta(n,state1) += tmp(state1,state2);
                Z += beta(n,state1);
            }
            for(all_state)
                beta(n,state) /= Z;
        }
        
        /* Compute beliefs for use in gradient update */
        /* Node Beliefs */
        for(all_n) {
            Z = 0;
            for(all_state) {
                nodeBel(n,state) = alpha(n,state)*beta(n,state);
                Z += nodeBel(n,state);
            }
            for(all_state)
                nodeBel(n,state) /= Z;
        }
        /* Edge Beliefs */
        for(n = 0; n < nNodes-1; n++) {
            Z = 0;
            for(all_state1) {
                for(all_state2) {
                    tmp(state1,state2) = alpha(n,state1)*nodePot(n+1,state2)*beta(n+1,state2)*edgePot(state1,state2);
                    Z += tmp(state1,state2);
                }
            }
            for(all_state1) {
                for(all_state2) {
                    edgeBel(state1,state2,n) = tmp(state1,state2)/Z;
                }
            }
        }
        
        /* Update directions and number of examples we've seen*/
        if(!covered[s]) {
            /* Component of direction from observed data */
            for(all_n) {
                for(all_features) {
                    feature = X[sentences[s]-1 + n + nWords*f];
                    if(feature != 0) {
                        featureInd = featureStart[f]-1 + feature-1;
                        state = y[sentences[s]-1 +n]-1;
                        dW[featureInd + nFeaturesTotal*state] -= 1;
                    }
                }
            }
            dV_start[y[sentences[s]-1]-1] -= 1;
            dV_end[y[sentences[s]-1 + nNodes-1]-1] -= 1;
            for(n = 0; n < nNodes-1; n++) {
                state1 = y[sentences[s]-1 +n]-1;
                state2 = y[sentences[s]-1 +n+1]-1;
                dV(state1,state2) -= 1;
            }
        }
        /* Component of direction from expectations and store beliefs */
        for(all_n) {
            for(all_features) {
                feature = X[sentences[s]-1 + n + nWords*f];
                if(feature != 0) {
                    featureInd = featureStart[f]-1 + feature-1;
                    for(all_state) {
                        dW[featureInd + nFeaturesTotal*state] -= NB(sentences[s]-1+n,state) - nodeBel(n,state);
                    }
                }
            }
        }
        for(all_state) {
            dV_start[state] -= NB(sentences[s]-1,state) - nodeBel(0,state);
            dV_end[state] -= NB(sentences[s]-1+nNodes-1,state) - nodeBel(nNodes-1,state);
        }
        for(all_n) {
            for(all_state) {
                NB(sentences[s]-1+n,state) = nodeBel(n,state);
            }
        }
        for(all_state1) {
            for(all_state2) {
                dV(state1,state2) -= EB(state1,state2,s);
            }
        }
        for(all_state1) {
            for(all_state2) {
                EB(state1,state2,s) = 0;
            }
        }
        for(n = 0; n < nNodes-1; n++)
        {
            for(all_state1) {
                for(all_state2) {
                    EB(state1,state2,s) += edgeBel(state1,state2,n);
                }
            }
        }
        for(all_state1) {
            for(all_state2) {
                dV(state1,state2) += EB(state1,state2,s);
            }
        }
        
        /* ****************************
         * Here is the line-search part
         * **************************** */
        
        Li_old = Li[s];
        if (covered[s])
            Li[s] *= 0.9;
        else if (iter != 0) /* This could probably be changed to nCovered > 0 */
        {
			Li[s] = Lmean/2;
        }
        
        /* Compute gradient and squared norm of gradient */
        gg = 0;
        for(all_n) {
            for(all_features) {
                feature = X[sentences[s]-1 + n + nWords*f];
                if(feature != 0) {
                    featureInd = featureStart[f]-1 + feature-1;
                    for(all_state) {
                        gg -= gW[featureInd + nFeaturesTotal*state]*gW[featureInd + nFeaturesTotal*state];
                    }
                    for(all_state) {
                        gW[featureInd + nFeaturesTotal*state] += nodeBel(n,state);
                    }
                    gW[featureInd + nFeaturesTotal*(y[sentences[s]-1 +n]-1)] -= 1;
                    for(all_state) {
                        gg += gW[featureInd + nFeaturesTotal*state]*gW[featureInd + nFeaturesTotal*state];
                    }
                }
            }
        }
        for(all_state) {
            gV_start[state] = nodeBel(0,state) - (state == y[sentences[s]-1]-1);
            gV_end[state] = nodeBel(nNodes-1,state) - (state == y[sentences[s]-1 + nNodes-1]-1);
            gg += gV_start[state]*gV_start[state];
            gg += gV_end[state]*gV_end[state];
        }
         for(n = 0; n < nNodes-1; n++)
        {
            state1 = y[sentences[s]-1 +n]-1;
            state2 = y[sentences[s]-1 +n+1]-1;
            gV[state1 + nStates*state2] -= 1;
        }
        for(all_state1) {
            for(all_state2) {
                gV[state1 + nStates*state2] += EB(state1,state2,s);
            }
        }
        for(all_state1) {
            for(all_state2) {
                gg += gV[state1 + nStates*state2]*gV[state1 + nStates*state2];
            }
        }
        
        /* Do the line-search if gradient of example is non-trivial */
        if (gg > eps) {
            /* Compute function value for example */
            logZ = 0;
            for(all_n)
                logZ += log(kappa[n]);
            fi += logZ;
            
            while (1) /* Loop breaks when satisfactory Li[s] value is found */
            {
                /* Compute function value if we took a step of 1/L in gradient direction
                 * (mostly copied and pasted from the above code) */
                for(all_n) {
                    for(all_state)
                        nodePot(n,state) = 0;
                    for(f = 0; f < nFeatureTypes; f++) {
                        feature = X[sentences[s]-1 + n + nWords*f];
                        if(feature != 0) {
                            featureInd = featureStart[f]-1 + feature-1;
                            for(all_state) {
                                nodePot(n,state) += rho*w[featureInd + nFeaturesTotal*state] - gW[featureInd + nFeaturesTotal*state]/Li[s];
                            }
                        }
                    }
                }
                for(all_state) {
                    nodePot(0,state) += rho*v_start[state] - gV_start[state]/Li[s];
                    nodePot(nNodes-1,state) += rho*v_end[state] - gV_end[state]/Li[s];
                }
                fi_new = 0;
                for(all_n) {
                    state = y[sentences[s]-1 + n]-1;
                    fi_new -= nodePot(n,state);
                }
                for(n = 0; n < nNodes-1; n++) {
                    state1 = y[sentences[s]-1 + n]-1;
                    state2 = y[sentences[s]-1 + n+1]-1;
                    fi_new -= rho*v(state1,state2) - gV[state1 + nStates*state2]/Li[s];
                }
                for(all_n) {
                    for(all_state) {
                        nodePot(n,state) = exp(nodePot(n,state));
                    }
                }
                for(all_state2) {
                    for(all_state1) {
                        edgePot(state1,state2) = exp(rho*v(state1,state2) - gV[state1 + nStates*state2]/Li[s]);
                    }
                }
                if (nlhs > 0)
                    (*nFwdBwd)++;
                for(all_n)
                    kappa[n] = 0;
                for(all_state) {
                    alpha(0,state) = nodePot(0,state);
                    kappa[0] += alpha(0,state);
                }
                for(all_state)
                    alpha(0,state) /= kappa[0];
                for(n = 1; n < nNodes; n++) {
                    for(all_state1) {
                        for(all_state2)
                            tmp(state1,state2) = alpha(n-1,state1) * edgePot(state1,state2);
                    }
                    for(all_state2) {
                        alpha(n,state2) = 0;
                        for(all_state1)
                            alpha(n,state2) += tmp(state1,state2);
                        alpha(n,state2) = nodePot(n,state2) * alpha(n,state2);
                        kappa[n] += alpha(n,state2);
                    }
                    for(all_state)
                        alpha(n,state) /= kappa[n];
                }
                
                logZ = 0;
                for(all_n)
                    logZ += log(kappa[n]);
                fi_new += logZ;
                
                if (fi_new <= fi - gg/(2*Li[s]))
                    break;
                else
                    Li[s] *= 2;
                
            }
        }
        
        /* Reset gradients */
        for(all_n) {
            for(all_features) {
                feature = X[sentences[s]-1 + n + nWords*f];
                if(feature != 0) {
                    featureInd = featureStart[f]-1 + feature-1;
                    for(all_state) {
                        gW[featureInd + nFeaturesTotal*state] = 0;
                    }
                }
            }
        }
        for(all_state1) {
            for(all_state2) {
                gV[state1 + nStates*state2] = 0;
            }
        }
        
        /* ****************************
         * End of the line-search part
         * **************************** */
        
        /* ************************************************
         * Updating data structures and step sizes
         * ************************************************ */
        
        if (Li[s] > *Lmax)
            *Lmax = Li[s];
        
        if (covered[s]==0) {
            covered[s]=1;
            nCovered++;
            Lmean = Lmean*((double)(nCovered-1)/(double)nCovered) + Li[s]/(double)nCovered;
            
            /* Update LiMatrix so we sample this guy proportional to its Lipschitz constant*/
            ind = s;
            for(level=0;level<nLevels;level++)
            {
                LiMatrix[ind + nextpow2*level] += Li[s] + lambda;
                ind = ind/2;
            }
        }
        else if (Li[s] != Li_old) {
            Lmean = Lmean + (Li[s] - Li_old)/(double)nCovered;
            /* Update LiMatrix with the new estimate of the Lipscitz constant */
            ind = s;
            for(level=0;level<nLevels;level++)
            {
                LiMatrix[ind + nextpow2*level] += (Li[s] - Li_old);
                ind = ind/2;
            }
        }
        stepSize = (1/(*Lmax + lambda) + 1/(Lmean + lambda))/2;
        
        /* Update Parameters */
        /* For lazy updates of node parameters */
        rho *= 1-stepSize*lambda;
        if (iter == 0)
            cumSum[0] = stepSize/(rho*nCovered);
        else
            cumSum[iter] = cumSum[iter-1] + stepSize/(rho*nCovered);
        
        /* Update edge parameters explicitly */
        step = stepSize/(rho*nCovered);
        for(all_state) {
            v_start[state] -= step*dV_start[state];
            v_end[state] -= step*dV_end[state];
        }
        for(all_state1) {
            for(all_state2) {
                v(state1,state2) -= step*dV(state1,state2);
            }
        }
        
        /* The below line is new for the line-search variant: */
        *Lmax *= pow(2.0,-1.0/nSentences);
        
    }
    
    /* Get final values of w */
    for(f = 0; f < nFeaturesTotal*nStates; f++) {
        if (lastVisited[f]==0)
            w[f] -= d[f]*cumSum[maxIter-1];
        else
            w[f] -= d[f]*(cumSum[maxIter-1]-cumSum[lastVisited[f]-1]);
    }
    
    /* Scale variables to get final result */
    for(f = 0; f < nParams; f++)
        wv[f] = wv[f]*rho;   
    
    mxFree(nodePot);
    mxFree(edgePot);
    mxFree(alpha);
    mxFree(beta);
    mxFree(tmp);
    mxFree(nodeBel);
    mxFree(edgeBel);
    mxFree(kappa);
    mxFree(lastVisited);
    mxFree(cumSum);
    mxFree(gW);
    mxFree(gV_start);
    mxFree(gV_end);
    mxFree(gV);
    mxFree(LiMatrix);
}
