%% Setting things up and some basic operations

clear all

% Turn on using .mex files to speed things up
useMex = 1;

% We will use a simple demo with 5 nodes, where each node can take 3 states
nNodes = 5;
nStates = 3;

% Make a random adjacency matrix
adj = rand(nNodes) > .5;
adj = double(setdiag(adj,0));

% Make the edgeEnds structure from the adjacency matrix
%(each row contains the two nodes on the end-points of an edge)
edgeEnds = DCG_makeEdgeEnds(adj)
nEdges = size(edgeEnds,1);

% Make some random parameters (in the exponential family parameterization)
[bias,weights] = DCG_initWeights(nNodes,edgeEnds,nStates,1)

% Evaluate potential of a random configuration
y = ceil(rand(nNodes,1)*nStates);
logPot = DCG_logPotential(nNodes,edgeEnds,nStates,y,bias,weights)
pot = exp(logPot)

% Compute normalizing constant
Z = DCG_computeZ(nNodes,edgeEnds,nStates,bias,weights)

% Evaluate the probability of a configuration
y = ceil(rand(nNodes,1)*nStates);
logPot = DCG_logPotential(nNodes,edgeEnds,nStates,y,bias,weights);
pot = exp(logPot)
prob = pot/Z

% Compute node marginals
[nodeBel,edgeBel] = DCG_infer(nNodes,edgeEnds,nStates,bias,weights,Z)

%% Generating observational samples and maximum-likelihood estimation from observational data

% Generate observational samples from the model
nSamples = 50;
samples = DCG_sample(nNodes,edgeEnds,nStates,bias,weights,nSamples,Z);

% Compute the negative log-likelihood of the samples
nll = DCG_nll(samples,nNodes,edgeEnds,nStates,[bias(:);weights(:)],useMex)

% Optimize weights based on the samples
[bias,weights] = DCG_initWeights(nNodes,edgeEnds,nStates,0); % Initialize to zero

funObj = @(w)DCG_nll(samples,nNodes,edgeEnds,nStates,w,useMex); % Objective function

[w,fval]=minFunc(funObj,[bias(:);weights(:)],[]); % Optimize

% Split vector weights to get maximum likelihood estimate
bias = reshape(w(1:nNodes*nStates),nNodes,nStates);
weights = reshape(w(nNodes*nStates+1:end),nStates,nStates,nEdges);

%% Basic operations with interventions

% Randmoly initialize weights
[bias,weights] = DCG_initWeights(nNodes,edgeEnds,nStates,1);

% Set up an intervention
clamped = zeros(nNodes,1);
clamped(1) = 2; % Force variable 1 to take state 2 
clamped(3) = 1; % and force variable 3 to take state 3

% Compute the log-potential under the intervention distribution
clamped_logPot = DCG_intervLogPotential(clamped,nNodes,edgeEnds,nStates,y,bias,weights)

% Compute the normalizing constant under the interventional distribution
clamped_Z = DCG_intervComputeZ(clamped,nNodes,edgeEnds,nStates,bias,weights)

% Compute marginals under the interventional distribution
[clamped_nodeBel,clamped_edgeBel] = DCG_intervInfer(clamped,nNodes,edgeEnds,nStates,bias,weights,clamped_Z)

% Generate interventional samples
clamped_samples = DCG_intervSample(clamped,nNodes,edgeEnds,nStates,bias,weights,nSamples,clamped_Z)

% Maximum likelihood estimation with interventional data
[bias,weights] = DCG_initWeights(nNodes,edgeEnds,nStates,0); % Initialize to zero
funObj = @(w)DCG_intervNLL(clamped,clamped_samples,nNodes,edgeEnds,nStates,w,useMex);
[w,fval]=minFunc(funObj,[bias(:);weights(:)]);

% Split vector weights into bias and weights 
bias = reshape(w(1:nNodes*nStates),nNodes,nStates);
weights = reshape(w(nNodes*nStates+1:end),nStates,nStates,nEdges);

%% Training when different interventions are done on different samples

% Make a data set where a random intervention is done at 25% of the
% variables in each sample
clampedMatrix = zeros(nSamples,nNodes);
clampedSamples = zeros(nSamples,nNodes);
for s = 1:nSamples
    for n = 1:nNodes
        if rand < .25
           clampedMatrix(s,n) = ceil(rand*nStates); 
        end
    end
    clampedSamples(s,:) = DCG_intervSample(clampedMatrix(s,:),nNodes,edgeEnds,nStates,bias,weights,1);
end

% Training w/ clamped matrix
[bias,weights] = DCG_initWeights(nNodes,edgeEnds,nStates,0); % Initialize to zero
funObj = @(w)DCG_intervNLL2(clampedMatrix,clampedSamples,nNodes,edgeEnds,nStates,w,useMex);
[w,fval] = minFunc(funObj,[bias(:);weights(:)]);
bias = reshape(w(1:nNodes*nStates),nNodes,nStates);
weights = reshape(w(nNodes*nStates+1:end),nStates,nStates,nEdges);