/*
 * Decompiled with CFR 0.152.
 */
package ca.ubc.cs.beta.models.fastrf;

import ca.ubc.cs.beta.models.fastrf.RegtreeFwd;
import ca.ubc.cs.beta.models.fastrf.utils.Utils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import java.util.TreeSet;
import java.util.Vector;

public strictfp class Regtree
implements Serializable {
    private static final long serialVersionUID = -7861532246973394125L;
    public int numNodes;
    public int[] node;
    public int[] parent;
    public double[][] ysub;
    public int[] var;
    public double[] cut;
    public int[][] children;
    public int[] nodesize;
    public int npred;
    public int[][] catsplit;
    public double[] nodepred;
    public double[] nodevar;
    public boolean resultsStoredInLeaves;
    public boolean preprocessed;
    public double[] weightedpred;
    public double[] weightedvar;
    public double[] weights;
    public boolean preprocessed_for_classification;
    public double[][] bestClasses;
    public double[][] leafContLB;
    public double[][] leafContUB;
    public Set<Integer>[][] leafallCatValues;
    public double[][] leafDomainPercentage;
    public Vector<Integer> leafIndices;
    public boolean leafInfoIsPrecomputed = false;
    public boolean[] isCatDimension;
    public int[] categoricalDomainSizes;
    public int logModel;

    public Regtree(int numNodes, int logModel) {
        this.numNodes = numNodes;
        this.logModel = logModel;
        this.preprocessed = false;
        this.preprocessed_for_classification = false;
    }

    public Regtree(int numNodes, int ncatsplit, boolean storeResultsInLeaves, int logModel) {
        this(numNodes, logModel);
        this.resultsStoredInLeaves = storeResultsInLeaves;
        this.node = new int[numNodes];
        this.parent = new int[numNodes];
        this.var = new int[numNodes];
        this.cut = new double[numNodes];
        this.children = new int[numNodes][2];
        this.nodesize = new int[numNodes];
        this.catsplit = new int[ncatsplit][];
        this.ysub = this.resultsStoredInLeaves ? (Object)new double[numNodes][] : new double[numNodes][2];
    }

    public Regtree(Regtree t) {
        this(t.numNodes, t.catsplit.length, t.resultsStoredInLeaves, t.logModel);
        int i;
        this.npred = t.npred;
        System.arraycopy(t.node, 0, this.node, 0, this.numNodes);
        System.arraycopy(t.parent, 0, this.parent, 0, this.numNodes);
        System.arraycopy(t.var, 0, this.var, 0, this.numNodes);
        System.arraycopy(t.cut, 0, this.cut, 0, this.numNodes);
        System.arraycopy(t.nodesize, 0, this.nodesize, 0, this.numNodes);
        for (i = 0; i < t.catsplit.length; ++i) {
            this.catsplit[i] = new int[t.catsplit[i].length];
            this.catsplit[i] = new int[t.catsplit[i].length];
            System.arraycopy(t.catsplit[i], 0, this.catsplit[i], 0, t.catsplit[i].length);
        }
        for (i = 0; i < this.numNodes; ++i) {
            this.children[i][0] = t.children[i][0];
            this.children[i][1] = t.children[i][1];
            if (this.resultsStoredInLeaves) {
                int Nnode;
                int n = Nnode = this.var[i] == 0 ? this.nodesize[i] : 0;
                if (Nnode == 0) continue;
                this.ysub[i] = new double[Nnode];
                System.arraycopy(t.ysub[i], 0, this.ysub[i], 0, Nnode);
                continue;
            }
            System.arraycopy(t.ysub[i], 0, this.ysub[i], 0, t.ysub[i].length);
        }
        this.preprocessed = t.preprocessed;
        if (this.preprocessed) {
            this.weights = new double[this.numNodes];
            this.weightedpred = new double[this.numNodes];
            this.weightedvar = new double[this.numNodes];
            System.arraycopy(t.weights, 0, this.weights, 0, this.numNodes);
            System.arraycopy(t.weightedpred, 0, this.weightedpred, 0, this.numNodes);
            System.arraycopy(t.weightedvar, 0, this.weightedvar, 0, this.numNodes);
        }
        this.recalculateStats();
    }

    public static double[][] apply(Regtree tree, double[][] X) {
        int[] nodes = RegtreeFwd.fwd(tree, X);
        double[][] retn = new double[X.length][2];
        for (int i = 0; i < X.length; ++i) {
            retn[i][0] = tree.nodepred[nodes[i]];
            retn[i][1] = tree.nodevar[nodes[i]];
        }
        return retn;
    }

    public static double[] classify(Regtree tree, double[][] X) {
        if (!tree.preprocessed_for_classification) {
            RegtreeFwd.preprocess_for_classification(tree);
        }
        int[] nodes = RegtreeFwd.fwd(tree, X);
        double[] retn = new double[X.length];
        for (int i = 0; i < X.length; ++i) {
            double[] best = tree.bestClasses[nodes[i]];
            retn[i] = best[(int)(Math.random() * (double)best.length)];
        }
        return retn;
    }

    public static Object[] applyMarginal(Regtree tree, double[][] Theta, double[][] X) {
        return RegtreeFwd.marginalFwd(tree, Theta, X);
    }

    public static void update(Regtree tree, double[][] newx, double[] newy) {
        int i;
        if (tree.preprocessed) {
            throw new RuntimeException("Cannot update preprocessed forests.");
        }
        if (null == newx || null == newy) {
            throw new RuntimeException("Input newx or newy to update is null.");
        }
        if (newx.length != newy.length) {
            throw new RuntimeException("Argument sizes mismatch.");
        }
        if (tree.logModel > 0) {
            for (int i2 = 0; i2 < newy.length; ++i2) {
                newy[i2] = Math.pow(10.0, newy[i2]);
            }
        }
        int[] nodes = RegtreeFwd.fwd(tree, newx);
        boolean[] nodeChanged = new boolean[tree.node.length];
        for (i = 0; i < newx.length; ++i) {
            int node = nodes[i];
            nodeChanged[node] = true;
            int Nnode = tree.nodesize[node];
            if (tree.resultsStoredInLeaves) {
                double[] newysub = new double[Nnode + 1];
                if (Nnode != 0) {
                    System.arraycopy(tree.ysub[node], 0, newysub, 0, Nnode);
                }
                newysub[Nnode] = newy[i];
                tree.ysub[node] = newysub;
            } else {
                double[] dArray = tree.ysub[node];
                dArray[0] = dArray[0] + newy[i];
                double[] dArray2 = tree.ysub[node];
                dArray2[1] = dArray2[1] + newy[i] * newy[i];
            }
            int n = node;
            tree.nodesize[n] = tree.nodesize[n] + 1;
            while (node != 0) {
                int n2 = node = tree.parent[node];
                tree.nodesize[n2] = tree.nodesize[n2] + 1;
            }
        }
        for (i = 0; i < nodeChanged.length; ++i) {
            if (!nodeChanged[i]) continue;
            tree.recalculateStats(i);
        }
    }

    public void recalculateStats() {
        this.nodepred = new double[this.numNodes];
        this.nodevar = new double[this.numNodes];
        for (int i = 0; i < this.numNodes; ++i) {
            this.recalculateStats(i);
        }
    }

    public void recalculateStats(int node) {
        if (this.var[node] != 0) {
            return;
        }
        if (this.resultsStoredInLeaves) {
            this.nodepred[node] = Utils.mean(this.ysub[node]);
            this.nodevar[node] = Utils.var(this.ysub[node]);
        } else {
            double sum = this.ysub[node][0];
            double sumOfSq = this.ysub[node][1];
            int N = this.nodesize[node];
            this.nodepred[node] = sum / (double)N;
            this.nodevar[node] = (sumOfSq - sum * sum / (double)N) / (double)Math.max(N - 1, 1);
        }
    }

    public void verifyInputsAreConsistent(int[] indicesOfObservations, double[] observations) {
        if (!this.leafInfoIsPrecomputed) {
            throw new RuntimeException("Leaf info has to be precomputed before predicting marginal performance.");
        }
        if (indicesOfObservations.length != observations.length) {
            throw new IllegalArgumentException("indicesOfObservations and observations vectors must be of same length");
        }
        for (int j = 0; j < observations.length; ++j) {
            int index = indicesOfObservations[j];
            if (!this.isCatDimension[index]) continue;
            int value = (int)observations[j];
            if (Math.abs(observations[j] - (double)value) > 1.0E-6) {
                throw new IllegalArgumentException("Dimension " + indicesOfObservations[j] + " must be categorical.");
            }
            if (value >= 0 && value < this.categoricalDomainSizes[index]) continue;
            throw new IllegalArgumentException("Dimension " + indicesOfObservations[j] + " is categorical with domain size " + this.categoricalDomainSizes[index] + ", but the input value passed was " + value + ". Note that this is 0-indexed.");
        }
    }

    private double probabilityOfLeafForRemainingDimensions(int[] dimensions, int leafIdx) {
        HashSet<Integer> dimensionIndexSet = new HashSet<Integer>();
        for (int i = 0; i < dimensions.length; ++i) {
            dimensionIndexSet.add(dimensions[i]);
        }
        return this.probabilityOfLeafForRemainingDimensions(dimensionIndexSet, leafIdx);
    }

    private double probabilityOfLeafForRemainingDimensions(Set<Integer> dimensionIndexSet, int leafIdx) {
        double p = 1.0;
        for (int j = 0; j < this.isCatDimension.length; ++j) {
            if (dimensionIndexSet.contains(j)) continue;
            p *= this.leafDomainPercentage[leafIdx][j];
        }
        return p;
    }

    public boolean observationsAreConsistentWithLeaf(int[] indicesOfObservations, double[] observations, int leafIdx) {
        if (indicesOfObservations == null && observations == null) {
            return true;
        }
        for (int j = 0; j < indicesOfObservations.length; ++j) {
            int o = indicesOfObservations[j];
            if (this.isCatDimension[o]) {
                int value = (int)observations[j];
                assert (Math.abs(observations[j] - (double)value) < 1.0E-6);
                if (this.leafallCatValues[leafIdx][o].contains(value)) continue;
                return false;
            }
            double value = observations[j];
            if (!(value < this.leafContLB[leafIdx][o]) && !(this.leafContUB[leafIdx][o] < value)) continue;
            return false;
        }
        return true;
    }

    public static Set<HashSet<Integer>> powerSet(Set<Integer> originalSet) {
        HashSet<HashSet<Integer>> setOfSubsets = new HashSet<HashSet<Integer>>();
        if (originalSet.isEmpty()) {
            setOfSubsets.add(new HashSet());
            return setOfSubsets;
        }
        ArrayList<Integer> originalList = new ArrayList<Integer>(originalSet);
        Integer head = (Integer)originalList.get(0);
        HashSet<Integer> rest = new HashSet<Integer>(originalList.subList(1, originalList.size()));
        for (HashSet<Integer> subset : Regtree.powerSet(rest)) {
            setOfSubsets.add(subset);
            HashSet<Integer> subsetPlusHead = new HashSet<Integer>();
            subsetPlusHead.add(head);
            subsetPlusHead.addAll(subset);
            setOfSubsets.add(subsetPlusHead);
        }
        return setOfSubsets;
    }

    public double computeFactorVariance(int[] indicesOfObservations, double[] observations, HashSet<Integer> indicesOfFactor, HashMap<Set<Integer>, Double> precomputedFactorVariance) {
        System.out.println("Computing factorVariance for " + indicesOfFactor + " ... ");
        if (indicesOfFactor.isEmpty()) {
            return 0.0;
        }
        if (precomputedFactorVariance.containsKey(indicesOfFactor)) {
            return precomputedFactorVariance.get(indicesOfFactor);
        }
        double totalVariance = this.computeTotalVariance(indicesOfObservations, observations, indicesOfFactor);
        System.out.println("Total variance for " + indicesOfFactor + " is " + totalVariance);
        Set<HashSet<Integer>> setOfSubsets = Regtree.powerSet(indicesOfFactor);
        setOfSubsets.remove(indicesOfFactor);
        for (HashSet<Integer> subset : setOfSubsets) {
            double subFactorVariance = this.computeFactorVariance(indicesOfObservations, observations, subset, precomputedFactorVariance);
            totalVariance -= subFactorVariance;
            System.out.println("Subtracting subFactorVariance " + subFactorVariance);
        }
        HashSet<Integer> copyOfIndicesOfFactorForSaving = new HashSet<Integer>();
        copyOfIndicesOfFactorForSaving.addAll(indicesOfFactor);
        precomputedFactorVariance.put(copyOfIndicesOfFactorForSaving, totalVariance);
        return totalVariance;
    }

    public double computeProperTotalVariance() {
        double totalVariance = 0.0;
        int dim = this.isCatDimension.length;
        for (int tmp1 = 0; tmp1 < this.leafIndices.size(); ++tmp1) {
            int i = this.leafIndices.get(tmp1);
            for (int tmp2 = 0; tmp2 < this.leafIndices.size(); ++tmp2) {
                int j = this.leafIndices.get(tmp2);
                double probBothLeaves = 1.0;
                for (int d = 0; d < dim; ++d) {
                    if (this.isCatDimension[d]) {
                        HashSet<Integer> intersection = new HashSet<Integer>();
                        intersection.addAll(this.leafallCatValues[i][d]);
                        intersection.retainAll(this.leafallCatValues[j][d]);
                        probBothLeaves *= ((double)intersection.size() + 0.0) / (double)this.categoricalDomainSizes[d];
                    } else {
                        double lower = Math.max(this.leafContLB[i][d], this.leafContLB[j][d]);
                        double upper = Math.min(this.leafContUB[i][d], this.leafContUB[j][d]);
                        probBothLeaves *= upper - lower;
                    }
                    if (!(probBothLeaves <= 0.0)) continue;
                    probBothLeaves = 0.0;
                    break;
                }
                totalVariance += probBothLeaves * this.nodepred[i] * this.nodepred[j] * this.weights[i] * this.weights[j];
            }
        }
        double a_0 = 0.0;
        for (Integer leafIdx : this.leafIndices) {
            double pThisFactor = this.probabilityOfLeafForRemainingDimensions(new TreeSet<Integer>(), (int)leafIdx);
            double value = this.weights[leafIdx] * this.nodepred[leafIdx];
            a_0 += pThisFactor * value;
        }
        return totalVariance - Math.pow(a_0, 2.0);
    }

    public double computeTotalVariance(int[] indicesOfObservations, double[] observations, HashSet<Integer> indicesOfFactor) {
        if (indicesOfObservations != null && indicesOfObservations.length > 0) {
            this.verifyInputsAreConsistent(indicesOfObservations, observations);
            throw new RuntimeException("observations not quite supported yet...");
        }
        HashSet<Integer> indicesNotInFactor = new HashSet<Integer>();
        for (int i = 0; i < this.isCatDimension.length; ++i) {
            indicesNotInFactor.add(i);
        }
        indicesNotInFactor.removeAll(indicesOfFactor);
        double a_0 = 0.0;
        for (Integer leafIdx : this.leafIndices) {
            if (!this.observationsAreConsistentWithLeaf(indicesOfObservations, observations, leafIdx)) continue;
            double pThisFactor = this.probabilityOfLeafForRemainingDimensions(indicesNotInFactor, (int)leafIdx);
            double pRemainingVars = this.probabilityOfLeafForRemainingDimensions(indicesOfFactor, (int)leafIdx);
            double value = pRemainingVars * this.weights[leafIdx] * this.nodepred[leafIdx];
            a_0 += pThisFactor * value;
        }
        double totalVariance = 0.0;
        for (Integer leafIdx : this.leafIndices) {
            double pThisFactor = this.probabilityOfLeafForRemainingDimensions(indicesNotInFactor, (int)leafIdx);
            double pRemainingVars = this.probabilityOfLeafForRemainingDimensions(indicesOfFactor, (int)leafIdx);
            double value = pRemainingVars * this.weights[leafIdx] * this.nodepred[leafIdx];
            System.out.println("Leaf " + leafIdx + ": weight=" + this.weights[leafIdx] + ", nodepred=" + this.nodepred[leafIdx] + ", value=" + value + ", a_0=" + a_0 + ", pThisFactor=" + pThisFactor + ", pRemainingVars=" + pRemainingVars);
            totalVariance += pThisFactor * pRemainingVars * Math.pow(value - a_0, 2.0);
        }
        return totalVariance;
    }

    public double marginalPerformance(int[] indicesOfObservations, double[] observations) {
        this.verifyInputsAreConsistent(indicesOfObservations, observations);
        double result = 0.0;
        double sumOfP = 0.0;
        double sumOfWeights = 0.0;
        int numConsistent = 0;
        for (Integer leafIdx : this.leafIndices) {
            if (!this.observationsAreConsistentWithLeaf(indicesOfObservations, observations, leafIdx)) continue;
            ++numConsistent;
            double p = this.probabilityOfLeafForRemainingDimensions(indicesOfObservations, (int)leafIdx);
            double pred = this.nodepred[leafIdx];
            result += p * this.weights[leafIdx] * pred;
            sumOfP += p;
            sumOfWeights += this.weights[leafIdx];
        }
        return result;
    }

    public void precomputeLeafInfo(boolean[] isCat, HashSet<Integer>[] allCatValues, double[] contLB, double[] contUB) {
        int i;
        this.isCatDimension = isCat;
        int dim = isCat.length;
        this.leafIndices = new Vector();
        this.categoricalDomainSizes = new int[dim];
        for (i = 0; i < dim; ++i) {
            this.categoricalDomainSizes[i] = allCatValues[i] == null ? 0 : allCatValues[i].size();
        }
        this.leafContLB = new double[this.numNodes][dim];
        this.leafContUB = new double[this.numNodes][dim];
        this.leafallCatValues = new Set[this.numNodes][dim];
        this.leafDomainPercentage = new double[this.numNodes][dim];
        for (i = 0; i < this.numNodes; ++i) {
            this.leafContLB[i] = new double[dim];
            this.leafContUB[i] = new double[dim];
            this.leafallCatValues[i] = new Set[dim];
            this.leafDomainPercentage[i] = new double[dim];
        }
        this.precomputeLeafInfoInSubtree(0, isCat, allCatValues, contLB, contUB);
        this.leafInfoIsPrecomputed = true;
    }

    private void precomputeLeafInfoInSubtree(int thisnode, boolean[] isCat, HashSet<Integer>[] allCatValues, double[] contLB, double[] contUB) {
        int splitvar = this.var[thisnode];
        if (splitvar == 0) {
            this.leafIndices.add(thisnode);
            int dim = isCat.length;
            this.leafallCatValues[thisnode] = allCatValues;
            for (int i = 0; i < dim; ++i) {
                this.leafContLB[thisnode][i] = contLB[i];
                this.leafContUB[thisnode][i] = contUB[i];
                this.leafDomainPercentage[thisnode][i] = isCat[i] ? ((double)allCatValues[i].size() + 0.0) / (double)this.categoricalDomainSizes[i] : contUB[i] - contLB[i];
            }
        } else {
            if (splitvar > 0 ? !$assertionsDisabled && isCat[splitvar] : !$assertionsDisabled && !isCat[-splitvar]) {
                throw new AssertionError();
            }
            int left_kid = this.children[thisnode][0];
            int right_kid = this.children[thisnode][1];
            if (Math.abs(splitvar) > isCat.length) {
                this.precomputeLeafInfoInSubtree(left_kid, isCat, allCatValues, contLB, contUB);
                this.precomputeLeafInfoInSubtree(right_kid, isCat, allCatValues, contLB, contUB);
            } else {
                double cutoff = this.cut[thisnode];
                if (splitvar > 0) {
                    int v = splitvar - 1;
                    double previousUB = contUB[splitvar - 1];
                    contUB[v] = cutoff;
                    this.precomputeLeafInfoInSubtree(left_kid, isCat, allCatValues, contLB, contUB);
                    contUB[v] = previousUB;
                    double previousLB = contLB[splitvar - 1];
                    contLB[v] = cutoff;
                    this.precomputeLeafInfoInSubtree(right_kid, isCat, allCatValues, contLB, contUB);
                    contLB[v] = previousLB;
                } else {
                    int v = -splitvar - 1;
                    HashSet<Integer> thisValues = allCatValues[v];
                    HashSet<Integer> leftValues = new HashSet<Integer>();
                    HashSet<Integer> rightValues = new HashSet<Integer>();
                    for (Integer thisValue : thisValues) {
                        int split = this.catsplit[(int)cutoff][thisValue];
                        if (split == 0) {
                            leftValues.add(thisValue);
                            continue;
                        }
                        if (split == 1) {
                            rightValues.add(thisValue);
                            continue;
                        }
                        throw new IllegalStateException("Error in node " + thisnode + ": catsplit does not state which kid to propagate value " + thisValue + " to. Note that input allCatValues should be 0-indexed!");
                    }
                    HashSet[] allLeftValues = (HashSet[])allCatValues.clone();
                    allLeftValues[v] = leftValues;
                    this.precomputeLeafInfoInSubtree(left_kid, isCat, allLeftValues, contLB, contUB);
                    HashSet[] allRightValues = (HashSet[])allCatValues.clone();
                    allRightValues[v] = rightValues;
                    this.precomputeLeafInfoInSubtree(right_kid, isCat, allRightValues, contLB, contUB);
                }
            }
        }
    }
}

