package ve;

/**
 * This the the class of factors that are created by summing out a set
 * of variables from another factor.
 *
 * Copyright 2001, David Poole. All rights reserved.
 *
 * @author David Poole
 * @version 0.1 2001-01-26
 **/
public class FactorSumOut extends FactorStored {

    /**
     * constructs a factor with the given variables.
     *
     * @param f1 the original factor. 
     * @param vars the list of variables to be summed out. This
     * assumes that vars is a subset of f1.getVariables().  
     **/
    public FactorSumOut(Factor f1, Variable[] vars) {
	super(setDiff(f1.getVariables(),vars),BY_SUM_OUT);
	if (getSavingForTracing()) {
	    theFactor = f1;
	    theVariables = vars;
	}
	int[] valsSumOut = new int[vars.length];
	int[] valsBetween = new int[vars.length+1];
	// we need to construct the tables of the sizes for iterating over.
	int varsIndex=0;
	int f1Index=0;
	int numExpBlocks=0;
	valsBetween[0]=1;

	// determine the number and size of blocks of variables to be
	// summed out and the blocks between these.
	while (f1Index<f1.getVariables().length) {
	    // variables to be summed out that don't appear
	    while (f1Index<f1.getVariables().length
		   && varsIndex < vars.length
		   && vars[varsIndex].getId() < f1.getVariables()[f1Index].getId()){
		varsIndex++;
	    }
	    // variables not summed out
	    while (f1Index<f1.getVariables().length 
		   && ( varsIndex >= vars.length 
		   || f1.getVariables()[f1Index].getId()<vars[varsIndex].getId())) {
		valsBetween[numExpBlocks] *= f1.getVariables()[f1Index++].getDomain().length;
		// variables to be summed out that don't appear
		while (f1Index<f1.getVariables().length
		       && varsIndex < vars.length
		       && vars[varsIndex].getId() < f1.getVariables()[f1Index].getId()){
		    varsIndex++;
		}
	    }
	    if (varsIndex < vars.length) {
		// variables to be summed out
		valsSumOut[numExpBlocks] =1;
		while (f1Index < f1.getVariables().length &&
		       varsIndex < vars.length && 
		       vars[varsIndex] == f1.getVariables()[f1Index]){
		    valsSumOut[numExpBlocks] *= vars[varsIndex++].getDomain().length;
		    f1Index++;
		    // variables to be summed out that don't appear
		    while (f1Index<f1.getVariables().length
			   && varsIndex < vars.length
			   && vars[varsIndex].getId() < f1.getVariables()[f1Index].getId()){
			varsIndex++;
		    }
		}
		valsBetween[++numExpBlocks]=1;
	    }
	}

	// put the sum in the right place
	EltsIterator f1Iter = f1.iterator();
	if (numExpBlocks==0) {
	    // not summing out any variables
	    int curpos=0;
	    while (f1Iter.hasNext()) {
		factorValues[curpos++]=f1Iter.next();
	    }
	    return;
	}
		
	int[] countSumOut = new int[numExpBlocks];
	int[] countBetween = new int[numExpBlocks+1];
	int[] offsets = new int[numExpBlocks];
	offsets[numExpBlocks-1]=valsBetween[numExpBlocks];
	for (int i=numExpBlocks-1; i>0; i--) {
	    offsets[i-1]=offsets[i]*valsBetween[i];
	}
	
	if (valsBetween[numExpBlocks]==1) {
	    // We are summing out the right-most variable(s)
	    int curpos=0;
	    int innerMostLoopSize = valsSumOut[numExpBlocks-1];
	    while (f1Iter.hasNext()) {
		double sum=0;
		for (int i=0; i<innerMostLoopSize; i++){
		    sum+=f1Iter.next();
		}
		factorValues[curpos++]+=sum;
		if (++countBetween[numExpBlocks-1]==valsBetween[numExpBlocks-1]) {
		    countBetween[numExpBlocks-1]=0;
		    boolean going=true;
		    for (int j=numExpBlocks-2;going && j>=0;j--) {
			if(++countSumOut[j]<valsSumOut[j]) {
			    curpos-=offsets[j];
			    going=false;
			}
			else {
			    countSumOut[j]=0;
			    if(++countBetween[j]<valsBetween[j]) {
				going=false;
			    }
			    else {
				countBetween[j]=0;
			    }
			}
		    }
		}
	    }
	}
	else {
	    // We are not summing out the right-most variable
	    int curpos=0;
	    while (f1Iter.hasNext()) {
		factorValues[curpos++]+=f1Iter.next();
		if (++countBetween[numExpBlocks]==valsBetween[numExpBlocks]) {
		    boolean going=true;
		    countBetween[numExpBlocks]=0;
		    for (int j=numExpBlocks-1;going && j>=0;j--) {
			if(++countSumOut[j]<valsSumOut[j]) {
			    curpos-=offsets[j];
			    going=false;
			}
			else {
			    countSumOut[j]=0;
			    if(++countBetween[j]<valsBetween[j]) {
				going=false;
			    }
			    else {
				countBetween[j]=0;
			    }
			}
		    }
		}
	    }
	}
    }

    /**
     * returns the array representing the set difference of the variables.
     *
     * @param vars1 an ordered list of variables.
     * @param vars2 an ordered list of variables.
     * @return the elements of vars1 not in vars2.
     */
    public static Variable[] setDiff(Variable[] vars1, Variable[] vars2) {
	Variable[] vars = new Variable[vars1.length];
	int pos=0;  // number of elements currently in the difference
	int i1=0;   // next position in vars1
	int i2=0;   // next position in vars2
	while (i1<vars1.length && i2<vars2.length) {
	    if (vars1[i1]==vars2[i2]) {
		i1++; i2++;}
	    else if (vars1[i1].getId() < vars2[i2].getId()) {
		vars[pos++]=vars1[i1++]; }
	    else { 
                i2++; }
	}
	while (i1<vars1.length) {
	    vars[pos++]=vars1[i1++];}
        Variable[] result = new Variable[pos];
	for (int i=0; i<pos; i++) {
	    result[i]=vars[i];
	}
	return result;
    }


    /**
     * the factor that the variable is being summed out from
     **/
    private Factor theFactor;

    /**
     * returns the factor that the variable is being summed out from
     **/
    public Factor getTheFactor() {
	return theFactor;
    }
    /**
     * the variables being summed out
     **/
    private Variable[] theVariables;
    /**
     * returns the variables being summed out
     **/
    public Variable[] getTheVariables() {
	return theVariables;
    }

}
