/*
 * Decompiled with CFR 0.152.
 */
package de.unifreiburg.cs.junit;

import ca.ubc.cs.beta.models.fastrf.RandomForest;
import ca.ubc.cs.beta.models.fastrf.RegtreeBuildParams;
import ca.ubc.cs.beta.models.fastrf.utils.CsvToDataConverter;
import ca.ubc.cs.beta.models.fastrf.utils.RfData;
import java.io.IOException;
import java.util.Random;
import org.junit.Assert;
import org.junit.Test;

public class DataTester {
    public static String rfDeployment = null;
    public static Long seedOffset = Math.abs(new Random().nextLong());

    @Test
    public void testMakeUnique() throws IOException {
        int[] thetaColIdxs = new int[]{0};
        int[] xColIdxs = new int[]{1};
        int yColIdx = 2;
        int[] catColIdxs = new int[]{};
        String filename = "test_files/mini_with_repetitions.csv";
        CsvToDataConverter converter = new CsvToDataConverter(filename, thetaColIdxs, xColIdxs, yColIdx, catColIdxs);
        RfData data = converter.readDataFromCsvFile(filename);
        Assert.assertEquals((long)data.getTheta().length, (long)5L);
        Assert.assertEquals((long)data.getX().length, (long)5L);
        data.makeUnique(true);
        Assert.assertEquals((long)data.getTheta().length, (long)2L);
        Assert.assertEquals((long)data.getX().length, (long)5L);
        data.makeUnique(false);
        Assert.assertEquals((long)data.getTheta().length, (long)2L);
        Assert.assertEquals((long)data.getX().length, (long)3L);
        data.makeUnique(true);
        Assert.assertEquals((long)data.getTheta().length, (long)2L);
        Assert.assertEquals((long)data.getX().length, (long)3L);
        data.makeUnique(false);
        Assert.assertEquals((long)data.getTheta().length, (long)2L);
        Assert.assertEquals((long)data.getX().length, (long)3L);
    }

    @Test
    public void testRfFromCSV() throws IOException {
        int[] thetaColIdxs = new int[]{0};
        int[] xColIdxs = new int[]{1};
        int yColIdx = 2;
        int[] catColIdxs = new int[]{};
        String filename = "test_files/mini.csv";
        CsvToDataConverter converter = new CsvToDataConverter(filename, thetaColIdxs, xColIdxs, yColIdx, catColIdxs);
        RfData data = converter.readDataFromCsvFile(filename);
        RandomForest rf = RandomForest.learnModel(1, data.getTheta(), data.getX(), data.getTheta_inst_idxs(), data.getY(), new RegtreeBuildParams(2, false, 1));
        double[][] newX = data.buildMatrixForApply();
        double[][] meanvar = RandomForest.apply(rf, newX);
        for (int i = 0; i < data.getY().length; ++i) {
            Assert.assertEquals((double)meanvar[i][0], (double)data.getY()[i], (double)1.0E-10);
        }
    }

    @Test
    public void testRfFromMiniLDOF() throws IOException {
        int[] thetaColIdxs = new int[]{2, 3, 4, 5};
        int[] xColIdxs = new int[94];
        for (int i = 0; i < xColIdxs.length; ++i) {
            xColIdxs[i] = i + 6;
        }
        int yColIdx = 101;
        int[] catColIdxs = new int[]{};
        String trainFilename = "test_files/mini_ldof.csv";
        CsvToDataConverter converter = new CsvToDataConverter(trainFilename, thetaColIdxs, xColIdxs, yColIdx, catColIdxs);
        RfData trainData = converter.readDataFromCsvFile(trainFilename);
        RegtreeBuildParams regTreeBuildParams = new RegtreeBuildParams(false, 1, converter.getCatDomainSizes());
        RandomForest rf = RandomForest.learnModel(1, trainData.getTheta(), trainData.getX(), trainData.getTheta_inst_idxs(), trainData.getY(), regTreeBuildParams);
        double[][] meanvar = RandomForest.apply(rf, trainData.buildMatrixForApply());
        for (int i = 0; i < trainData.getY().length; ++i) {
            Assert.assertEquals((double)meanvar[i][0], (double)trainData.getY()[i], (double)1.0E-10);
        }
    }

    @Test
    public void testBetterThanMeanPredictions() throws IOException {
        int[] thetaColIdxs = new int[]{2, 3, 4, 5};
        int[] xColIdxs = new int[94];
        for (int i = 0; i < xColIdxs.length; ++i) {
            xColIdxs[i] = i + 6;
        }
        int yColIdx = 101;
        int[] catColIdxs = new int[]{};
        String trainFilename = "test_files/train_ldof_data_shuffled_first1000.csv";
        CsvToDataConverter converter = new CsvToDataConverter(trainFilename, thetaColIdxs, xColIdxs, yColIdx, catColIdxs);
        RfData trainData = converter.readDataFromCsvFile(trainFilename);
        RegtreeBuildParams regTreeBuildParams = new RegtreeBuildParams(true, 10, converter.getCatDomainSizes());
        RandomForest rf = RandomForest.learnModel(10, trainData.getTheta(), trainData.getX(), trainData.getTheta_inst_idxs(), trainData.getY(), regTreeBuildParams);
        double rmse = 0.0;
        double rmseOfMeanPred = 0.0;
        double meanPred = 0.0;
        for (int i = 0; i < trainData.getY().length; ++i) {
            meanPred += trainData.getY()[i];
        }
        meanPred /= (double)trainData.getY().length;
        String testFilename = "test_files/test_ldof_data_shuffled_first10000.csv";
        RfData testData = converter.readDataFromCsvFile(testFilename);
        testData.makeUnique(true);
        testData.makeUnique(false);
        double[][] meanvar = RandomForest.apply(rf, testData.buildMatrixForApply());
        for (int i = 0; i < testData.getY().length; ++i) {
            rmse += Math.pow(testData.getY()[i] - meanvar[i][0], 2.0);
            rmseOfMeanPred += Math.pow(testData.getY()[i] - meanPred, 2.0);
        }
        rmse = Math.sqrt(rmse / (double)testData.getY().length);
        rmseOfMeanPred = Math.sqrt(rmseOfMeanPred / (double)testData.getY().length);
        System.out.println("RMSE = " + rmse + "; rmse of mean pred: " + rmseOfMeanPred);
        Assert.assertTrue((rmse < rmseOfMeanPred ? 1 : 0) != 0);
    }
}

