/*
 * Decompiled with CFR 0.152.
 */
package ca.ubc.cs.beta.models.fastrf;

import ca.ubc.cs.beta.models.fastrf.Regtree;
import ca.ubc.cs.beta.models.fastrf.RegtreeBuildParams;
import ca.ubc.cs.beta.models.fastrf.RegtreeFit;
import ca.ubc.cs.beta.models.fastrf.RegtreeFwd;
import ca.ubc.cs.beta.models.fastrf.Utils;
import ca.ubc.cs.beta.models.fastrf.WeibullFit;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Random;

public class RandomForest
implements Serializable {
    private static final long serialVersionUID = 5204746081208095703L;
    public int numTrees;
    public Regtree[] Trees;

    public RandomForest(int numtrees) {
        if (numtrees <= 0) {
            throw new RuntimeException("Invalid number of regression trees in forest: " + numtrees);
        }
        this.numTrees = numtrees;
        this.Trees = new Regtree[numtrees];
    }

    public static RandomForest learnModel(int numTrees, double[][] allTheta, double[][] allX, int[][] theta_inst_idxs, double[] y, boolean[] cens, RegtreeBuildParams params) {
        Random r = params.random;
        if (r == null) {
            r = new Random();
            if (params.seed != -1L) {
                r.setSeed(params.seed);
            }
        }
        int N = y.length;
        int[][] dataIdxs = new int[numTrees][N];
        int i = 0;
        while (i < numTrees) {
            int j = 0;
            while (j < N) {
                dataIdxs[i][j] = r.nextInt(N);
                ++j;
            }
            ++i;
        }
        return RandomForest.learnModel(numTrees, allTheta, allX, theta_inst_idxs, y, cens, dataIdxs, params);
    }

    public static RandomForest learnModel(int numTrees, double[][] allTheta, double[][] allX, int[][] theta_inst_idxs, double[] y, boolean[] cens, int[][] dataIdxs, RegtreeBuildParams params) {
        if (dataIdxs.length != numTrees) {
            throw new RuntimeException("length(dataIdxs) must be equal to numtrees.");
        }
        ArrayList<Integer> censored_idxs = new ArrayList<Integer>();
        int i = 0;
        while (i < cens.length) {
            if (cens[i]) {
                censored_idxs.add(i);
            }
            ++i;
        }
        boolean storeResponses = params.storeResponses;
        if (censored_idxs.size() > 0) {
            params.storeResponses = true;
        }
        RandomForest rf = new RandomForest(numTrees);
        int i2 = 0;
        while (i2 < numTrees) {
            int N = dataIdxs[i2].length;
            int[][] this_theta_inst_idxs = new int[N][];
            double[] thisy = new double[N];
            boolean[] thiscens = new boolean[N];
            int j = 0;
            while (j < N) {
                int idx = dataIdxs[i2][j];
                this_theta_inst_idxs[j] = theta_inst_idxs[idx];
                thisy[j] = y[idx];
                thiscens[j] = cens[idx];
                ++j;
            }
            rf.Trees[i2] = RegtreeFit.fit(allTheta, allX, this_theta_inst_idxs, thisy, thiscens, params);
            ++i2;
        }
        if (censored_idxs.size() > 0) {
            double[] lowerBoundForSamples = new double[censored_idxs.size()];
            double[][] censoredX = new double[censored_idxs.size()][allTheta[0].length + allX[0].length];
            int i3 = 0;
            while (i3 < lowerBoundForSamples.length) {
                int idx = (Integer)censored_idxs.get(i3);
                lowerBoundForSamples[i3] = y[idx];
                int j = 0;
                while (j < allTheta[0].length) {
                    censoredX[i3][j] = allTheta[theta_inst_idxs[idx][0]][j];
                    ++j;
                }
                j = 0;
                while (j < allX[0].length) {
                    censoredX[i3][j + allTheta[0].length] = allX[theta_inst_idxs[idx][1]][j];
                    ++j;
                }
                ++i3;
            }
            double maxY = y[0];
            int i4 = 1;
            while (i4 < y.length) {
                maxY = Math.max(maxY, y[i4]);
                ++i4;
            }
            if (params.logModel == 1) {
                i4 = 0;
                while (i4 < lowerBoundForSamples.length) {
                    lowerBoundForSamples[i4] = Math.pow(10.0, lowerBoundForSamples[i4]);
                    ++i4;
                }
                maxY = Math.pow(10.0, maxY);
            }
            if (params.logModel == 3) {
                params.kappa = Math.log10(params.kappa);
            }
            double valueForAllCens = params.cutoffPenaltyFactor * params.kappa;
            valueForAllCens = Math.max(valueForAllCens, maxY);
            int[][] numOccurrence = new int[censored_idxs.size()][numTrees];
            int[] numCensoredBefore = new int[y.length];
            int i5 = 1;
            while (i5 < y.length) {
                numCensoredBefore[i5] = numCensoredBefore[i5 - 1] + (cens[i5 - 1] ? 1 : 0);
                ++i5;
            }
            int m = 0;
            while (m < numTrees) {
                int i6 = 0;
                while (i6 < dataIdxs[m].length) {
                    int idx = dataIdxs[m][i6];
                    if (cens[idx]) {
                        int[] nArray = numOccurrence[numCensoredBefore[idx]];
                        int n = m;
                        nArray[n] = nArray[n] + 1;
                    }
                    ++i6;
                }
                ++m;
            }
            Object[][] result = RandomForest.hallucinateData(rf, censoredX, numOccurrence, lowerBoundForSamples, valueForAllCens, params.logModel);
            double[][][] convertedResult = new double[censored_idxs.size()][numTrees][];
            int i7 = 0;
            while (i7 < censored_idxs.size()) {
                int m2 = 0;
                while (m2 < numTrees) {
                    convertedResult[i7][m2] = (double[])result[i7][m2];
                    ++m2;
                }
                ++i7;
            }
            params.storeResponses = storeResponses;
            int m3 = 0;
            while (m3 < numTrees) {
                int N = dataIdxs[m3].length;
                int[][] this_theta_inst_idxs = new int[N][];
                double[] thisy = new double[N];
                boolean[] thiscens = new boolean[N];
                int i8 = 0;
                while (i8 < N) {
                    int idx = dataIdxs[m3][i8];
                    this_theta_inst_idxs[i8] = theta_inst_idxs[idx];
                    thisy[i8] = y[idx];
                    thiscens[i8] = false;
                    if (cens[idx]) {
                        double[] samples = convertedResult[numCensoredBefore[idx]][m3];
                        int[] nArray = numOccurrence[numCensoredBefore[idx]];
                        int n = m3;
                        int n2 = nArray[n];
                        nArray[n] = n2 - 1;
                        thisy[i8] = samples[n2];
                    }
                    ++i8;
                }
                rf.Trees[m3] = RegtreeFit.fit(allTheta, allX, this_theta_inst_idxs, thisy, thiscens, params);
                ++m3;
            }
        }
        return rf;
    }

    public static int[][] fwd(RandomForest forest, double[][] X) {
        int[][] retn = new int[forest.numTrees][X.length];
        int i = 0;
        while (i < forest.numTrees) {
            int[] result = RegtreeFwd.fwd(forest.Trees[i], X);
            System.arraycopy(result, 0, retn[i], 0, result.length);
            ++i;
        }
        return retn;
    }

    public static Object[][] fwdThetas(RandomForest forest, double[][] Theta) {
        Object[][] retn = new Object[forest.numTrees][2];
        int i = 0;
        while (i < forest.numTrees) {
            Object[] result = RegtreeFwd.fwdThetas(forest.Trees[i], Theta);
            int[] leafIdxs = (int[])result[0];
            int[][] ThetaIdxs = (int[][])result[1];
            retn[i][0] = leafIdxs;
            retn[i][1] = ThetaIdxs;
            ++i;
        }
        return retn;
    }

    public static double[][] apply(RandomForest forest, double[][] X) {
        double[][] treemeans = new double[X.length][forest.numTrees];
        double[][] treevars = new double[X.length][forest.numTrees];
        int i = 0;
        while (i < forest.numTrees) {
            int[] result = RegtreeFwd.fwd(forest.Trees[i], X);
            int j = 0;
            while (j < result.length) {
                treemeans[j][i] = forest.Trees[i].nodepred[result[j]];
                treevars[j][i] = forest.Trees[i].nodevar[result[j]];
                ++j;
            }
            ++i;
        }
        double[][] retn = new double[X.length][2];
        int i2 = 0;
        while (i2 < X.length) {
            retn[i2][0] = Utils.mean(treemeans[i2]);
            retn[i2][1] = Utils.var(treemeans[i2]);
            ++i2;
        }
        return retn;
    }

    public static double[][] applyMarginal(RandomForest forest, double[][] Theta) {
        return RandomForest.applyMarginal(forest, Theta, null);
    }

    public static double[][] applyMarginal(RandomForest forest, double[][] Theta, double[][] X) {
        int nTheta = Theta.length;
        double[][] treemeans = new double[nTheta][forest.numTrees];
        int i = 0;
        while (i < forest.numTrees) {
            double[] result = RegtreeFwd.marginalFwd(forest.Trees[i], Theta, X);
            int j = 0;
            while (j < nTheta) {
                treemeans[j][i] = result[j];
                ++j;
            }
            ++i;
        }
        double[][] retn = new double[nTheta][2];
        int i2 = 0;
        while (i2 < nTheta) {
            retn[i2][0] = Utils.mean(treemeans[i2]);
            retn[i2][1] = Utils.var(treemeans[i2]);
            ++i2;
        }
        return retn;
    }

    /*
     * Unable to fully structure code
     * Could not resolve type clashes
     */
    public static double[] getMarginal(RandomForest forest, double[][] X, double[][] toBeMarginalized, int[] variableIndicesForColumnsOfX) {
        if (X == null) {
            X = new double[1][0];
        }
        if (toBeMarginalized /* !! */  == null) {
            toBeMarginalized /* !! */  = new double[0][];
        }
        if (variableIndicesForColumnsOfX == null) {
            variableIndicesForColumnsOfX = new int[]{};
        }
        if (X[0].length + toBeMarginalized /* !! */ .length != forest.Trees[0].npred) {
            throw new RuntimeException("d+p must be equal to numvars.");
        }
        if (X[0].length != variableIndicesForColumnsOfX.length) {
            throw new RuntimeException("The number of columns of X and the length of variableIndicesForColumnsOfX must match up.");
        }
        toBeMarginalizedVariables = new int[forest.Trees[0].npred - variableIndicesForColumnsOfX.length];
        isXvar = new int[forest.Trees[0].npred + 1];
        counter = 0;
        next = 1;
        i = 0;
        ** GOTO lbl23
        {
            toBeMarginalizedVariables[counter++] = next++;
            do {
                if (variableIndicesForColumnsOfX[i] != next) continue block0;
                isXvar[variableIndicesForColumnsOfX[i]] = i + 1;
                ++next;
                ++i;
lbl23:
                // 2 sources

            } while (i < variableIndicesForColumnsOfX.length);
        }
        while (next <= forest.Trees[0].npred) {
            toBeMarginalizedVariables[counter++] = next++;
        }
        logModel = forest.Trees[0].logModel;
        treeMeans = new double[forest.numTrees][X == null ? 1 : X.length];
        queue = new LinkedList<Integer>();
        m = 0;
        while (m < forest.numTrees) {
            tree = forest.Trees[m];
            results = new double[tree.numNodes];
            Arrays.fill(results, 1.0);
            i = 0;
            while (i < toBeMarginalized /* !! */ .length) {
                nextvar = toBeMarginalizedVariables[i];
                counts = new int[tree.numNodes];
                numValues = toBeMarginalized /* !! */ [i].length;
                j = 0;
                while (j < numValues) {
                    queue.add(0);
                    block6: while (!queue.isEmpty()) {
                        thisnode = (Integer)queue.poll();
                        while (true) {
                            splitvar = tree.var[thisnode];
                            cutoff = tree.cut[thisnode];
                            left_kid = tree.children[thisnode][0];
                            right_kid = tree.children[thisnode][1];
                            if (splitvar == 0) {
                                v0 = thisnode;
                                counts[v0] = counts[v0] + 1;
                                continue block6;
                            }
                            if (Math.abs(splitvar) != nextvar) {
                                queue.add(right_kid);
                                thisnode = left_kid;
                                continue;
                            }
                            if (splitvar > 0) {
                                thisnode = toBeMarginalized /* !! */ [i][j] <= cutoff ? left_kid : right_kid;
                                continue;
                            }
                            x = (int)toBeMarginalized /* !! */ [i][j];
                            split = tree.catsplit[(int)cutoff][x - 1];
                            if (split == 0) {
                                thisnode = left_kid;
                                continue;
                            }
                            if (split != 1) break;
                            thisnode = right_kid;
                        }
                        throw new RuntimeException("Missing value -- not allowed in this implementation.");
                    }
                    ++j;
                }
                j = 0;
                while (j < tree.numNodes) {
                    v1 = j;
                    results[v1] = results[v1] * ((double)counts[j] * 1.0 / (double)numValues);
                    ++j;
                }
                ++i;
            }
            if (X[0].length == 0) {
                sum = 0.0;
                i = 0;
                while (i < results.length) {
                    sum += results[i];
                    ++i;
                }
                if (sum - 1.0 > 1.0E-6) {
                    throw new RuntimeException("Something is wrong. Sum: " + sum + "(" + Arrays.toString(results) + ")");
                }
                i = 0;
                while (i < results.length) {
                    pred = logModel == 1 ? Math.pow(10.0, tree.nodepred[i]) : tree.nodepred[i];
                    v2 = treeMeans[m];
                    v2[0] = v2[0] + results[i] * pred;
                    ++i;
                }
            } else {
                i = 0;
                while (i < X.length) {
                    queue.add(0);
                    block12: while (!queue.isEmpty()) {
                        thisnode = (Integer)queue.poll();
                        while (true) {
                            splitvar = tree.var[thisnode];
                            cutoff = tree.cut[thisnode];
                            left_kid = tree.children[thisnode][0];
                            right_kid = tree.children[thisnode][1];
                            varidx = isXvar[Math.abs(splitvar)];
                            if (splitvar == 0) {
                                pred = logModel == 1 ? Math.pow(10.0, tree.nodepred[thisnode]) : tree.nodepred[thisnode];
                                v3 = treeMeans[m];
                                v4 = i;
                                v3[v4] = v3[v4] + results[thisnode] * pred;
                                continue block12;
                            }
                            if (varidx == 0) {
                                queue.add(right_kid);
                                thisnode = left_kid;
                                continue;
                            }
                            if (splitvar > 0) {
                                thisnode = X[i][varidx - 1] <= cutoff ? left_kid : right_kid;
                                continue;
                            }
                            x = (int)X[i][varidx - 1];
                            split = tree.catsplit[(int)cutoff][x - 1];
                            if (split == 0) {
                                thisnode = left_kid;
                                continue;
                            }
                            if (split != 1) break;
                            thisnode = right_kid;
                        }
                        throw new RuntimeException("Missing value -- not allowed in this implementation.");
                    }
                    ++i;
                }
            }
            ++m;
        }
        retn = new double[X.length];
        i = 0;
        while (i < retn.length) {
            sum = 0.0;
            m = 0;
            while (m < forest.numTrees) {
                sum += treeMeans[m][i];
                ++m;
            }
            retn[i] = sum * 1.0 / (double)forest.numTrees;
            if (logModel == 1) {
                retn[i] = Math.log10(retn[i]);
            }
            ++i;
        }
        return retn;
    }

    public static RandomForest preprocessForest(RandomForest forest, double[][] X) {
        RandomForest prepared = new RandomForest(forest.numTrees);
        int i = 0;
        while (i < forest.numTrees) {
            prepared.Trees[i] = RegtreeFwd.preprocess_inst_splits(forest.Trees[i], X);
            ++i;
        }
        return prepared;
    }

    public static Object[] collectData(RandomForest forest, double[][] X) {
        int[][] leafNodes = RandomForest.fwd(forest, X);
        return RandomForest.collectData(forest, leafNodes);
    }

    public static Object[] collectData(RandomForest forest, int[][] leafNodes) {
        if (!forest.Trees[0].resultsStoredInLeaves) {
            throw new RuntimeException("Cannot collect data if they were not stored.");
        }
        Object[] retn = new Object[3];
        int numdata = leafNodes[0].length;
        double[][] y_pred_all = new double[numdata][];
        boolean[][] cens_pred_all = new boolean[numdata][];
        double[][] weights_all = new double[numdata][];
        int i = 0;
        while (i < numdata) {
            int size = 0;
            int j = 0;
            while (j < forest.numTrees) {
                size += forest.Trees[j].nodesize[leafNodes[j][i]];
                ++j;
            }
            double[] y_pred = new double[size];
            boolean[] cens_pred = new boolean[size];
            double[] weights = new double[size];
            int counter = 0;
            int j2 = 0;
            while (j2 < forest.numTrees) {
                Regtree tree = forest.Trees[j2];
                int node = leafNodes[j2][i];
                int nodesize = tree.nodesize[node];
                int k = 0;
                while (k < nodesize) {
                    y_pred[counter + k] = tree.ysub[node][k];
                    cens_pred[counter + k] = tree.is_censored[node][k];
                    if (!cens_pred[counter + k]) {
                        int n = counter + k;
                        y_pred[n] = y_pred[n] + 1.0E-10;
                    }
                    ++k;
                }
                Arrays.fill(weights, counter, counter + nodesize, 1.0 / (double)nodesize);
                counter += nodesize;
                ++j2;
            }
            y_pred_all[i] = y_pred;
            cens_pred_all[i] = cens_pred;
            weights_all[i] = weights;
            ++i;
        }
        retn[0] = y_pred_all;
        retn[1] = cens_pred_all;
        retn[2] = weights_all;
        return retn;
    }

    public static Object[][] hallucinateData(RandomForest forest, double[][] X, int[][] numOccurrence, double[] lowerBoundForSamples, double valueForAllCens, int logModel) {
        Object[][] retn = new Object[X.length][forest.numTrees];
        Object[] collectedFromLeaves = RandomForest.collectData(forest, X);
        double[][] y_pred_all = (double[][])collectedFromLeaves[0];
        boolean[][] cens_pred_all = (boolean[][])collectedFromLeaves[1];
        double[][] weights_all = (double[][])collectedFromLeaves[2];
        int numdata = X.length;
        int i = 0;
        while (i < numdata) {
            int totalOccurrences = 0;
            int j = 0;
            while (j < forest.numTrees) {
                totalOccurrences += numOccurrence[i][j];
                ++j;
            }
            if (totalOccurrences != 0) {
                double[] y_pred = y_pred_all[i];
                boolean[] cens_pred = cens_pred_all[i];
                double[] weights = weights_all[i];
                double[] samples = WeibullFit.fit_dist_and_sample(y_pred, cens_pred, weights, totalOccurrences, lowerBoundForSamples[i], valueForAllCens);
                int counter = 0;
                int j2 = 0;
                while (j2 < forest.numTrees) {
                    int numOccurrenceHere = numOccurrence[i][j2];
                    double[] result = new double[numOccurrenceHere];
                    int k = 0;
                    while (k < numOccurrenceHere) {
                        result[k] = logModel == 1 ? Math.log10(samples[counter++]) : samples[counter++];
                        ++k;
                    }
                    retn[i][j2] = result;
                    ++j2;
                }
            }
            ++i;
        }
        return retn;
    }

    public static void update(RandomForest forest, double[][] newx, double[] newy, boolean[] newcens, double valueForAllCens, int logModel) {
        int i;
        if (newx.length != newy.length || newx.length != newcens.length) {
            throw new RuntimeException("Argument sizes mismatch.");
        }
        int[][] nodes = RandomForest.fwd(forest, newx);
        if (logModel == 1 || logModel == 2) {
            double[] tmp = new double[newy.length];
            i = 0;
            while (i < newy.length) {
                tmp[i] = Math.pow(10.0, newy[i]);
                ++i;
            }
            newy = tmp;
        }
        boolean hascens = false;
        i = 0;
        while (i < newcens.length) {
            if (newcens[i]) {
                hascens = true;
                break;
            }
            ++i;
        }
        double[] samples = null;
        double[][] y_pred_all = null;
        boolean[][] cens_pred_all = null;
        double[][] weights_all = null;
        if (hascens) {
            Object[] collectedFromLeaves = RandomForest.collectData(forest, nodes);
            y_pred_all = (double[][])collectedFromLeaves[0];
            cens_pred_all = (boolean[][])collectedFromLeaves[1];
            weights_all = (double[][])collectedFromLeaves[2];
        }
        boolean[][] nodeChanged = new boolean[forest.numTrees][];
        int i2 = 0;
        while (i2 < forest.numTrees) {
            nodeChanged[i2] = new boolean[forest.Trees[i2].node.length];
            ++i2;
        }
        i2 = 0;
        while (i2 < newx.length) {
            if (newcens[i2]) {
                double lowerBoundForSamples = newy[i2];
                double[] y_pred = y_pred_all[i2];
                boolean[] cens_pred = cens_pred_all[i2];
                double[] weights = weights_all[i2];
                samples = WeibullFit.fit_dist_and_sample(y_pred, cens_pred, weights, forest.numTrees, lowerBoundForSamples, valueForAllCens);
            }
            int m = 0;
            while (m < forest.numTrees) {
                Regtree tree = forest.Trees[m];
                int node = nodes[m][i2];
                nodeChanged[m][node] = true;
                int Nnode = tree.nodesize[node];
                if (tree.resultsStoredInLeaves) {
                    double[] newysub = new double[Nnode + 1];
                    boolean[] newcenssub = new boolean[Nnode + 1];
                    if (Nnode != 0) {
                        System.arraycopy(tree.ysub[node], 0, newysub, 0, Nnode);
                        System.arraycopy(tree.is_censored[node], 0, newcenssub, 0, Nnode);
                    }
                    newysub[Nnode] = newcens[i2] ? (logModel == 1 ? Math.log10(samples[m]) : samples[m]) : newy[i2];
                    tree.ysub[node] = newysub;
                    newcenssub[Nnode] = false;
                    tree.is_censored[node] = newcenssub;
                } else {
                    double[] dArray = tree.ysub[node];
                    dArray[0] = dArray[0] + newy[i2];
                    double[] dArray2 = tree.ysub[node];
                    dArray2[1] = dArray2[1] + newy[i2] * newy[i2];
                }
                int n = node;
                tree.nodesize[n] = tree.nodesize[n] + 1;
                while (node != 0) {
                    int n2 = node = tree.parent[node];
                    tree.nodesize[n2] = tree.nodesize[n2] + 1;
                }
                ++m;
            }
            ++i2;
        }
        int m = 0;
        while (m < forest.numTrees) {
            Regtree tree = forest.Trees[m];
            int i3 = 0;
            while (i3 < nodeChanged[m].length) {
                if (nodeChanged[m][i3]) {
                    tree.recalculateStats(i3);
                }
                ++i3;
            }
            ++m;
        }
    }
}

