/*
 * Decompiled with CFR 0.152.
 */
package ca.ubc.cs.beta.models.fastrf;

import ca.ubc.cs.beta.models.fastrf.Regtree;
import ca.ubc.cs.beta.models.fastrf.RegtreeBuildParams;
import ca.ubc.cs.beta.models.fastrf.RegtreeFit;
import ca.ubc.cs.beta.models.fastrf.RegtreeFwd;
import ca.ubc.cs.beta.models.fastrf.utils.Utils;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Random;

public strictfp class RandomForest
implements Serializable {
    private static final long serialVersionUID = 5204746081208095705L;
    public int numTrees;
    public Regtree[] Trees;
    public int logModel;
    public static final double MIN_VARIANCE_RESULT = -1.0 * Math.pow(10.0, -6.0);
    public double minVariance;
    private RegtreeBuildParams buildParams;
    private static Object[][] arr = new Object[2][];

    public RandomForest(int numtrees, RegtreeBuildParams buildParams) {
        if (numtrees <= 0) {
            throw new RuntimeException("Invalid number of regression trees in forest: " + numtrees);
        }
        this.logModel = buildParams.logModel;
        this.numTrees = numtrees;
        this.Trees = new Regtree[numtrees];
        this.minVariance = buildParams.minVariance;
        this.buildParams = buildParams;
    }

    public boolean equals(Object o) {
        if (o instanceof RandomForest) {
            RandomForest rf = (RandomForest)o;
            if (this.numTrees != rf.numTrees) {
                return false;
            }
            if (this.logModel != rf.logModel) {
                return false;
            }
            return Arrays.equals(this.Trees, rf.Trees);
        }
        return false;
    }

    public RegtreeBuildParams getBuildParams() {
        return this.buildParams;
    }

    public int hashCode() {
        return this.logModel ^ 2 * this.numTrees ^ Arrays.deepHashCode(this.Trees);
    }

    public int matlabHashCode() {
        return Math.abs(this.hashCode()) % 32452867;
    }

    public static RandomForest learnModel(int numTrees, double[][] allTheta, double[][] allX, int[][] theta_inst_idxs, double[] y, RegtreeBuildParams params) {
        Random r = params.random;
        if (r == null) {
            r = new Random();
            if (params.seed != -1L) {
                r.setSeed(params.seed);
            }
        }
        int N = y.length;
        int[][] dataIdxs = new int[numTrees][N];
        for (int i = 0; i < numTrees; ++i) {
            for (int j = 0; j < N; ++j) {
                dataIdxs[i][j] = r.nextInt(N);
            }
        }
        return RandomForest.learnModel(numTrees, allTheta, allX, theta_inst_idxs, y, dataIdxs, params);
    }

    public static RandomForest learnModelImputedValues(int numTrees, double[][] allTheta, double[][] allX, int[][] theta_inst_idxs, double[][] y, RegtreeBuildParams params) {
        Random r = params.random;
        if (r == null) {
            r = new Random();
            if (params.seed != -1L) {
                r.setSeed(params.seed);
            }
        }
        int N = y.length;
        int[][] dataIdxs = new int[numTrees][N];
        for (int i = 0; i < numTrees; ++i) {
            for (int j = 0; j < N; ++j) {
                dataIdxs[i][j] = r.nextInt(N);
            }
        }
        return RandomForest.learnModelImputedValues(numTrees, allTheta, allX, theta_inst_idxs, y, dataIdxs, params);
    }

    public static void fixInputs(double[] input) {
        for (int j = 0; j < input.length; ++j) {
            double val = input[j];
            long raw = Double.doubleToLongBits(val);
            val = Double.longBitsToDouble(raw);
            input[j] = (float)val;
        }
    }

    public static void fixInputs(double[][] input) {
        for (int i = 0; i < input.length; ++i) {
            RandomForest.fixInputs(input[i]);
        }
    }

    public static String print(Object o) {
        if (o instanceof Integer) {
            return o.toString();
        }
        if (o instanceof double[]) {
            return Arrays.toString((double[])o);
        }
        if (o instanceof int[][]) {
            return Arrays.deepToString((Object[])((int[][])o));
        }
        if (o instanceof double[][]) {
            return Arrays.deepToString((Object[])((double[][])o));
        }
        if (o instanceof Object) {
            return o.toString();
        }
        System.out.println(o);
        throw new IllegalStateException();
    }

    public static boolean equalTest(Object o, Object o2) {
        if (o instanceof Integer) {
            return o.equals(o2);
        }
        if (o instanceof double[]) {
            return Arrays.equals((double[])o, (double[])o2);
        }
        if (o instanceof int[][]) {
            return Arrays.deepEquals((Object[])((int[][])o), (Object[])((int[][])o2));
        }
        if (o instanceof double[][]) {
            return Arrays.deepEquals((Object[])((double[][])o), (Object[])((double[][])o2));
        }
        if (o instanceof Object) {
            return o.equals(o2);
        }
        System.out.println(o);
        throw new IllegalStateException();
    }

    public static void main3(String[] args) {
        RandomForest matlabForest = RandomForest.fromForestFile("/tmp/RandomForest4433899701602217560Build");
        RandomForest javaForest = RandomForest.fromForestFile("/tmp/RandomForest8044841660237959103Build");
        double[][] configs = new double[][]{{2.0, 3.0, 2.0, 2.0, 4.0, 12.0, 5.0, 3.0, 5.0, 3.0, 19.0, 6.0, 4.0, 5.0, 2.0, 1.0, 1.0, 3.0, 16.0, 3.0, 5.0, 2.0, 1.0, 2.0, 7.0, 2.0}};
        int[] treesUsed = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
        System.out.println(matlabForest.equals(javaForest));
        System.out.println(Arrays.deepToString((Object[])RandomForest.applyRoundedMarginal(matlabForest, treesUsed, configs)));
        System.out.println(Arrays.deepToString((Object[])RandomForest.applyRoundedMarginal(javaForest, treesUsed, configs)));
    }

    public static void main2(String[] args) {
        RandomForest f1 = RandomForest.fromFile("/tmp/RandomForest7400514883271338896Build");
        RandomForest f2 = RandomForest.fromFile("/tmp/RandomForest7510701929783488816Build");
        System.out.println(f1.matlabHashCode());
        System.out.println("---------");
        System.out.println(f2.matlabHashCode());
        System.out.println("f1var:" + Arrays.toString(f1.Trees[2].var));
        System.out.println("f2var:" + Arrays.toString(f2.Trees[2].var));
        System.out.println(Arrays.equals(f1.Trees[2].var, f2.Trees[2].var));
        System.out.println("f1catsplit:" + Arrays.deepToString((Object[])f1.Trees[2].catsplit));
        System.out.println("f2catsplit:" + Arrays.deepToString((Object[])f2.Trees[2].catsplit));
        System.out.println(Arrays.deepEquals((Object[])f1.Trees[2].catsplit, (Object[])f2.Trees[2].catsplit));
        System.out.println("f1cut:" + Arrays.toString(f1.Trees[2].cut));
        System.out.println("f2cut:" + Arrays.toString(f2.Trees[2].cut));
        System.out.println(Arrays.equals(f1.Trees[2].cut, f2.Trees[2].cut));
        for (int i = 0; i < arr[1].length; ++i) {
            System.out.println("a1:(" + i + ")" + RandomForest.print(arr[0][i]));
            System.out.println("a2:(" + i + ")" + RandomForest.print(arr[1][i]));
            System.out.println("=:(" + i + ")" + RandomForest.equalTest(arr[0][i], arr[1][i]));
            System.out.println("");
        }
        System.out.println(f1.equals(f2));
        double[][] configs = new double[][]{{2.0, 3.0, 2.0, 2.0, 4.0, 12.0, 5.0, 3.0, 5.0, 3.0, 19.0, 6.0, 4.0, 5.0, 2.0, 1.0, 1.0, 3.0, 16.0, 3.0, 5.0, 2.0, 1.0, 2.0, 7.0, 2.0}};
        int[] treesUsed = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
        System.out.println(Arrays.deepToString((Object[])RandomForest.applyMarginal(f1, treesUsed, configs)));
        System.out.println(Arrays.deepToString((Object[])RandomForest.applyMarginal(f2, treesUsed, configs)));
    }

    public static RandomForest fromFile(String s) {
        File f = new File(s);
        try {
            ObjectInputStream in = new ObjectInputStream(new FileInputStream(f));
            int numTrees = in.readInt();
            double[][] allTheta = (double[][])in.readObject();
            double[][] allX = (double[][])in.readObject();
            int[][] theta_inst_idxs = (int[][])in.readObject();
            double[] y = (double[])in.readObject();
            int[][] dataIdxs = (int[][])in.readObject();
            RegtreeBuildParams params = (RegtreeBuildParams)in.readObject();
            in.close();
            RandomForest.fixInputs(allTheta);
            RandomForest.fixInputs(y);
            RandomForest.fixInputs(allX);
            Object[] obj = new Object[]{numTrees, allTheta, allX, theta_inst_idxs, y, dataIdxs, params};
            if (arr[0] == null) {
                RandomForest.arr[0] = obj;
            } else {
                RandomForest.arr[1] = obj;
            }
            return RandomForest.learnModel(numTrees, allTheta, allX, theta_inst_idxs, y, dataIdxs, params);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static RandomForest fromForestFile(String s) {
        File f = new File(s);
        try {
            ObjectInputStream in = new ObjectInputStream(new FileInputStream(f));
            RandomForest forest = (RandomForest)in.readObject();
            in.close();
            return forest;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static void main(String[] args) {
        String badParams = "/tmp/RandomForestParams3027553244461234091Build";
        String goodParams = "/tmp/RandomForestParams4017574993756479929Build";
        File f1 = new File(goodParams);
        File f2 = new File(badParams);
        try {
            ObjectInputStream in = new ObjectInputStream(new FileInputStream(f1));
            ObjectInputStream in2 = new ObjectInputStream(new FileInputStream(f2));
            int numTrees = in.readInt();
            int numTrees2 = in2.readInt();
            System.out.println(numTrees - numTrees2);
            double[][] allTheta = (double[][])in.readObject();
            double[][] allTheta2 = (double[][])in2.readObject();
            for (int i = 0; i < allTheta.length; ++i) {
                for (int j = 0; j < allTheta[i].length; ++j) {
                    if (allTheta[i][j] - allTheta2[i][j] == 0.0) continue;
                    System.out.println(i + "," + j + ":" + allTheta[i][j] + " " + allTheta2[i][j]);
                }
            }
            System.out.println(Arrays.deepEquals((Object[])allTheta, (Object[])allTheta2));
            System.out.println("a");
            double[][] allX = (double[][])in.readObject();
            double[][] allX2 = (double[][])in2.readObject();
            System.out.println(Arrays.deepEquals((Object[])allX, (Object[])allX2));
            System.out.println("b");
            int[][] theta_inst_idxs = (int[][])in.readObject();
            int[][] theta_inst_idxs2 = (int[][])in2.readObject();
            System.out.println(Arrays.deepEquals((Object[])theta_inst_idxs, (Object[])theta_inst_idxs2));
            System.out.println("c");
            double[] y = (double[])in.readObject();
            double[] y2 = (double[])in2.readObject();
            System.out.println(Arrays.equals(y, y2));
            System.out.println("d");
            int[][] dataIdxs = (int[][])in.readObject();
            int[][] dataIdxs2 = (int[][])in2.readObject();
            System.out.println(Arrays.deepEquals((Object[])dataIdxs, (Object[])dataIdxs2));
            System.out.println("e");
            in.close();
            in2.close();
        }
        catch (Exception e) {
            // empty catch block
        }
    }

    public static void save(RandomForest forest) {
        try {
            File f = File.createTempFile("RandomForest", "Build");
            RandomForest.save(forest, f);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void save(RandomForest forest, String filename) {
        File f = new File(filename);
        RandomForest.save(forest, f);
    }

    public static void save(RandomForest forest, File f) {
        try {
            ObjectOutputStream o = new ObjectOutputStream(new FileOutputStream(f));
            o.writeObject(forest);
            System.out.println("Forest Saved To:" + f.getAbsolutePath());
            o.close();
        }
        catch (IOException e) {
            System.err.println(e);
        }
    }

    public static RandomForest learnModel(int numTrees, double[][] allTheta, double[][] allX, int[][] theta_inst_idxs, double[] y, int[][] dataIdxs, RegtreeBuildParams params) {
        boolean writeOutput = false;
        if (writeOutput) {
            try {
                File f = File.createTempFile("RandomForestParams", "Build");
                ObjectOutputStream o = new ObjectOutputStream(new FileOutputStream(f));
                o.writeInt(numTrees);
                o.writeObject(allTheta);
                o.writeObject(allX);
                o.writeObject(theta_inst_idxs);
                o.writeObject(y);
                o.writeObject(dataIdxs);
                o.writeObject(params);
                System.out.println("Calls written & deleted to:" + f.getAbsolutePath());
                o.close();
            }
            catch (IOException e) {
                System.err.println(e);
            }
        }
        if (dataIdxs.length != numTrees) {
            throw new RuntimeException("length(dataIdxs) must be equal to numtrees.");
        }
        RandomForest rf = new RandomForest(numTrees, params);
        for (int i = 0; i < numTrees; ++i) {
            int N = dataIdxs[i].length;
            int[][] this_theta_inst_idxs = new int[N][];
            double[] thisy = new double[N];
            for (int j = 0; j < N; ++j) {
                int idx = dataIdxs[i][j];
                this_theta_inst_idxs[j] = theta_inst_idxs[idx];
                thisy[j] = y[idx];
            }
            rf.Trees[i] = RegtreeFit.fit(allTheta, allX, this_theta_inst_idxs, thisy, params);
        }
        return rf;
    }

    public static RandomForest learnModelImputedValues(int numTrees, double[][] allTheta, double[][] allX, int[][] theta_inst_idxs, double[][] y, int[][] dataIdxs, RegtreeBuildParams params) {
        boolean writeOutput = false;
        if (writeOutput) {
            try {
                File f = File.createTempFile("RandomForestParams", "Build");
                ObjectOutputStream o = new ObjectOutputStream(new FileOutputStream(f));
                o.writeInt(numTrees);
                o.writeObject(allTheta);
                o.writeObject(allX);
                o.writeObject(theta_inst_idxs);
                o.writeObject(y);
                o.writeObject(dataIdxs);
                o.writeObject(params);
                System.out.println("Calls written & deleted to:" + f.getAbsolutePath());
                o.close();
            }
            catch (IOException e) {
                System.err.println(e);
            }
        }
        if (dataIdxs.length != numTrees) {
            throw new RuntimeException("length(dataIdxs) must be equal to numtrees.");
        }
        RandomForest rf = new RandomForest(numTrees, params);
        for (int i = 0; i < numTrees; ++i) {
            int N = dataIdxs[i].length;
            int[][] this_theta_inst_idxs = new int[N][];
            double[] thisy = new double[N];
            for (int j = 0; j < N; ++j) {
                int idx = dataIdxs[i][j];
                this_theta_inst_idxs[j] = theta_inst_idxs[idx];
            }
            rf.Trees[i] = RegtreeFit.fit(allTheta, allX, this_theta_inst_idxs, y[i], params);
            continue;
        }
        return rf;
    }

    public static int[][] fwd(RandomForest forest, double[][] X) {
        int[][] retn = new int[forest.numTrees][X.length];
        for (int i = 0; i < forest.numTrees; ++i) {
            int[] result = RegtreeFwd.fwd(forest.Trees[i], X);
            System.arraycopy(result, 0, retn[i], 0, result.length);
        }
        return retn;
    }

    public static double[][] apply(RandomForest forest, double[][] X) {
        int i;
        double[][] retn = new double[X.length][2];
        for (i = 0; i < forest.numTrees; ++i) {
            int[] result = RegtreeFwd.fwd(forest.Trees[i], X);
            for (int j = 0; j < X.length; ++j) {
                double pred = forest.Trees[i].nodepred[result[j]];
                double var = forest.Trees[i].nodevar[result[j]];
                if (forest.logModel > 0) {
                    if (forest.buildParams.brokenVarianceCalculation) {
                        pred = Math.log10(pred);
                    } else {
                        double mu_l10;
                        double test_mu_n = pred;
                        double test_var_n = var;
                        double var_ln = Math.log(test_var_n / (test_mu_n * test_mu_n) + 1.0);
                        double mu_ln = Math.log(test_mu_n) - var_ln / 2.0;
                        double var_l10 = var_ln / Math.log(10.0) / Math.log(10.0);
                        pred = mu_l10 = mu_ln / Math.log(10.0);
                        var = var_l10;
                    }
                }
                double[] dArray = retn[j];
                dArray[0] = dArray[0] + pred;
                double[] dArray2 = retn[j];
                dArray2[1] = dArray2[1] + (var + pred * pred);
            }
        }
        for (i = 0; i < X.length; ++i) {
            double[] dArray = retn[i];
            dArray[0] = dArray[0] / (double)forest.numTrees;
            double[] dArray3 = retn[i];
            dArray3[1] = dArray3[1] / (double)forest.numTrees;
            double[] dArray4 = retn[i];
            dArray4[1] = dArray4[1] - retn[i][0] * retn[i][0];
            retn[i][1] = retn[i][1] * ((double)forest.numTrees / (double)Math.max(1, forest.numTrees - 1));
            if (retn[i][1] < MIN_VARIANCE_RESULT) {
                System.err.println("[WARN]: Variance is less than " + MIN_VARIANCE_RESULT + " > " + retn[i][1]);
                assert (retn[i][1] > MIN_VARIANCE_RESULT);
            }
            retn[i][1] = Math.max(forest.minVariance, retn[i][1]);
        }
        return retn;
    }

    public static double round(double val) {
        float fval = (float)val;
        int bits = Float.floatToRawIntBits(fval);
        float nval = Float.intBitsToFloat(bits &= 0xFFFFF800);
        return nval;
    }

    public static double[][] applyRoundedMarginal(RandomForest forest, int[] tree_idxs_used, double[][] Theta) {
        double[][] results = RandomForest.applyMarginal(forest, tree_idxs_used, Theta);
        for (int i = 0; i < results.length; ++i) {
            for (int j = 0; j < results[i].length; ++j) {
                double result = results[i][j];
                results[i][j] = RandomForest.round(result);
            }
        }
        return results;
    }

    public static double[][] applyRoundedMarginal(RandomForest forest, int[] tree_idxs_used, double[][] Theta, double[][] X) {
        double[][] results = RandomForest.applyMarginal(forest, tree_idxs_used, Theta, X);
        for (int i = 0; i < results.length; ++i) {
            for (int j = 0; j < results[i].length; ++j) {
                double result = results[i][j];
                results[i][j] = RandomForest.round(result);
            }
        }
        return results;
    }

    public static double[] classify(RandomForest forest, double[][] X) {
        double[][] votes = new double[X.length][forest.numTrees];
        for (int i = 0; i < forest.numTrees; ++i) {
            double[] res = Regtree.classify(forest.Trees[i], X);
            for (int j = 0; j < res.length; ++j) {
                votes[j][i] = res[j];
            }
        }
        double[] retn = new double[X.length];
        for (int i = 0; i < X.length; ++i) {
            double[] best = Utils.mode(votes[i]);
            retn[i] = best[(int)(Math.random() * (double)best.length)];
        }
        return retn;
    }

    public static double[][] marginalTreePredictions(RandomForest forest, int[] tree_idxs_used, double[][] Theta) {
        return RandomForest.marginalTreePredictions(forest, tree_idxs_used, Theta, null);
    }

    public static double[][] applyMarginal(RandomForest forest, int[] tree_idxs_used, double[][] Theta) {
        return RandomForest.applyMarginal(forest, tree_idxs_used, Theta, null);
    }

    public static double[][] applyMarginal(RandomForest forest, int[] tree_idxs_used, double[][] Theta, double[][] X) {
        int i;
        int nTheta = Theta.length;
        int nTrees = tree_idxs_used.length;
        double[][] retn = new double[nTheta][2];
        for (i = 0; i < nTrees; ++i) {
            Object[] result = RegtreeFwd.marginalFwd(forest.Trees[tree_idxs_used[i]], Theta, X);
            double[] preds = (double[])result[0];
            double[] vars = (double[])result[1];
            for (int j = 0; j < nTheta; ++j) {
                double pred = preds[j];
                if (forest.logModel > 0) {
                    pred = Math.log10(pred);
                }
                double[] dArray = retn[j];
                dArray[0] = dArray[0] + pred;
                double[] dArray2 = retn[j];
                dArray2[1] = dArray2[1] + (vars[j] + pred * pred);
            }
        }
        for (i = 0; i < nTheta; ++i) {
            double[] dArray = retn[i];
            dArray[0] = dArray[0] / (double)nTrees;
            double[] dArray3 = retn[i];
            dArray3[1] = dArray3[1] / (double)nTrees;
            double[] dArray4 = retn[i];
            dArray4[1] = dArray4[1] - retn[i][0] * retn[i][0];
            retn[i][1] = retn[i][1] * (((double)forest.numTrees + 0.0) / (double)Math.max(1, forest.numTrees - 1));
            if (retn[i][1] < MIN_VARIANCE_RESULT) {
                System.err.println("[WARN]: Variance is less than " + MIN_VARIANCE_RESULT + " > " + retn[i][1]);
                assert (retn[i][1] > MIN_VARIANCE_RESULT);
            }
            retn[i][1] = Math.max(forest.minVariance, retn[i][1]);
        }
        return retn;
    }

    public static double[][] marginalTreePredictions(RandomForest forest, int[] tree_idxs_used, double[][] Theta, double[][] X) {
        int nTheta = Theta.length;
        int nTrees = tree_idxs_used.length;
        double[][] retn = new double[nTheta][nTrees];
        for (int i = 0; i < nTrees; ++i) {
            Object[] result = RegtreeFwd.marginalFwd(forest.Trees[tree_idxs_used[i]], Theta, X);
            double[] preds = (double[])result[0];
            for (int j = 0; j < nTheta; ++j) {
                double pred = preds[j];
                if (forest.logModel > 0) {
                    pred = Math.log10(pred);
                }
                retn[j][i] = pred;
            }
        }
        return retn;
    }

    public static RandomForest preprocessForest(RandomForest forest, double[][] X) {
        RandomForest prepared = new RandomForest(forest.numTrees, forest.buildParams);
        for (int i = 0; i < forest.numTrees; ++i) {
            prepared.Trees[i] = RegtreeFwd.preprocess_inst_splits(forest.Trees[i], X);
        }
        return prepared;
    }

    public static void preprocessForestForClassification(RandomForest forest) {
        for (int i = 0; i < forest.numTrees; ++i) {
            RegtreeFwd.preprocess_for_classification(forest.Trees[i]);
        }
    }
}

