#include <math.h>
#include "mex.h"

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    /* Variables */
    int n, s,e,n1,n2,nNodes, *clamped,*edgeEnds, nStates, nEdges, *y, done, sizeEdgeBel[3];
    double *bias,*weights,Z=0,*nodeBel,*edgeBel,logPot,pot,*logZ;
    
    /* Input */
    clamped = mxGetPr(prhs[0]);
    nNodes = mxGetScalar(prhs[1]);
    edgeEnds = mxGetPr(prhs[2]);
    nStates = mxGetScalar(prhs[3]);
    bias = mxGetPr(prhs[4]);
    weights = mxGetPr(prhs[5]);
    
    /* Compute Sizes */
    nEdges = mxGetDimensions(prhs[2])[0];
    
    /* Output */
    plhs[0] = mxCreateDoubleMatrix(nNodes,nStates,mxREAL);
    nodeBel = mxGetPr(plhs[0]);
    
    if(nlhs >= 2)
    {
        sizeEdgeBel[0] = nStates;
        sizeEdgeBel[1] = nStates;
        sizeEdgeBel[2] = nEdges;
        plhs[1] = mxCreateNumericArray(3,sizeEdgeBel,mxDOUBLE_CLASS,mxREAL);
        edgeBel = mxGetPr(plhs[1]);
    }
        
    if(nlhs == 3)
    {
        plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL);
        logZ = mxGetPr(plhs[2]);
        *logZ = log(Z);
    }
    
    y = mxCalloc(nNodes,sizeof(int));
    
    for(n=0;n < nNodes;n++)
    {
        if (clamped[n]!=0)
            y[n] = clamped[n]-1;
    }
    
    while(1)
    {
        
      /* Compute logPot */
        logPot = 0;
        for(n=0;n < nNodes;n++)
        {
            if(clamped[n]==0)
                logPot += bias[n + nNodes*(y[n])];
        }
        for(e=0;e < nEdges;e++)
        {
            n1 = edgeEnds[e]-1;
            n2 = edgeEnds[e+nEdges]-1;
            if(clamped[n2]==0)
                logPot += weights[y[n1] + nStates*(y[n2] + nStates*e)];
        }
        pot = exp(logPot);
        
        /* Update Z */
        Z += pot;
        
        /* Update nodeBel */
        for(n=0;n < nNodes;n++)
        {
            nodeBel[n + nNodes*(y[n])] += pot;
        }
        
        if(nlhs >= 2)
        {
	  /*  Update edgeBel */
            for(e=0; e < nEdges;e++)
            {
                n1 = edgeEnds[e]-1;
                n2 = edgeEnds[e+nEdges]-1;
                if(clamped[n2]==0)
                    edgeBel[y[n1] + nStates*(y[n2] + nStates*e)] += pot;
            }
        }
        
        /* Go to next state */
        done = 1;
        for(n=0;n < nNodes;n++)
        {
            if(clamped[n]==0) {
                
                if(y[n] < nStates-1)
                {
                    y[n]++;
                    done = 0;
                    break;
                }
                else
                {
                    y[n] = 0;
                }
            }
        }
        
        if(done) {
            break;
        }
    }
    
    for(n=0;n<nNodes*nStates;n++)
        nodeBel[n] /= Z;
    
    if(nlhs >= 2)
    {
        for(e=0;e<nStates*nStates*nEdges;e++)
            edgeBel[e] /= Z;
    }
    
    if(nlhs == 3)
    {
        *logZ = log(Z);
    }
    
    mxFree(y);
}
