package ve;

/**
 * This the the class of factors that are the the result of projecting
 * another factor onto some observations.
 *
 * @author David Poole
 * @version 0.1 2001-01-29 
 * */
public class FactorObserved extends Factor {

    /**
     * constructs a factor with the given obsevations.
     *
     * @param f1 the factor that is to be updated
     * @param observedVars the array of variables that have observed
     * values. These are assumed to be in order. They do not
     * necessarily appear in f1.
     * @param ObservedVals the array of values for the corresponding
     * observed values. Each value ObservedVals[i] is the index into
     * ObservedVars[i].getDomain().
     **/
    public FactorObserved(Factor f1, Variable[] observedVars, int[] observedVals){
	super(FactorSumOut.setDiff(f1.getVariables(),observedVars),BY_OBSERVED);
	fact=f1;
	domainSizeObserved = new int[observedVars.length];
	valsObserved = new int[observedVars.length];
	domainSizeBetween = new int[observedVars.length+1];
	// determine the number and size of blocks of variables
	// observed summed out and the blocks between these.  We need
	// to construct the tables of the sizes for iterating over.

	// This is closely related to the code in FactorSumOut.java.
	// Any bug fixes here should be checked with that file.
	int varsIndex=0;
	int f1Index=0;
	numObsBlocks=0;
	domainSizeBetween[0]=1;
	while (f1Index<f1.getVariables().length) {
	    // variables to be summed out that don't appear
	    while (f1Index<f1.getVariables().length
		   && varsIndex < observedVars.length
		   && observedVars[varsIndex].getId() < f1.getVariables()[f1Index].getId()){
		varsIndex++;
	    }
	    // variables not summed out
	    while (f1Index<f1.getVariables().length 
		   && ( varsIndex >= observedVars.length 
		   || f1.getVariables()[f1Index].getId()<observedVars[varsIndex].getId())) {
		domainSizeBetween[numObsBlocks] *= f1.getVariables()[f1Index++].getDomain().length;
		// variables to be summed out that don't appear
		while (f1Index<f1.getVariables().length
		       && varsIndex < observedVars.length
		       && observedVars[varsIndex].getId() < f1.getVariables()[f1Index].getId()){
		    varsIndex++;
		}
	    }
	    if (f1Index<f1.getVariables().length && varsIndex < observedVars.length) {
		// variables observed
		domainSizeObserved[numObsBlocks] =1;
		while (f1Index < f1.getVariables().length &&
		       varsIndex < observedVars.length && 
		       observedVars[varsIndex] == f1.getVariables()[f1Index]){
		    domainSizeObserved[numObsBlocks] *= observedVars[varsIndex].getDomain().length;
		    valsObserved[numObsBlocks]=valsObserved[numObsBlocks]* observedVars[varsIndex].getDomain().length+observedVals[varsIndex];
		    varsIndex++;
		    f1Index++;

		    // variables observed that don't appear in f1
		    while (f1Index<f1.getVariables().length
			   && varsIndex < observedVars.length
			   && observedVars[varsIndex].getId() < f1.getVariables()[f1Index].getId()){
			varsIndex++;
		    }
		}
		domainSizeBetween[++numObsBlocks]=1;
	    }
	}
    }

    private int[] domainSizeObserved;
    private int[] valsObserved;
    private int[] domainSizeBetween;
    private int numObsBlocks;
    private Factor fact;

    /*
     * Returns an iterator over the values of the factor. This is like
     * the iterator in FactorExpand.  */
    public EltsIterator iterator() {
	if (numObsBlocks==0) {
	    return fact.iterator();
	}
	else {
	    return new Itr();
	}
    }
    
    private class Itr implements EltsIterator {
	Itr(){
	    currBet=new int[numObsBlocks+1];
	    for (int i=0; i<= numObsBlocks; i++) {
		currBet[i]=domainSizeBetween[i];
	    }
	    factItr = fact.iterator();
	    factItr.backTo(factPos());
	}

	/** 
	 * iterator for fact.
	 */
	private EltsIterator factItr;

	/**
	 * array of the counts for each digit of domainSizeBetween.
	 * currBet[i] starts off with domainSizeBetween[i] and counts
	 * down.  */
	private int[] currBet;

	public boolean hasNext() {
	    if (currBet[numObsBlocks]>0) {
		return true;}
	    for (int i=0; i<numObsBlocks; i++) {
		if (currBet[i]>1) {return true;}
	    }
	    return false;
	}

	public double next() {
	    if (currBet[numObsBlocks] >0) {
		// normal case
		currBet[numObsBlocks]--;
		return factItr.next();
	    }
	    for (int i=numObsBlocks-1; i>=0; i--) {
		currBet[i+1]=domainSizeBetween[i+1];
		if (currBet[i]>1) {
		    currBet[i]--;
		    factItr.backTo(factPos()); 
		    // can we do something more efficient than that?
		    // for example by maintaining an offset to subtract?
		    currBet[numObsBlocks]--;
		    return factItr.next();
		}
	    }
	    System.out.println("This shouldnt occur");
	    return factItr.next(); // I don't think this should occur
	}
	
	public int currPos() {
	    int pos=domainSizeBetween[0]-currBet[0];
	    for (int i=1; i<=numObsBlocks; i++) {
		pos= (pos+1)*domainSizeBetween[i]-currBet[i];
	    }
	    return pos;

	}

	public void backTo(int pos) {
	    for (int i=numObsBlocks; i> 0; i--) {
		currBet[i]=domainSizeBetween[i]-pos%domainSizeBetween[i];
		pos /= domainSizeBetween[i];
	    }
	    currBet[0]=domainSizeBetween[0]-pos;
	    factItr.backTo(factPos());
	}

	/**
	 * given the current values of currBet, returns the position
	 * of the iterator for fact.  */
	private int factPos() {
	    int pos=domainSizeBetween[0]-currBet[0];
	    for (int i=0; i<numObsBlocks; i++) {
		pos= (pos*domainSizeObserved[i]+valsObserved[i]+1)*domainSizeBetween[i+1]-currBet[i+1];
	    }
	    return pos;
	}
	    
	
    }
}
