% ActiveLearn(X, Y, sigma)
% Written by: Eric Brochu.
%
%   An implementation of active learning for semi-supervised data, as
%   described in these two papers:
%
%       Zhu, Ghahramani and Lafferty.  2003.  Semi-Supervised Learning 
%       Using Gaussian Fields and Harmonic Functions.
%
%       Zhu, Lafferty and Ghahramani. 2003. Combining Active Learning
%       and Semi-Supervised Learning Using Gaussian Fields and Harmonic
%       Functions.
%
%   Parameters:
%   
%       X       The two-dimensional data set, with each datum in a row.
%
%       Y       The initial labels of X.  These correspond to the data with
%               the same indices, so the labeled data must be the first data
%               of X.  The only allowable labels are {0, 1}.  At least one
%               instance of each of {1, 0} must occur in Y.
%
%               For example, X=[1 0; 2 1; 3 0], Y=[0; 1] is a set of two
%               labelled data at (1, 0) and (2, 1) and an unlabelled datum
%               at (3, 0).
%
%       sigma   Kernel width parameters.  Must be a 2D vector.  The 
%               algorithm is very sensitive to this.  Zhu, Ghahramani and 
%               Lafferty gives a way of learning it.
%
function ActiveLearn(X, Y, sigma)

fid=fopen('testdat','w');
write_data(fid,X);
write_data(fid,Y);

% number of data
n_d = size(X,1);

% number of labeled data
n_l = length(Y);

% Compute the weight matrix. Any similarity method will do, but this is the
% RBF one used in [Zhu, Lafferty and Ghahramani 2003].
for i=1:n_d
    for j = 1:i
        W(i,j) = exp( -1 * sum( (( X(i,:)-X(j,:) ).^2) ./(sigma.^2) ));
        W(j,i) = W(i,j);
    end
end

% Compute the diagonal matrix D.
D = sparse(size(W));
for i = 1:n_d
    D(i,i) = sum(W(i,:));
end

% the Laplacian is just D-W
delta_full = D-W;
delta_ll = delta_full(1:n_l, 1:n_l);
delta_lu = delta_full(1:n_l, n_l+1:n_d);
delta_ul = delta_full(n_l+1:n_d, 1:n_l);
delta_uu = delta_full(n_l+1:n_d, n_l+1:n_d);
delta_uu_inv = delta_uu^-1;

write_data(fid,delta_ul);
write_data(fid,delta_uu);
write_data(fid,delta_uu_inv);
fclose(fid);

% now we can do the active learning!
while (n_l ~= n_d-1)

    % we can now get the minimum energy function
    f(1:n_l) = Y(1:n_l);
    f(n_l+1:n_d) = -delta_uu_inv * (delta_ul * Y(1:n_l));
    
    % compute the risk for each unlabeled datum
    clear R;
    fu = f(n_l+1:n_d);
    for i = 1:n_d-n_l
        k = i + n_l;
        k0 = fu_plus(fu, 0, delta_uu_inv, i);
        k1 = fu_plus(fu, 1, delta_uu_inv, i);
        R(i) = estimated_risk(f(k), k0, k1);
    end
    
    % the one that minimizes the risk is the one to ask for
    [val, ku] = min(R);

    % ku is the index into the unlabelled data, 
    % k is the index into all data
    k = ku + n_l;

    % now, plot this bad boy!
    clf;
    figure(1);
    hold on;
    for i=1:n_d
        if f(i) > 0.5
            plot(X(i,1),X(i,2),'.b');
        else
            plot(X(i,1),X(i,2),'.r');
        end
    end
    for i=1:n_l;
        if Y(i) == 0
            plot(X(i,1),X(i,2),'or');
        else
            plot(X(i,1),X(i,2),'ob');
        end
    end
    plot(X(k,1),X(k,2),'+g');
    plot(X(k,1),X(k,2),'og');
    hold off;
    response = input('label for + [0=in, 1=out]: ');
    
    % and now, update the blocks of the Laplacian
    n_u = n_d-n_l;
    
    % update the inverse using the method of ZLG Appendix B
    delta_uu_inv = updateInverse(delta_uu, delta_uu_inv, ku);
    
    % copy the column from duu to dul and remove the row
    delta_ul(:,n_l+1) = delta_uu(:,ku);
    delta_ul = delta_ul([1:ku-1 ku+1:end],:);

    % now, remove the row and column from duu
    delta_uu = delta_uu([1:ku-1 ku+1:end],:);
    delta_uu = delta_uu(:,[1:ku-1 ku+1:end]);

    % and add the label the user gave us
    n_l = n_l+1;
    Y(n_l) = response;
    if response > 1 | response < 0
        break;
    end

    % and move the order of the data around to reflect all this
    Xtemp = X(k,:);
    X(n_l+1:k,:) = X(n_l:k-1,:);
    X(n_l,:) = Xtemp;
end



%=======================================================
function newf = fu_plus(fu, y, delta_uu_inv, k)

n = length(fu);
d_k = delta_uu_inv(:,k);
d_kk = delta_uu_inv(k,k);
fk = fu(k);

% remove datum k
d_k = d_k([1:k-1 k+1:n]);
fu = fu([1:k-1 k+1:n]);

newf = fu' + (y-fk) .* (d_k ./ d_kk);



%=======================================================
function r_est = estimated_risk(fk, k0, k1)

r0 = 0;
r1 = 0;

for i=1:length(k0)
    if k0(i) < (1 - k0(i))
        r0 = r0 + k0(i);
    else
        r0 = r0 + (1-k0(i));
    end
    
    if k1(i) < (1 - k1(i))
        r1 = r1 + k1(i);
    else
        r1 = r1 + (1-k1(i));
    end
end

r_est = (1-fk) * r0 + fk * r1;



%=======================================================
% From A and A^-1, compute the inverse of A with 
% row/column i removed.  From [Zhu, Ghahramamni and 
% Lafferty 2003].  The possibility of numerical error
% exists, though: we might not get the same result as
% inverting.
%=======================================================
function newInv = updateInverse(A, Ainv, i)

n = size(Ainv,1);

B = perm(A,i);

u = zeros(n,1);
u(1) = -1;

v = B(1,:);
v(1) = v(1) - 1;
v = v';

w = B(:,1);
w(1) = 0;

% B^-1 = perm(A^-1,i)
Binv = perm(Ainv,i);

% (B')^-1
Bpinv = Binv - (Binv * u) * (v' * Binv) / (1 + v' * Binv * u);

% (B'')^-1
Bppinv = Bpinv - (Bpinv * w) * (u' * Bpinv) / (1 + u' * Bpinv * w);

newInv = Bppinv([2:n],[2:n]);


%=======================================================
% Move row/column i to before first row/column of matrix.
%=======================================================
function permA = perm(A,k)

temp = A(:,k);
A(:,2:k) = A(:,1:k-1);
A(:,1) = temp;

temp = A(k,:);
A(2:k,:) = A(1:k-1,:);
A(1,:) = temp;

permA = A;


function write_data(fid,data)

fprintf(fid, '===\n');
fprintf(fid, '%d\n', size(data,1));
fprintf(fid, '%d\n', size(data,2));
for i=1:size(data,1)
    for j=1:size(data,2)
        fprintf(fid, '%f\n', data(i,j));
    end
end
