% Computational Intelligence: a logical approach. 
% Prolog Code.
% BELIEF NETWORK INTERPRETER (Revised)
% Copyright (c) 1998, Poole, Mackworth, Goebel and Oxford University Press.

% This is much more efficient that the code in appendix C.

% A belief network is represented with the relations
% variables(Xs) Xs is the list of random variables.
%   Xs is ordered: parents of node are before the node.
% parents(X,Ps) Ps list of parents of variable X.
%   Ps is ordered consistently with Xs
% values(X,Vs) Vs is the list of values of X
% pr(X,As,D) X is a variable, As is a list of Pi=Vi where
%   Pi is a parent of X, and Vi is a value for variable Pi
%   The elements of As are ordered consistently with Ps.


% p(Var,Obs,Dist) is true if Dist represents the
% probability distribution of P(Var|Obs)
% where Obs is a list of Vari=Vali. Var is not observed.
p(Var,Obs,VDist) :-
   relevant(Var,Obs,RelVars),
   to_sum_out(RelVars,Var,Obs,SO),
   joint(RelVars,Obs,Joint),
   sum_out_each(SO,Joint,Dist),
   collect(Dist,DT0),
   normalize(DT0,0,_,VDist).

% relevant(Var,Obs,RelVars) Relvars is the relevant
% variables given query Var and observations Obs.
% This is the most conservative.
relevant(_,_,Vs) :-
   variables(Vs).       % all variables are relevant

% to_sum_out(Vs,Var,Obs,SO), 
%   Given all variables Vs, query variable Var
% and observations Obs, S0 specifies the elimination
% ordering. Here, naively, the elimination ordering
% is the same as variable ordering
to_sum_out(Vs,Var,Obs,SO) :-
   remove(Var,Vs,RVs),
   remove_each_obs(Obs,RVs,SO).

% remove_each_obs(Obs,RVs,SO) removes each of the
% observation variables from RVs resulting in SO.
remove_each_obs([],SO,SO).
remove_each_obs([X=_|Os],Vs0,SO) :-
   remove_if_present(X,Vs0,Vs1),
   remove_each_obs(Os,Vs1,SO).

/* A joint probability distribution is represented
as a list of distribution trees, of the form
         dtree(Vars,DTree) 
where Vars is a list of Variables (ordered
consistently with the ordering of variables), and
DTree is tree representation for the function from
values of variables into numbers such that if
Vars=[] then DTree is a number. Otherwise
Vars=[Var|RVars], and DTree is a list with one
element for each value of Var, and each element
is a tree representation for RVars. The ordering
of the elements in DTree is given by the ordering
of Vals given by values(Var,Vals). */

% joint(Vs,Obs,Joint) Vs is a list of variables,
% Obs is an observation list returns a list of
% dtrees that takes the observations into account.
% There is a dtree for each non-observed variable.
joint([],_,[]).
joint([X|Xs],Obs,[dtree(DVars,DTree)|JXs]) :-
   parents(X,PX),
   make_dvars(PX,X,Obs,DVars),
   DVars \== [], !,
   make_dtree(PX,X,Obs,[],DTree),
   joint(Xs,Obs,JXs).
joint([_|Xs],Obs,JXs) :-
        % we remove any dtree with no variables
   joint(Xs,Obs,JXs).

% make_dvars(PX,X,Obs,DVars)  
% where X is a variable and PX are the parents of
% X and Obs is observation list returns
% DVars = {X} U PX - observed variables
% This relies on PX ordered before X
make_dvars([],X,Obs,[]) :-
   member(X=_,Obs),!.
make_dvars([],X,_,[X]).
make_dvars([V|R],X,Obs,DVs) :-
   member(V=_,Obs),!,
   make_dvars(R,X,Obs,DVs).
make_dvars([V|R],X,Obs,[V|DVs]) :-
  % \+member(V=_,Obs),
   make_dvars(R,X,Obs,DVs).

% make_dtree(RP,X,Obs,Con,Dtree) constructs a factor
% corresponding to p(X|PX). RP is list of remaining
% parents of X, Obs is the observations, Con is a
% context of assignments to previous (in the
% variable ordering) parents of X - in reverse order
% to the variable assignment, returns DTree as the
% dtree corresponding to values of RP.
make_dtree([],X,Obs,Con,DX) :-
   member(X=OVal,Obs),!,
   reverse(Con,RCon),
   pr(X,RCon,DXPr),
   values(X,Vals),
   select_corresp_elt(Vals,OVal,DXPr,DX).
make_dtree([],X,_,Con,DX) :-
   reverse(Con,RCon),
   pr(X,RCon,DX).
make_dtree([P|RP],X,Obs,Con,DX) :-
   member(P=Val,Obs),!,
   make_dtree(RP,X,Obs,[P=Val|Con],DX).
make_dtree([P|RP],X,Obs,Con,DX) :-
   values(P,Vals),
   make_dtree_for_vals(Vals,P,RP,X,Obs,Con,DX).

% make_dtree_for_vals(Vals,P,RP,X,Obs,Con,DX).
%  makes a DTree for each value in Vals, and
% collected them into DX.  Other variables are as
% for make_dtree.
make_dtree_for_vals([],_,_,_,_,_,[]).
make_dtree_for_vals([Val|Vals],P,RP,X,Obs,Con,[ST|DX]):-
   make_dtree(RP,X,Obs,[P=Val|Con],ST),
   make_dtree_for_vals(Vals,P,RP,X,Obs,Con,DX).

% select_corresp_elt(Vals,Val,List,Elt) is true
% if Elt is at the same position in List as Val is
% in list Vals. Assumes Vals, Val, List are bound.
select_corresp_elt([Val|_],Val,[Elt|_],Elt) :-
   !.
select_corresp_elt([_|Vals],Val,[_|Rest],Elt) :-
   select_corresp_elt(Vals,Val,Rest,Elt).

% sum_out_each(SO,Joint0,Joint1) is true if
% Joint1 is a distribution Joint0 with each
% variable in SO summed out
sum_out_each([],J,J).
sum_out_each([X|Xs],J0,J2) :-
   sum_out(X,J0,J1),
   sum_out_each(Xs,J1,J2).

% sum_out(V,J0,J1) is true if
% Joint1 is a distribution Joint0 with
% variable V summed out.
sum_out(X,Dist0,[dtree(CVars1,CTree)|NoX]) :-
   partition(Dist0,X,NoX,SomeX),
   mult_tables(SomeX,VT0,T0),
   sum_from_table(X,VT0,T0,CVars1,CTree).

% partition(J0,X,NoX,SomeX) partitions J0 into
% those dtrees that contain variable X (SomeX) and
% those that do not contain X (NoX)
partition([],_,[],[]).
partition([dtree(Vs,Di)|R],X,NoX,[dtree(Vs,Di)|SomeX]) :-
   member(X,Vs),
   !,
   partition(R,X,NoX,SomeX).
partition([dtree(Vs,Di)|R],X,[dtree(Vs,Di)|NoX],SomeX) :-
   partition(R,X,NoX,SomeX).


% collect(Dist,DT) multiplies all of the factors together
% forming a DTRee. This assumes that all of the factors
% contain just the query variable
collect([dtree(_,DT)],DT) :- !.
collect([dtree(_,DT0)|R],DT2) :-
   collect(R ,DT1),
   multiply_corresp_elts(DT0,DT1,DT2).

% multiply_corresp_elts(DT0,DT1,DT2) DT2 is the dot
% product of DT0 and DT1
multiply_corresp_elts([],[],[]).
multiply_corresp_elts([E0|L0],[E1|L1],[E2|L2]) :-
   E2 is E0*E1,
   multiply_corresp_elts(L0,L1,L2).

% normalize(List,CumVal,Sum,NList) makes NList
% the same a list, but where elements sum to 1.
% Sum is the sum of all of the list, and CumVal
% is the accumulated sum to this point.
normalize([],S,S,[]).
normalize([A|L],CV,Sum,[AN|LN]) :-
   CV1 is CV + A,
   normalize(L,CV1,Sum,LN),
   AN is A/Sum.

%  ordered_union(L0,L1,R,RL) is true if R = L0 U L1, where RL
%  is a reference list that provides the ordering of elements.
%  L0, L1, RL must all be bound.
ordered_union([],L,L,_) :- !.
ordered_union(L,[],L,_) :- !.
ordered_union([E|L0],[E|L1],[E|R],[E|RL]) :- 
   !,
   ordered_union(L0,L1,R,RL).
ordered_union([E|L0],L1,[E|R],[E|RL]) :- 
   !,
   ordered_union(L0,L1,R,RL).
ordered_union(L0,[E|L1],[E|R],[E|RL]) :- 
   !,
   ordered_union(L0,L1,R,RL).
ordered_union(L0,L1,R,[_|RL]) :- 
   !,
   ordered_union(L0,L1,R,RL).

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%  SUMMING A VARIABLE FROM A TABLE
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%sum_from_table(V,T0Vs,T0,T1Vs,T1)
%  T1 is the table resulting from summing out V from table T0
%  T0Vs is the variables in T0, T1Vs is the list of variables in T1
sum_from_table(V,[V|TVs],[T0|R0],TVs,T1) :- !,
   add_corresponding_elts(T0,R0,T1).
sum_from_table(V,[V0|TVs],T0,[V0|T1Vs],T1) :- !,
   sum_for_each_value(T0,V,TVs,T1Vs,T1).

%%add_corresponding_elts(Sum,T0,T1)
add_corresponding_elts(Sum,[],Sum).
add_corresponding_elts(Sum,[H|T],T1) :-
	sum_two_trees(Sum,H,R),
	add_corresponding_elts(R,T,T1).

%%sum_two_trees(T0,T1,T2)
sum_two_trees([],[],[]) :-!.
sum_two_trees([H0|R0],[H1|R1],[H2|R2]) :-!,
	sum_two_trees(H0,H1,H2),
	sum_two_trees(R0,R1,R2).
sum_two_trees(N0,N1,N2) :-
	N2 is N0+N1.

%% sum_for_each_value(T0,V,TVs,T1Vs,T1).
%%  T0 is a list of trees, each of which has variables TVs
%%  sums out V from every element of T0, resulting in list T1 of trees
%%  T1Vs is the list of variables in the elements of T1
sum_for_each_value([],_,_,_,[]).
sum_for_each_value([T0|R0],V,TVs,T1Vs,[T1|R1]) :-
	sum_from_table(V,TVs,T0,T1Vs,T1),
	sum_for_each_value(R0,V,TVs,T1Vs,R1).



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%     MULTIPLYING TABLES
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%mult_tables(Tables,TVs,T)
mult_tables([dtree(T1Vs,T1)|R],TVs,T) :-
	mult_tables_acc(R,T1Vs,T1,TVs,T).
mult_tables_acc([],TVs,T,TVs,T).
mult_tables_acc([dtree(TkVs,Tk)|R],TaccVs,Tacc,TVs,T) :-
	mult_2tables(TkVs,Tk,TaccVs,Tacc,TnewVs,Tnew),
	mult_tables_acc(R,TnewVs,Tnew,TVs,T).

%mult_2tables(T1Vs,T1,T2Vs,T2,T1T2Vs,T1T2)
% +  T1 & T2 are tables
% +  T1Vs is the ordered list of variables in T1
% +  T2Vs is the ordered list of variables in T2
% -  T1T2 is the product of T1 & T2 with variable T1T2Vs
mult_2tables(T1Vs,T1,T2Vs,T2,T1T2Vs,T1T2) :-
	merge_vars(T1Vs,T2Vs,T1T2Vs),
	mult_tables6(T1Vs,T1,T2Vs,T2,T1T2Vs,T1T2).

%mult_tables6(T1Vs,T1,T2Vs,T2,T1T2Vs,T1T2) :-
mult_tables6([],T1,[],T2,[],T1T2) :-
	T1T2 is T1*T2.
mult_tables6([V|T1Vs],T1,[V|T2Vs],T2,[V|T1T2Vs],T1T2) :- !,
	% T1 & T2 share the smallest variable
	mult_tables_each_pair_of_values(T1,T1Vs,T2,T2Vs,T1T2,T1T2Vs).
mult_tables6([V|T1Vs],T1,T2Vs,T2,[V|T1T2Vs],T1T2) :- !,
	% smallest variables is in T1
	mult_tables_each_value(T1,T1Vs,T2,T2Vs,T1T2,T1T2Vs).
mult_tables6(T1Vs,T1,[V|T2Vs],T2,[V|T1T2Vs],T1T2) :- !,
	% smallest variables is in T1
	mult_tables_each_value(T2,T2Vs,T1,T1Vs,T1T2,T1T2Vs).

        
% mult_tables_each_pair_of_values(T1,T1Vs,T2,T2Vs,T1T2,T1T2Vs)
mult_tables_each_pair_of_values([],_,[],_,[],_).
mult_tables_each_pair_of_values([T1|R1],T1Vs,[T2|R2],T2Vs,[T1T2|R1R2],T1T2Vs) :-
	mult_tables6(T1Vs,T1,T2Vs,T2,T1T2Vs,T1T2),
	mult_tables_each_pair_of_values(R1,T1Vs,R2,T2Vs,R1R2,T1T2Vs).

% mult_tables_each_value(T1,T1Vs,T2,T2Vs,T1T2,T1T2Vs)
mult_tables_each_value([],_,_,_,[],_).
mult_tables_each_value([T1|R1],T1Vs,T2,T2Vs,[T1T2|R1R2],T1T2Vs) :-
	mult_tables6(T1Vs,T1,T2Vs,T2,T1T2Vs,T1T2),
	mult_tables_each_value(R1,T1Vs,T2,T2Vs,R1R2,T1T2Vs).

% merge_vars(T1Vs,T2Vs,T1T2Vs)
merge_vars(T1Vs,T2Vs,T1T2Vs) :-
	variables(Vars),
	merge_vars4(Vars,T1Vs,T2Vs,T1T2Vs).

% merge_vars4(Vars,T1Vs,T2Vs,T1T2Vs)
merge_vars4(_,[],L,L) :- !.
merge_vars4(_,L,[],L) :- !.
merge_vars4([H|AVs],[H|V1],[H|V2],[H|V3]) :- !,
	merge_vars4(AVs,V1,V2,V3).
merge_vars4([H|AVs],V1,[H|V2],[H|V3]) :- !,
	merge_vars4(AVs,V1,V2,V3).
merge_vars4([H|AVs],[H|V1],V2,[H|V3]) :- !,
	merge_vars4(AVs,V1,V2,V3).
merge_vars4([_|AVs],V1,V2,V3) :- !,
	merge_vars4(AVs,V1,V2,V3).



% STANDARD DEFINITIONS
% reverse(L,R) is true if R contains same elements 
% as list L, in reverse order
reverse(L,R) :-
   rev(L,[],R).
rev([],R,R).
rev([H|T],Acc,R) :-
   rev(T,[H|Acc],R).

% remove(E,L,R) true if R is the list L with 
% one occurrence of E removed
remove(E,[E|L],L).
remove(E,[A|L],[A|R]) :-
   remove(E,L,R).

% remove_if_present(E,L,R) true if R is the list
% L with one occurrence of E removed
remove_if_present(_,[],[]).
remove_if_present(E,[E|L],L) :- !.
remove_if_present(E,[A|L],[A|R]) :-
   remove_if_present(E,L,R).

% member(E,L) is true if E is a member of list L
member(A,[A|_]).
member(A,[_|L]) :-
   member(A,L).
