/*
 * Decompiled with CFR 0.152.
 */
package ca.ubc.cs.beta.models.fastrf;

import ca.ubc.cs.beta.models.fastrf.Regtree;
import java.util.Arrays;
import java.util.LinkedList;

public class RegtreeFwd {
    public static int[] fwd(Regtree tree, double[][] X) {
        int numdata = X.length;
        int numnodes = tree.node.length;
        if (tree.cut.length != numnodes) {
            throw new RuntimeException("cut must be Nx1 vector.");
        }
        if (tree.nodepred.length != numnodes) {
            throw new RuntimeException("nodepred must be Nx1 vector.");
        }
        if (tree.children.length != numnodes) {
            throw new RuntimeException("children must be Nx2 matrix.");
        }
        int[] result = new int[numdata];
        int i = 0;
        while (i < numdata) {
            int thisnode;
            block7: {
                thisnode = 0;
                while (true) {
                    int splitvar;
                    if ((splitvar = tree.var[thisnode]) == 0) break block7;
                    double cutoff = tree.cut[thisnode];
                    int left_kid = tree.children[thisnode][0];
                    int right_kid = tree.children[thisnode][1];
                    if (splitvar > 0) {
                        thisnode = X[i][splitvar - 1] <= cutoff ? left_kid : right_kid;
                        continue;
                    }
                    int x = (int)X[i][-splitvar - 1];
                    int split = tree.catsplit[(int)cutoff][x - 1];
                    if (split == 0) {
                        thisnode = left_kid;
                        continue;
                    }
                    if (split != 1) break;
                    thisnode = right_kid;
                }
                throw new RuntimeException("Missing value -- not allowed in this implementation.");
            }
            result[i] = thisnode;
            ++i;
        }
        return result;
    }

    public static double[] marginalFwd(Regtree tree, double[][] Theta, double[][] X) {
        if (Theta == null || Theta.length == 0) {
            throw new RuntimeException("Theta must not be empty");
        }
        int thetarows = Theta.length;
        int thetacols = Theta[0].length;
        int numnodes = tree.node.length;
        if (numnodes == 0) {
            throw new RuntimeException("Tree must exist.");
        }
        if (tree.cut.length != numnodes) {
            throw new RuntimeException("cut must be Nx1 vector.");
        }
        if (tree.nodepred.length != numnodes) {
            throw new RuntimeException("nodepred must be Nx1 vector.");
        }
        if (tree.parent.length != numnodes) {
            throw new RuntimeException("parent must be Nx1 matrix.");
        }
        if (tree.children.length != numnodes) {
            throw new RuntimeException("children must be Nx2 matrix.");
        }
        Regtree preprocessed = tree.preprocessed ? tree : RegtreeFwd.preprocess_inst_splits(tree, X);
        double[] result = new double[thetarows];
        LinkedList<Integer> queue = new LinkedList<Integer>();
        int i = 0;
        while (i < thetarows) {
            queue.add(0);
            block1: while (!queue.isEmpty()) {
                int thisnode = (Integer)queue.poll();
                while (true) {
                    int splitvar = preprocessed.var[thisnode];
                    double cutoff = preprocessed.cut[thisnode];
                    int left_kid = preprocessed.children[thisnode][0];
                    int right_kid = preprocessed.children[thisnode][1];
                    if (splitvar == 0) {
                        int n = i;
                        result[n] = result[n] + preprocessed.weightedpred[thisnode];
                        continue block1;
                    }
                    if (Math.abs(splitvar) > thetacols) {
                        queue.add(right_kid);
                        thisnode = left_kid;
                        continue;
                    }
                    if (splitvar > 0) {
                        thisnode = Theta[i][splitvar - 1] <= cutoff ? left_kid : right_kid;
                        continue;
                    }
                    int x = (int)Theta[i][-splitvar - 1];
                    int split = preprocessed.catsplit[(int)cutoff][x - 1];
                    if (split == 0) {
                        thisnode = left_kid;
                        continue;
                    }
                    if (split != 1) break;
                    thisnode = right_kid;
                }
                throw new RuntimeException("Missing value -- not allowed in this implementation.");
            }
            ++i;
        }
        if (tree.logModel == 1) {
            i = 0;
            while (i < thetarows) {
                result[i] = Math.log10(result[i]);
                ++i;
            }
        }
        return result;
    }

    public static Object[] fwdThetas(Regtree tree, double[][] Theta) {
        if (Theta == null || Theta.length == 0) {
            throw new RuntimeException("Theta must not be empty");
        }
        int thetarows = Theta.length;
        int thetacols = Theta[0].length;
        int numnodes = tree.node.length;
        if (numnodes == 0) {
            throw new RuntimeException("Tree must exist.");
        }
        if (tree.cut.length != numnodes) {
            throw new RuntimeException("cut must be Nx1 vector.");
        }
        if (tree.nodepred.length != numnodes) {
            throw new RuntimeException("nodepred must be Nx1 vector.");
        }
        if (tree.parent.length != numnodes) {
            throw new RuntimeException("parent must be Nx1 matrix.");
        }
        if (tree.children.length != numnodes) {
            throw new RuntimeException("children must be Nx2 matrix.");
        }
        if (!tree.preprocessed) {
            throw new RuntimeException("fwdThetas can only be called on a preprocessed tree.");
        }
        Regtree preprocessed = tree;
        int numLeavesWithTheta = 0;
        int[][] thetaIdxs = new int[numnodes][thetarows];
        int[] thetaIdxsLens = new int[numnodes];
        Arrays.fill(thetaIdxsLens, 0);
        LinkedList<Integer> queue = new LinkedList<Integer>();
        int i = 0;
        while (i < thetarows) {
            queue.add(0);
            block1: while (!queue.isEmpty()) {
                int thisnode = (Integer)queue.poll();
                while (true) {
                    int splitvar = preprocessed.var[thisnode];
                    double cutoff = preprocessed.cut[thisnode];
                    int left_kid = preprocessed.children[thisnode][0];
                    int right_kid = preprocessed.children[thisnode][1];
                    if (splitvar == 0) {
                        if (thetaIdxsLens[thisnode] == 0) {
                            ++numLeavesWithTheta;
                        }
                        int n = thisnode;
                        int n2 = thetaIdxsLens[n];
                        thetaIdxsLens[n] = n2 + 1;
                        thetaIdxs[thisnode][n2] = i;
                        continue block1;
                    }
                    if (Math.abs(splitvar) > thetacols) {
                        queue.add(right_kid);
                        thisnode = left_kid;
                        continue;
                    }
                    if (splitvar > 0) {
                        thisnode = Theta[i][splitvar - 1] <= cutoff ? left_kid : right_kid;
                        continue;
                    }
                    int x = (int)Theta[i][-splitvar - 1];
                    int split = preprocessed.catsplit[(int)cutoff][x - 1];
                    if (split == 0) {
                        thisnode = left_kid;
                        continue;
                    }
                    if (split != 1) break;
                    thisnode = right_kid;
                }
                throw new RuntimeException("Missing value -- not allowed in this implementation.");
            }
            ++i;
        }
        int[] leafIdxs = new int[numLeavesWithTheta];
        int[][] condensedThetaIdxs = new int[numLeavesWithTheta][];
        int i2 = 0;
        int counter = 0;
        while (i2 < numnodes) {
            if (thetaIdxsLens[i2] != 0) {
                leafIdxs[counter] = i2;
                condensedThetaIdxs[counter] = new int[thetaIdxsLens[i2]];
                System.arraycopy(thetaIdxs[i2], 0, condensedThetaIdxs[counter], 0, thetaIdxsLens[i2]);
                ++counter;
            }
            ++i2;
        }
        Object[] retn = new Object[]{leafIdxs, condensedThetaIdxs};
        return retn;
    }

    public static Regtree preprocess_inst_splits(Regtree tree, double[][] X) {
        Regtree newtree = new Regtree(tree);
        int numnodes = newtree.node.length;
        if (numnodes == 0) {
            throw new RuntimeException("Tree must exist.");
        }
        if (newtree.cut.length != numnodes) {
            throw new RuntimeException("cut must be Nx1 vector.");
        }
        if (newtree.nodepred.length != numnodes) {
            throw new RuntimeException("nodepred must be Nx1 vector.");
        }
        if (newtree.parent.length != numnodes) {
            throw new RuntimeException("parent must be Nx1 matrix.");
        }
        if (newtree.children.length != numnodes) {
            throw new RuntimeException("children must be Nx2 matrix.");
        }
        newtree.weights = new double[numnodes];
        newtree.weightedpred = new double[numnodes];
        if (X == null) {
            int i = 0;
            while (i < numnodes) {
                newtree.weights[i] = 0.0;
                newtree.weightedpred[i] = newtree.logModel == 1 ? Math.pow(10.0, newtree.nodepred[i]) : newtree.nodepred[i];
                ++i;
            }
            newtree.preprocessed = true;
            return newtree;
        }
        int numinsts = X.length;
        int thetacols = newtree.npred - X[0].length;
        LinkedList<Integer> queue = new LinkedList<Integer>();
        Arrays.fill(newtree.weights, 0.0);
        Arrays.fill(newtree.weightedpred, 0.0);
        int i = 0;
        while (i < numinsts) {
            queue.add(0);
            block2: while (!queue.isEmpty()) {
                int thisnode = (Integer)queue.poll();
                while (true) {
                    int splitvar = newtree.var[thisnode];
                    double cutoff = newtree.cut[thisnode];
                    int left_kid = newtree.children[thisnode][0];
                    int right_kid = newtree.children[thisnode][1];
                    if (splitvar == 0) {
                        if (newtree.logModel == 1) {
                            int n = thisnode;
                            newtree.weightedpred[n] = newtree.weightedpred[n] + Math.pow(10.0, newtree.nodepred[thisnode]) / (double)numinsts;
                        } else {
                            int n = thisnode;
                            newtree.weightedpred[n] = newtree.weightedpred[n] + newtree.nodepred[thisnode] / (double)numinsts;
                        }
                        int n = thisnode;
                        newtree.weights[n] = newtree.weights[n] + 1.0 / (double)numinsts;
                        continue block2;
                    }
                    if (Math.abs(splitvar) <= thetacols) {
                        queue.add(right_kid);
                        thisnode = left_kid;
                        continue;
                    }
                    if (splitvar > 0) {
                        thisnode = X[i][splitvar - 1 - thetacols] <= cutoff ? left_kid : right_kid;
                        continue;
                    }
                    int x = (int)X[i][-splitvar - 1 - thetacols];
                    int split = newtree.catsplit[(int)cutoff][x - 1];
                    if (split == 0) {
                        thisnode = left_kid;
                        continue;
                    }
                    if (split != 1) break;
                    thisnode = right_kid;
                }
                throw new RuntimeException("Missing value -- not allowed in this implementation.");
            }
            ++i;
        }
        RegtreeFwd.cut_instance_leaf_split_helper(newtree, thetacols, 0);
        newtree.preprocessed = true;
        return newtree;
    }

    private static int cut_instance_leaf_split_helper(Regtree tree, int thetacols, int thisnode) {
        int left_kid = tree.children[thisnode][0];
        int right_kid = tree.children[thisnode][1];
        int ret = 0;
        if (tree.var[thisnode] == 0) {
            ret = 1;
        } else if (RegtreeFwd.cut_instance_leaf_split_helper(tree, thetacols, left_kid) + RegtreeFwd.cut_instance_leaf_split_helper(tree, thetacols, right_kid) == 2 && Math.abs(tree.var[thisnode]) > thetacols) {
            RegtreeFwd.make_into_leaf(tree, thisnode);
            ret = 1;
        }
        if (thisnode != 0) {
            int n = tree.parent[thisnode];
            tree.weightedpred[n] = tree.weightedpred[n] + tree.weightedpred[thisnode];
            int n2 = tree.parent[thisnode];
            tree.weights[n2] = tree.weights[n2] + tree.weights[thisnode];
        }
        if (ret == 0 && tree.weights[thisnode] == 0.0) {
            RegtreeFwd.make_into_leaf(tree, thisnode);
            ret = 1;
        }
        return ret;
    }

    private static void make_into_leaf(Regtree tree, int thisnode) {
        tree.var[thisnode] = 0;
        int left_kid = tree.children[thisnode][0];
        int right_kid = tree.children[thisnode][1];
        tree.children[thisnode][0] = 0;
        tree.children[thisnode][1] = 0;
        if (tree.resultsStoredInLeaves) {
            tree.ysub[thisnode] = new double[tree.nodesize[thisnode]];
            System.arraycopy(tree.ysub[left_kid], 0, tree.ysub[thisnode], 0, tree.nodesize[left_kid]);
            System.arraycopy(tree.ysub[right_kid], 0, tree.ysub[thisnode], tree.nodesize[left_kid], tree.nodesize[right_kid]);
            tree.is_censored[thisnode] = new boolean[tree.nodesize[thisnode]];
            System.arraycopy(tree.is_censored[left_kid], 0, tree.is_censored[thisnode], 0, tree.nodesize[left_kid]);
            System.arraycopy(tree.is_censored[right_kid], 0, tree.is_censored[thisnode], tree.nodesize[left_kid], tree.nodesize[right_kid]);
        } else {
            int i = 0;
            while (i < tree.ysub[thisnode].length) {
                tree.ysub[thisnode][i] = tree.ysub[left_kid][i] + tree.ysub[right_kid][i];
                ++i;
            }
        }
        tree.recalculateStats(thisnode);
    }
}

