%% Load data
clear all
load('ocr.mat');
addBias = 1;
[data,w] = crfChain_prepareData(X,y,addBias);
dataTest = data;

% Split into training and testing sentences
trainNdx = fold(data.sentences(:,1))~=0;
testNdx = fold(data.sentences(:,1))==0;
data.sentences = data.sentences(trainNdx,:);
dataTest.sentences = dataTest.sentences(testNdx,:);

%% Parameters of Optimization
n = size(data.sentences,1); % Number of training exmaples
p = length(w); % Dimension of parameter vector
funObjTest = @(w)crfChain_testErrC(w,dataTest); % Testing objective
optTol = 1e-3; % Tolerance for optimality

%% Initialization of uniform sampling method
L = 1; % Initial guess of Lipschitz constant
NB = zeros(data.nWords,data.nStates); % Old values of node marginals
EB = zeros(data.nStates,data.nStates,n); % Old values of sum of edges marginals
d = zeros(p,1); % Initial sum of gradient approximations
covered = int32(zeros(n,1)); % Examples we have already visited
w = zeros(p,1);

%% Solve for sequence of regularization parameters
evals = int32(0);
testOld = inf;
for lambda = 2.^[5:-1:-5]/n
    fprintf('Training with lambda = %.2f/n using SAG...\n',lambda*n);
    funObj = @(w)penalizedL2(w,@crfChain_lossC,n*lambda/2,data); % Training objective 
    optCond = norm(d/n + lambda*w);
    while 1
        iVals = int32(ceil(rand(n,1)*n));
        [subEvals,L] = crfChain_SAGC_LS(w,data,lambda,iVals,d,covered,NB,EB,L); % Modifies everything in place, returns number of forward-backward calls
        evals = evals + subEvals;
        nCovered = sum(covered);
        if nCovered == n
            optCond = norm(d/n + lambda*w);
            fprintf('passes = %.2f, optCond = %.3f\n',double(evals)/n,optCond);
            if optCond < optTol
                break;
            end
        else
            fprintf('passes = %.2f, nCovered = %d\n',double(evals)/n,nCovered);
        end
    end
    
    testErrs = funObjTest(w);
    fprintf('Test error = %.3f\n\n',testErrs(1));
    if testErrs(1) < testOld
        testOld = testErrs(1);
        wOld = w;
    else
        fprintf('Test error increased, stopping\n');
        w = wOld;
        break
    end
end


%% Initialization of non-uniform sampling method with skipping
Li = ones(n,1); % Initial guess of each individual Lipschitz constant
Lmax = 1; % Initial guess of maximum Lipschitz constant
NB = zeros(data.nWords,data.nStates); % Old values of node marginals
EB = zeros(data.nStates,data.nStates,n); % Old values of sum of edges marginals
d = zeros(p,1); % Initial sum of gradient approximations
covered = int32(zeros(n,1)); % Examples we have already visited
passes = int32(zeros(n,1)); % Number of consecutive times line-search was satisfied
skip = int32(zeros(n,1)); % Number of times to skip line-search for this example
w = zeros(p,1);

%% Solve for sequence of regularization parameters
evals = int32(0);
testOld = inf;
for lambda = 2.^[5:-1:-5]/n
    fprintf('Training with lambda = %.2f/n with SAG-NUS*...\n',lambda*n);
    funObj = @(w)penalizedL2(w,@crfChain_lossC,n*lambda/2,data); % Training objective 
    optCond = norm(d/n + lambda*w);
    passes = int32(zeros(n,1)); % Number of consecutive times line-search was satisfied
    skip = int32(zeros(n,1)); % Number of times to skip line-search for this example
    while 1
        randVals = rand(n,2);
        [subEvals,Lmax] = crfChain_SAGC_LipschitzSimpleSkip(w,data,lambda,randVals,d,covered,NB,EB,Lmax,Li,passes,skip); % Modifies everything in place, returns number of forward-backward calls
        evals = evals + subEvals;
        nCovered = sum(covered);
        if nCovered == n
            optCond = norm(d/n + lambda*w);
            fprintf('passes = %.2f, optCond = %.3f\n',double(evals)/n,optCond);
            if optCond < optTol
                break;
            end
        else
            fprintf('passes = %.2f, nCovered = %d\n',double(evals)/n,nCovered);
        end
    end
    
    testErrs = funObjTest(w);
    fprintf('Test error = %.3f\n\n',testErrs(1));
    if testErrs(1) < testOld
        testOld = testErrs(1);
        wOld = w;
    else
        fprintf('Test error increased, stopping\n');
        w = wOld;
        break
    end
end
