% Implementation of the gridword from Sutton & Barto 3.8 for the 
% model-based algorithms Value Iteration and Policy Iteration.
%
%   modelDemo('policy') for policy iteration.
%   modelDemo('value') for value iteration.
function modelDemo(algo)

gamma = 0.5;

% the world is 5x5 with 4 actions from each state; good 
% candidate for sparsification
P = zeros(25,25,4);
R = zeros(25,25,4);

% ACTION 1 = North
% moving off the grid produces no motion and incurs R = -1
for s = 1:5
    P(s,s,1) = 1;
    R(s,s,1) = -1;
end
% otherwise state = state - 5
for s = 6:25
    P(s,s-5,1) = 1;
end

% ACTION 2 = EAST
for s = 1:25
    if mod(s,5) == 0
        P(s,s,2) = 1;
        R(s,s,2) = -1;
    else
        P(s,s+1,2) = 1;
    end
end

% ACTION 3 = SOUTH
% moving off the grid produces no motion and incurs R = -1
for s = 21:25
    P(s,s,3) = 1;
    R(s,s,3) = -1;
end
% otherwise state = state + 5
for s = 1:20
    P(s,s+5,3) = 1;
end

% ACTION 4 = WEST
for s = 1:25
    if mod(s,5) == 1
        P(s,s,4) = 1;
        R(s,s,4) = -1;
    else
        P(s,s-1,4) = 1;
    end
end

% moving into state 2 gets +10 reward and moves to state 22
P(2,:,:) = zeros(25,4); % clear old values
P(2,22,:) = 1;
R(2,22,:) = 2;

% movinf to state 4 gets +5 reward and moves to state 14
P(4,:,:) = zeros(25,4);
P(4,14,:) = 1;
R(4,14,:) = 10;

% now call value iteration
[pi,V] = eval([algo,'Iteration(P,R,gamma)']);
TheValueFunction = reshape(V,5,5)'
ThePolicy = reshape(pi,5,5)'
figure(1);
clf;
axis([1 6 1 6]);
hold on;

for i = 2:5
    plot([1 6],[i i],'k:');
    plot([i i],[1 6],'k:');
end

for x = 1:5
    for y = 1:5
        X = [x x+1 x+1 x  ];
        Y = [y y   y+1 y+1];
        doFill(X,Y,V(x + (y-1) * 5));
    end
end

% now plot pi
for x = 1:5
    for y = 1:5
        startx = x + 0.5;
        starty = (6-y) + 0.5;
        destx = startx;
        desty = starty;
        switch pi(x + (y-1)*5)
        case 1
            desty = desty + 0.5;
        case 2
            destx = destx + 0.5;
        case 3
            desty = desty - 0.5;
        case 4
            destx = destx - 0.5;
        end
        if and(x==2,y==1)
            % do nothing
        elseif and(x==4,y==1)
            % do nothing
        else        
            plot_arrow(startx, starty, destx, desty);
        end    
    end
end



function doFill(X,Y,q)

if q < 0
    fill(X,7-Y,[1/(1+exp(q)) 0 0]);
elseif q > 0
    fill(X,7-Y,[0 1/(1+exp(-q)) 0]);
end


