/*
 * Decompiled with CFR 0.152.
 */
package ca.ubc.cs.beta.models.fastrf;

import ca.ubc.cs.beta.models.fastrf.Regtree;
import ca.ubc.cs.beta.models.fastrf.RegtreeBuildParams;
import ca.ubc.cs.beta.models.fastrf.RegtreeFit;
import ca.ubc.cs.beta.models.fastrf.RegtreeFwd;
import ca.ubc.cs.beta.models.fastrf.utils.CsvToDataConverter;
import ca.ubc.cs.beta.models.fastrf.utils.RfData;
import ca.ubc.cs.beta.models.fastrf.utils.Utils;
import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Random;

public strictfp class RandomForest
implements Serializable {
    private static final long serialVersionUID = 5204746081205095705L;
    public int numTrees;
    public Regtree[] Trees;
    public int logModel;
    public static final double MIN_VARIANCE_RESULT = -1.0 * Math.pow(10.0, -6.0);
    public double minVariance;
    private RegtreeBuildParams buildParams;

    public RandomForest(int numtrees, RegtreeBuildParams buildParams) {
        if (numtrees <= 0) {
            throw new RuntimeException("Invalid number of regression trees in forest: " + numtrees);
        }
        this.logModel = buildParams.logModel;
        this.numTrees = numtrees;
        this.Trees = new Regtree[numtrees];
        this.minVariance = buildParams.minVariance;
        this.buildParams = buildParams;
    }

    public static RandomForest buildRf(RfData trainData) {
        RegtreeBuildParams regTreeBuildParams = new RegtreeBuildParams(true, 10, trainData.getCatDomainSizes());
        RandomForest rf = RandomForest.learnModel(10, trainData.getTheta(), trainData.getX(), trainData.getTheta_inst_idxs(), trainData.getY(), regTreeBuildParams);
        return rf;
    }

    public static RandomForest buildRf(RfData trainData, int numTrees) {
        RegtreeBuildParams regTreeBuildParams = new RegtreeBuildParams(true, 10, trainData.getCatDomainSizes());
        RandomForest rf = RandomForest.learnModel(numTrees, trainData.getTheta(), trainData.getX(), trainData.getTheta_inst_idxs(), trainData.getY(), regTreeBuildParams);
        return rf;
    }

    public static RandomForest buildDeterministicRf(RfData trainData) {
        RegtreeBuildParams regTreeBuildParams = new RegtreeBuildParams(false, 1, 1.0, trainData.getCatDomainSizes());
        RandomForest rf = RandomForest.learnModel(1, trainData.getTheta(), trainData.getX(), trainData.getTheta_inst_idxs(), trainData.getY(), regTreeBuildParams);
        return rf;
    }

    public static RandomForest buildRfFromCSV(String csvFileName, int[] thetaColIdxs, int[] xColIdxs, int yColIdx, int[] catColIdxs) throws IOException {
        CsvToDataConverter converter = new CsvToDataConverter(csvFileName, thetaColIdxs, xColIdxs, yColIdx, catColIdxs);
        RfData trainData = converter.readDataFromCsvFile(csvFileName);
        return RandomForest.buildRf(trainData);
    }

    public boolean equals(Object o) {
        if (o instanceof RandomForest) {
            RandomForest rf = (RandomForest)o;
            if (this.numTrees != rf.numTrees) {
                return false;
            }
            if (this.logModel != rf.logModel) {
                return false;
            }
            return Arrays.equals(this.Trees, rf.Trees);
        }
        return false;
    }

    public RegtreeBuildParams getBuildParams() {
        return this.buildParams;
    }

    public int hashCode() {
        return this.logModel ^ 2 * this.numTrees ^ Arrays.deepHashCode(this.Trees);
    }

    public int matlabHashCode() {
        return Math.abs(this.hashCode()) % 32452867;
    }

    public static RandomForest learnModel(int numTrees, double[][] allTheta, double[][] allX, int[][] theta_inst_idxs, double[] y, RegtreeBuildParams params) {
        Random r = params.random;
        if (r == null) {
            r = new Random();
            if (params.seed != -1L) {
                r.setSeed(params.seed);
            }
        }
        int N = y.length;
        int[][] dataIdxs = new int[numTrees][N];
        if (params.doBootstrapping) {
            for (int i = 0; i < numTrees; ++i) {
                for (int j = 0; j < N; ++j) {
                    dataIdxs[i][j] = r.nextInt(N);
                }
            }
        } else {
            for (int i = 0; i < numTrees; ++i) {
                for (int j = 0; j < N; ++j) {
                    dataIdxs[i][j] = j;
                }
            }
        }
        return RandomForest.learnModel(numTrees, allTheta, allX, theta_inst_idxs, y, dataIdxs, params);
    }

    public static RandomForest learnModelImputedValues(int numTrees, double[][] allTheta, double[][] allX, int[][] theta_inst_idxs, double[][] y, RegtreeBuildParams params) {
        Random r = params.random;
        if (r == null) {
            r = new Random();
            if (params.seed != -1L) {
                r.setSeed(params.seed);
            }
        }
        int N = y.length;
        int[][] dataIdxs = new int[numTrees][N];
        for (int i = 0; i < numTrees; ++i) {
            for (int j = 0; j < N; ++j) {
                dataIdxs[i][j] = r.nextInt(N);
            }
        }
        return RandomForest.learnModelImputedValues(numTrees, allTheta, allX, theta_inst_idxs, y, dataIdxs, params);
    }

    public static RandomForest learnModel(int numTrees, double[][] allTheta, double[][] allX, int[][] theta_inst_idxs, double[] y, int[][] dataIdxs, RegtreeBuildParams params) {
        if (dataIdxs.length != numTrees) {
            throw new RuntimeException("length(dataIdxs) must be equal to numtrees.");
        }
        RandomForest rf = new RandomForest(numTrees, params);
        for (int i = 0; i < numTrees; ++i) {
            int N = dataIdxs[i].length;
            int[][] this_theta_inst_idxs = new int[N][];
            double[] thisy = new double[N];
            for (int j = 0; j < N; ++j) {
                int idx = dataIdxs[i][j];
                this_theta_inst_idxs[j] = theta_inst_idxs[idx];
                thisy[j] = y[idx];
            }
            rf.Trees[i] = RegtreeFit.fit(allTheta, allX, this_theta_inst_idxs, thisy, params);
        }
        return rf;
    }

    public static RandomForest learnModelImputedValues(int numTrees, double[][] allTheta, double[][] allX, int[][] theta_inst_idxs, double[][] y, int[][] dataIdxs, RegtreeBuildParams params) {
        if (dataIdxs.length != numTrees) {
            throw new RuntimeException("length(dataIdxs) must be equal to numtrees.");
        }
        RandomForest rf = new RandomForest(numTrees, params);
        for (int i = 0; i < numTrees; ++i) {
            int N = dataIdxs[i].length;
            int[][] this_theta_inst_idxs = new int[N][];
            double[] thisy = new double[N];
            for (int j = 0; j < N; ++j) {
                int idx = dataIdxs[i][j];
                this_theta_inst_idxs[j] = theta_inst_idxs[idx];
            }
            rf.Trees[i] = RegtreeFit.fit(allTheta, allX, this_theta_inst_idxs, y[i], params);
            continue;
        }
        return rf;
    }

    public static int[][] fwd(RandomForest forest, double[][] X) {
        int[][] retn = new int[forest.numTrees][X.length];
        for (int i = 0; i < forest.numTrees; ++i) {
            int[] result = RegtreeFwd.fwd(forest.Trees[i], X);
            System.arraycopy(result, 0, retn[i], 0, result.length);
        }
        return retn;
    }

    public static double[][] apply(RandomForest forest, double[][] X) {
        int i;
        double[][] retn = new double[X.length][2];
        for (i = 0; i < forest.numTrees; ++i) {
            int[] result = RegtreeFwd.fwd(forest.Trees[i], X);
            for (int j = 0; j < X.length; ++j) {
                double pred = forest.Trees[i].nodepred[result[j]];
                double var = forest.Trees[i].nodevar[result[j]];
                if (forest.logModel > 0) {
                    if (forest.buildParams.brokenVarianceCalculation) {
                        pred = Math.log10(pred);
                    } else {
                        double mu_l10;
                        double test_mu_n = pred;
                        double test_var_n = var;
                        double var_ln = Math.log(test_var_n / (test_mu_n * test_mu_n) + 1.0);
                        double mu_ln = Math.log(test_mu_n) - var_ln / 2.0;
                        double var_l10 = var_ln / Math.log(10.0) / Math.log(10.0);
                        pred = mu_l10 = mu_ln / Math.log(10.0);
                        var = var_l10;
                    }
                }
                double[] dArray = retn[j];
                dArray[0] = dArray[0] + pred;
                double[] dArray2 = retn[j];
                dArray2[1] = dArray2[1] + (var + pred * pred);
            }
        }
        for (i = 0; i < X.length; ++i) {
            double[] dArray = retn[i];
            dArray[0] = dArray[0] / (double)forest.numTrees;
            double[] dArray3 = retn[i];
            dArray3[1] = dArray3[1] / (double)forest.numTrees;
            double[] dArray4 = retn[i];
            dArray4[1] = dArray4[1] - retn[i][0] * retn[i][0];
            retn[i][1] = retn[i][1] * ((double)forest.numTrees / (double)Math.max(1, forest.numTrees - 1));
            if (retn[i][1] < MIN_VARIANCE_RESULT) {
                System.err.println("[WARN]: Variance is less than " + MIN_VARIANCE_RESULT + " > " + retn[i][1]);
                assert (retn[i][1] > MIN_VARIANCE_RESULT);
            }
            retn[i][1] = Math.max(forest.minVariance, retn[i][1]);
        }
        return retn;
    }

    public static double round(double val) {
        float fval = (float)val;
        int bits = Float.floatToRawIntBits(fval);
        float nval = Float.intBitsToFloat(bits &= 0xFFFFF800);
        return nval;
    }

    public static double[][] applyRoundedMarginal(RandomForest forest, int[] tree_idxs_used, double[][] Theta) {
        double[][] results = RandomForest.applyMarginal(forest, tree_idxs_used, Theta);
        for (int i = 0; i < results.length; ++i) {
            for (int j = 0; j < results[i].length; ++j) {
                double result = results[i][j];
                results[i][j] = RandomForest.round(result);
            }
        }
        return results;
    }

    public static double[][] applyRoundedMarginal(RandomForest forest, int[] tree_idxs_used, double[][] Theta, double[][] X) {
        double[][] results = RandomForest.applyMarginal(forest, tree_idxs_used, Theta, X);
        for (int i = 0; i < results.length; ++i) {
            for (int j = 0; j < results[i].length; ++j) {
                double result = results[i][j];
                results[i][j] = RandomForest.round(result);
            }
        }
        return results;
    }

    public static double[] classify(RandomForest forest, double[][] X) {
        double[][] votes = new double[X.length][forest.numTrees];
        for (int i = 0; i < forest.numTrees; ++i) {
            double[] res = Regtree.classify(forest.Trees[i], X);
            for (int j = 0; j < res.length; ++j) {
                votes[j][i] = res[j];
            }
        }
        double[] retn = new double[X.length];
        for (int i = 0; i < X.length; ++i) {
            double[] best = Utils.mode(votes[i]);
            retn[i] = best[(int)(Math.random() * (double)best.length)];
        }
        return retn;
    }

    public static double[][] marginalTreePredictions(RandomForest forest, int[] tree_idxs_used, double[][] Theta) {
        return RandomForest.marginalTreePredictions(forest, tree_idxs_used, Theta, null);
    }

    public static double[][] applyMarginal(RandomForest forest, int[] tree_idxs_used, double[][] Theta) {
        return RandomForest.applyMarginal(forest, tree_idxs_used, Theta, null);
    }

    public static double[][] applyMarginal(RandomForest forest, int[] tree_idxs_used, double[][] Theta, double[][] X) {
        int i;
        int nTheta = Theta.length;
        int nTrees = tree_idxs_used.length;
        double[][] retn = new double[nTheta][2];
        for (i = 0; i < nTrees; ++i) {
            Object[] result = RegtreeFwd.marginalFwd(forest.Trees[tree_idxs_used[i]], Theta, X);
            double[] preds = (double[])result[0];
            double[] vars = (double[])result[1];
            for (int j = 0; j < nTheta; ++j) {
                double pred = preds[j];
                if (forest.logModel > 0) {
                    pred = Math.log10(pred);
                }
                double[] dArray = retn[j];
                dArray[0] = dArray[0] + pred;
                double[] dArray2 = retn[j];
                dArray2[1] = dArray2[1] + (vars[j] + pred * pred);
            }
        }
        for (i = 0; i < nTheta; ++i) {
            double[] dArray = retn[i];
            dArray[0] = dArray[0] / (double)nTrees;
            double[] dArray3 = retn[i];
            dArray3[1] = dArray3[1] / (double)nTrees;
            double[] dArray4 = retn[i];
            dArray4[1] = dArray4[1] - retn[i][0] * retn[i][0];
            retn[i][1] = retn[i][1] * (((double)forest.numTrees + 0.0) / (double)Math.max(1, forest.numTrees - 1));
            if (retn[i][1] < MIN_VARIANCE_RESULT) {
                System.err.println("[WARN]: Variance is less than " + MIN_VARIANCE_RESULT + " > " + retn[i][1]);
                assert (retn[i][1] > MIN_VARIANCE_RESULT);
            }
            retn[i][1] = Math.max(forest.minVariance, retn[i][1]);
        }
        return retn;
    }

    public static double[][] marginalTreePredictions(RandomForest forest, int[] tree_idxs_used, double[][] Theta, double[][] X) {
        int nTheta = Theta.length;
        int nTrees = tree_idxs_used.length;
        double[][] retn = new double[nTheta][nTrees];
        for (int i = 0; i < nTrees; ++i) {
            Object[] result = RegtreeFwd.marginalFwd(forest.Trees[tree_idxs_used[i]], Theta, X);
            double[] preds = (double[])result[0];
            for (int j = 0; j < nTheta; ++j) {
                double pred = preds[j];
                if (forest.logModel > 0) {
                    pred = Math.log10(pred);
                }
                retn[j][i] = pred;
            }
        }
        return retn;
    }

    public static RandomForest preprocessForest(RandomForest forest, double[][] X) {
        RandomForest prepared = new RandomForest(forest.numTrees, forest.buildParams);
        for (int i = 0; i < forest.numTrees; ++i) {
            prepared.Trees[i] = RegtreeFwd.preprocess_inst_splits(forest.Trees[i], X);
        }
        return prepared;
    }

    public static void preprocessForestForClassification(RandomForest forest) {
        for (int i = 0; i < forest.numTrees; ++i) {
            RegtreeFwd.preprocess_for_classification(forest.Trees[i]);
        }
    }
}

