/*
 * Decompiled with CFR 0.152.
 */
package ca.ubc.cs.beta.aeatk.model.builder;

import ca.ubc.cs.beta.aeatk.misc.math.distribution.TruncatedNormalDistribution;
import ca.ubc.cs.beta.aeatk.misc.model.SMACRandomForestHelper;
import ca.ubc.cs.beta.aeatk.misc.watch.AutoStartStopWatch;
import ca.ubc.cs.beta.aeatk.misc.watch.StopWatch;
import ca.ubc.cs.beta.aeatk.model.builder.ModelBuilder;
import ca.ubc.cs.beta.aeatk.model.data.SanitizedModelData;
import ca.ubc.cs.beta.aeatk.options.RandomForestOptions;
import ca.ubc.cs.beta.models.fastrf.RandomForest;
import ca.ubc.cs.beta.models.fastrf.RegtreeBuildParams;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AdaptiveCappingModelBuilder
implements ModelBuilder {
    protected final RandomForest forest;
    protected final RandomForest preprocessedForest;
    private static final Logger log = LoggerFactory.getLogger(AdaptiveCappingModelBuilder.class);

    public AdaptiveCappingModelBuilder(SanitizedModelData mds, RandomForestOptions rfOptions, Random rand, int imputationIterations, double cutoffTime, double penaltyFactor) {
        this(mds, rfOptions, rand, imputationIterations, cutoffTime, penaltyFactor, 1.0);
    }

    public AdaptiveCappingModelBuilder(SanitizedModelData mds, RandomForestOptions rfOptions, Random rand, int imputationIterations, double cutoffTime, double penaltyFactor, double subsamplePercentage) {
        int k;
        int j;
        double maxPenalizedValue = mds.transformResponseValue(cutoffTime * penaltyFactor);
        double transformedCutoffTime = mds.transformResponseValue(cutoffTime * penaltyFactor);
        int[][] theta_inst_idxs = mds.getThetaInstIdxs();
        boolean[] censoringIndicators = mds.getCensoredResponses();
        for (int i = 0; i < theta_inst_idxs.length; ++i) {
            int[] nArray = theta_inst_idxs[i];
            nArray[0] = nArray[0] - 1;
            int[] nArray2 = theta_inst_idxs[i];
            nArray2[1] = nArray2[1] - 1;
        }
        double[] responseValues = mds.getResponseValues();
        ArrayList<int[]> censoredThetaInst = new ArrayList<int[]>(responseValues.length);
        ArrayList<int[]> nonCensoredThetaInst = new ArrayList<int[]>(responseValues.length);
        ArrayList<Double> nonCensoredResponses = new ArrayList<Double>(responseValues.length);
        int censoredCount = 0;
        for (int i = 0; i < responseValues.length; ++i) {
            if (!censoringIndicators[i]) {
                nonCensoredThetaInst.add(theta_inst_idxs[i]);
                nonCensoredResponses.add(responseValues[i]);
                continue;
            }
            censoredThetaInst.add(theta_inst_idxs[i]);
            ++censoredCount;
        }
        int[][] non_cens_theta_inst_idxs = (int[][])nonCensoredThetaInst.toArray((T[])new int[0][]);
        double[] non_cens_responses = this.convertToPrimitiveArray(nonCensoredResponses.toArray(new Double[0]));
        log.debug("Building Random Forest with {} censored runs out of {} total ", (Object)censoredCount, (Object)censoringIndicators.length);
        RandomForest rf = AdaptiveCappingModelBuilder.buildRandomForest(mds, rfOptions, non_cens_theta_inst_idxs, non_cens_responses, false, subsamplePercentage, rand);
        int numTrees = rfOptions.numTrees;
        int numDataPointsInTree = responseValues.length;
        if (subsamplePercentage < 1.0) {
            numDataPointsInTree = (int)((double)numDataPointsInTree * subsamplePercentage);
            Object[] args = new Object[]{numDataPointsInTree, responseValues.length, subsamplePercentage};
            log.debug("Subsampling number in points in imputed trees to {} out of {} ({} %)", args);
        }
        LinkedHashMap censoredSampleIdxs = new LinkedHashMap();
        for (int i = 0; i < numDataPointsInTree; ++i) {
            if (!censoringIndicators[i]) continue;
            censoredSampleIdxs.put(i, new HashMap());
        }
        int[][] dataIdxs = new int[numTrees][numDataPointsInTree];
        if (rfOptions.fullTreeBootstrap) {
            for (j = 0; j < numTrees; ++j) {
                for (k = 0; k < numDataPointsInTree; ++k) {
                    dataIdxs[j][k] = k;
                }
            }
        } else {
            for (j = 0; j < numTrees; ++j) {
                for (k = 0; k < numDataPointsInTree; ++k) {
                    dataIdxs[j][k] = rand.nextInt(numDataPointsInTree);
                }
            }
        }
        for (j = 0; j < numTrees; ++j) {
            for (k = 0; k < numDataPointsInTree; ++k) {
                int dataIndex = dataIdxs[j][k];
                if (!censoringIndicators[dataIndex]) continue;
                if (((Map)censoredSampleIdxs.get(dataIndex)).get(j) == null) {
                    ((Map)censoredSampleIdxs.get(dataIndex)).put(j, new ArrayList());
                }
                ((List)((Map)censoredSampleIdxs.get(dataIndex)).get(j)).add(k);
            }
        }
        double[][] yHallucinated = new double[numTrees][numDataPointsInTree];
        for (int tree = 0; tree < yHallucinated.length; ++tree) {
            for (int treeResponseValueIndex = 0; treeResponseValueIndex < yHallucinated[tree].length; ++treeResponseValueIndex) {
                int responseValueIndex = dataIdxs[tree][treeResponseValueIndex];
                yHallucinated[tree][treeResponseValueIndex] = responseValues[responseValueIndex];
            }
        }
        for (int i = 0; i < imputationIterations; ++i) {
            double differenceFromLastMean = 0.0;
            if (censoredSampleIdxs.isEmpty()) break;
            int Xlength = mds.getConfigs()[0].length + mds.getPCAFeatures()[0].length;
            double[][] predictors = new double[censoredSampleIdxs.size()][Xlength];
            int j2 = 0;
            for (Integer sampleIdxToUse : censoredSampleIdxs.keySet()) {
                int m;
                double[] configArray = mds.getConfigs()[theta_inst_idxs[sampleIdxToUse][0]];
                double[] featureArray = mds.getPCAFeatures()[theta_inst_idxs[sampleIdxToUse][1]];
                for (m = 0; m < configArray.length; ++m) {
                    predictors[j2][m] = configArray[m];
                }
                for (m = 0; m < featureArray.length; ++m) {
                    predictors[j2][m + configArray.length] = featureArray[m];
                }
                ++j2;
            }
            double[][] prediction = RandomForest.apply((RandomForest)rf, (double[][])predictors);
            j2 = 0;
            for (Map.Entry ent : censoredSampleIdxs.entrySet()) {
                int k2;
                double[] samples;
                int sampleIdxToUse = (Integer)ent.getKey();
                Map treeDataIdxsMap = (Map)ent.getValue();
                int numSamplesToGet = 0;
                for (List l : treeDataIdxsMap.values()) {
                    numSamplesToGet += l.size();
                }
                TruncatedNormalDistribution tNorm = new TruncatedNormalDistribution(prediction[j2][0], prediction[j2][1], responseValues[sampleIdxToUse], rand);
                ++j2;
                if (rfOptions.imputeMean) {
                    samples = new double[numSamplesToGet];
                    for (k2 = 0; k2 < samples.length; ++k2) {
                        samples[k2] = prediction[j2][0];
                    }
                } else {
                    samples = rfOptions.shuffleImputedValues ? tNorm.getValuesAtStratifiedShuffledIntervals(numSamplesToGet) : tNorm.getValuesAtStratifiedIntervals(numSamplesToGet);
                }
                for (k2 = 0; k2 < samples.length; ++k2) {
                    if (rfOptions.penalizeImputedValues && samples[k2] >= transformedCutoffTime) {
                        samples[k2] = maxPenalizedValue;
                    }
                    samples[k2] = Math.min(samples[k2], maxPenalizedValue);
                }
                int count = 0;
                double increaseThisDataPoint = 0.0;
                for (Map.Entry ent2 : treeDataIdxsMap.entrySet()) {
                    int tree = (Integer)ent2.getKey();
                    List responseLocationsInTree = (List)ent2.getValue();
                    for (int k3 = 0; k3 < responseLocationsInTree.size(); ++k3) {
                        int responseLocationInTree = (Integer)responseLocationsInTree.get(k3);
                        increaseThisDataPoint += samples[count] - yHallucinated[tree][responseLocationInTree];
                        yHallucinated[tree][responseLocationInTree] = samples[count++];
                    }
                }
                if (count == 0) continue;
                differenceFromLastMean += increaseThisDataPoint / (double)count;
            }
            differenceFromLastMean /= (double)censoredSampleIdxs.size();
            log.debug("Building random forest with imputed values iteration {}", (Object)i);
            rf = AdaptiveCappingModelBuilder.buildImputedRandomForest(mds, rfOptions, theta_inst_idxs, dataIdxs, yHallucinated, false, rand);
            if (differenceFromLastMean < Math.pow(10.0, -10.0) && i >= 1) {
                log.trace("Means of imputed values stopped increasing in imputation iteration {} (increase {})", (Object)i, (Object)differenceFromLastMean);
                break;
            }
            log.trace("Mean increase in imputed values in imputation iteration {} is {}", (Object)i, (Object)differenceFromLastMean);
        }
        this.forest = rf;
        if (rfOptions.preprocessMarginal) {
            log.trace("Preprocessing marginal for Random Forest");
            this.preprocessedForest = RandomForest.preprocessForest((RandomForest)this.forest, (double[][])mds.getPCAFeatures());
        } else {
            this.preprocessedForest = null;
        }
    }

    @Override
    public RandomForest getRandomForest() {
        return this.forest;
    }

    @Override
    public RandomForest getPreparedRandomForest() {
        return this.preprocessedForest;
    }

    private double[] convertToPrimitiveArray(Double[] arr) {
        double[] d = new double[arr.length];
        for (int i = 0; i < d.length; ++i) {
            d[i] = arr[i];
        }
        return d;
    }

    private static RandomForest buildRandomForest(SanitizedModelData mds, RandomForestOptions rfOptions, int[][] theta_inst_idxs, double[] responseValues, boolean preprocessed, double subsamplePercentage, Random rand) {
        RandomForest forest;
        double[][] features = mds.getPCAFeatures();
        double[][] configs = mds.getConfigs();
        int[] categoricalSize = mds.getCategoricalSize();
        Map<Integer, int[][]> nameConditionsMapParentsArray = mds.getNameConditionsMapParentsArray();
        Map<Integer, double[][][]> nameConditionsMapParentsValues = mds.getNameConditionsMapParentsValues();
        Map<Integer, int[][]> nameConditionsMapOp = mds.getNameConditionsMapOp();
        int numTrees = rfOptions.numTrees;
        RegtreeBuildParams buildParams = SMACRandomForestHelper.getRandomForestBuildParams(rfOptions, features[0].length, categoricalSize, nameConditionsMapParentsArray, nameConditionsMapParentsValues, nameConditionsMapOp, rand);
        log.trace("Building Random Forest with {} data points ", (Object)responseValues.length);
        StopWatch sw = new StopWatch();
        if (rfOptions.fullTreeBootstrap) {
            int N = responseValues.length;
            int[][] dataIdxs = new int[numTrees][N];
            for (int i = 0; i < numTrees; ++i) {
                for (int j = 0; j < N; ++j) {
                    dataIdxs[i][j] = j;
                }
            }
            sw.start();
            forest = RandomForest.learnModel((int)numTrees, (double[][])configs, (double[][])features, (int[][])theta_inst_idxs, (double[])responseValues, (int[][])dataIdxs, (RegtreeBuildParams)buildParams);
        } else if (subsamplePercentage < 1.0) {
            int N = (int)(subsamplePercentage * (double)responseValues.length);
            log.debug("Subsampling {} points out of {} total for random forest construction", (Object)N, (Object)responseValues.length);
            int[][] dataIdxs = new int[numTrees][N];
            for (int i = 0; i < numTrees; ++i) {
                for (int j = 0; j < N; ++j) {
                    dataIdxs[i][j] = buildParams.random.nextInt(N);
                }
            }
            sw.start();
            forest = RandomForest.learnModel((int)numTrees, (double[][])configs, (double[][])features, (int[][])theta_inst_idxs, (double[])responseValues, (int[][])dataIdxs, (RegtreeBuildParams)buildParams);
        } else {
            sw.start();
            forest = RandomForest.learnModel((int)numTrees, (double[][])configs, (double[][])features, (int[][])theta_inst_idxs, (double[])responseValues, (RegtreeBuildParams)buildParams);
        }
        log.debug("Building Random Forest took {} seconds ", (Object)((double)sw.stop() / 1000.0));
        return forest;
    }

    private static RandomForest buildImputedRandomForest(SanitizedModelData mds, RandomForestOptions rfOptions, int[][] theta_inst_idxs, int[][] dataIdxs, double[][] responseValues, boolean preprocessed, Random rand) {
        double[][] features = mds.getPCAFeatures();
        double[][] configs = mds.getConfigs();
        int[] categoricalSize = mds.getCategoricalSize();
        Map<Integer, int[][]> nameConditionsMapParentsArray = mds.getNameConditionsMapParentsArray();
        Map<Integer, double[][][]> nameConditionsMapParentsValues = mds.getNameConditionsMapParentsValues();
        Map<Integer, int[][]> nameConditionsMapOp = mds.getNameConditionsMapOp();
        int numTrees = rfOptions.numTrees;
        RegtreeBuildParams buildParams = SMACRandomForestHelper.getRandomForestBuildParams(rfOptions, features[0].length, categoricalSize, nameConditionsMapParentsArray, nameConditionsMapParentsValues, nameConditionsMapOp, rand);
        log.trace("Building Random Forest with {} data points ", (Object)responseValues[0].length);
        AutoStartStopWatch sw = new AutoStartStopWatch();
        RandomForest forest = RandomForest.learnModelImputedValues((int)numTrees, (double[][])configs, (double[][])features, (int[][])theta_inst_idxs, (double[][])responseValues, (int[][])dataIdxs, (RegtreeBuildParams)buildParams);
        log.trace("Building Random Forest took {} seconds ", (Object)((double)sw.stop() / 1000.0));
        return forest;
    }
}

