package LBJ2.classify;

import LBJ2.learn.WekaWrapper;
import LBJ2.parse.Parser;
import java.io.PrintStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;

/* loaded from: input_file:LBJ2/classify/TestDiscrete.class */
public class TestDiscrete {
    private static Classifier classifier;
    private static Classifier oracle;
    private static Parser parser;
    private static int outputGranularity;
    protected HashMap goldHistogram = new HashMap();
    protected HashMap predictionHistogram = new HashMap();
    protected HashMap correctHistogram = new HashMap();
    protected HashSet nullLabels = new HashSet();
    static final boolean $assertionsDisabled;
    static Class class$LBJ2$classify$TestDiscrete;
    static Class class$java$lang$String;

    public static void main(String[] strArr) {
        long j = -System.currentTimeMillis();
        TestDiscrete instantiate = instantiate(strArr);
        System.out.println(new StringBuffer().append("Classifier loaded in ").append((j + System.currentTimeMillis()) / 1000.0d).append(" seconds.").toString());
        testDiscrete(instantiate, classifier, oracle, parser, true, outputGranularity);
    }

    public static TestDiscrete testDiscrete(Classifier classifier2, Classifier classifier3, Parser parser2) {
        return testDiscrete(new TestDiscrete(), classifier2, classifier3, parser2, false, 0);
    }

    public static TestDiscrete testDiscrete(TestDiscrete testDiscrete, Classifier classifier2, Classifier classifier3, Parser parser2, boolean z, int i) {
        long currentTimeMillis;
        int i2 = 1;
        if (!z || i <= 0) {
            Object next = parser2.next();
            if (next == null) {
                return testDiscrete;
            }
            long currentTimeMillis2 = 0 - System.currentTimeMillis();
            String discreteValue = classifier2.discreteValue(next);
            currentTimeMillis = currentTimeMillis2 + System.currentTimeMillis();
            if (z) {
                System.out.println(new StringBuffer().append("First example processed in ").append(currentTimeMillis / 1000.0d).append(" seconds.").toString());
            }
            testDiscrete.reportPrediction(discreteValue, classifier3.discreteValue(next));
            Object next2 = parser2.next();
            while (next2 != null) {
                long currentTimeMillis3 = currentTimeMillis - System.currentTimeMillis();
                String discreteValue2 = classifier2.discreteValue(next2);
                currentTimeMillis = currentTimeMillis3 + System.currentTimeMillis();
                testDiscrete.reportPrediction(discreteValue2, classifier3.discreteValue(next2));
                next2 = parser2.next();
                i2++;
            }
        } else {
            System.out.println(new StringBuffer().append("0 examples tested at ").append(new Date()).toString());
            Object next3 = parser2.next();
            if (next3 == null) {
                return testDiscrete;
            }
            long currentTimeMillis4 = 0 - System.currentTimeMillis();
            String discreteValue3 = classifier2.discreteValue(next3);
            currentTimeMillis = currentTimeMillis4 + System.currentTimeMillis();
            System.out.println(new StringBuffer().append("First example processed in ").append(currentTimeMillis / 1000.0d).append(" seconds.").toString());
            testDiscrete.reportPrediction(discreteValue3, classifier3.discreteValue(next3));
            Object next4 = parser2.next();
            while (next4 != null) {
                if (i2 % i == 0) {
                    System.out.println(new StringBuffer().append(i2).append(" examples tested at ").append(new Date()).toString());
                }
                long currentTimeMillis5 = currentTimeMillis - System.currentTimeMillis();
                String discreteValue4 = classifier2.discreteValue(next4);
                currentTimeMillis = currentTimeMillis5 + System.currentTimeMillis();
                if (!$assertionsDisabled && discreteValue4 == null) {
                    throw new AssertionError(new StringBuffer().append("Classifier returned null prediction for example ").append(next4).toString());
                }
                testDiscrete.reportPrediction(discreteValue4, classifier3.discreteValue(next4));
                next4 = parser2.next();
                i2++;
            }
            System.out.println(new StringBuffer().append(i2).append(" examples tested at ").append(new Date()).append("\n").toString());
        }
        if (z) {
            System.out.println(new StringBuffer().append("Average evaluation time: ").append(currentTimeMillis / (1000.0d * i2)).append(" seconds\n").toString());
            testDiscrete.printPerformance(System.out);
        }
        return testDiscrete;
    }

    private static TestDiscrete instantiate(String[] strArr) {
        Class<?> cls;
        String str = null;
        String str2 = null;
        String str3 = null;
        String str4 = null;
        TestDiscrete testDiscrete = new TestDiscrete();
        try {
            int i = 0;
            if (strArr[0].charAt(0) == '-') {
                if (!strArr[0].equals("-t")) {
                    throw new Exception();
                }
                outputGranularity = Integer.parseInt(strArr[1]);
                i = 2;
            }
            str = strArr[i];
            str2 = strArr[i + 1];
            str3 = strArr[i + 2];
            str4 = strArr[i + 3];
            for (int i2 = i + 4; i2 < strArr.length; i2++) {
                testDiscrete.addNull(strArr[i2]);
            }
        } catch (Exception e) {
            System.err.println("usage:\n  java LBJ2.classify.TestDiscrete [-t <n>] <classifier> <oracle> \\\n                                  <parser> <input file> \\\n                                  [<null label> [<null label> ...]]");
            System.exit(1);
        }
        Class<?> cls2 = null;
        Class<?> cls3 = null;
        Class<?> cls4 = null;
        try {
            cls2 = Class.forName(str);
        } catch (Exception e2) {
            System.err.println(new StringBuffer().append("Can't get class for '").append(str).append("': ").append(e2).toString());
            System.exit(1);
        }
        try {
            classifier = (Classifier) cls2.newInstance();
        } catch (Exception e3) {
            System.err.println(new StringBuffer().append("Can't instantiate '").append(str).append("': ").append(e3).toString());
            System.exit(1);
        }
        try {
            cls3 = Class.forName(str2);
        } catch (Exception e4) {
            System.err.println(new StringBuffer().append("Can't get class for '").append(str2).append("': ").append(e4).toString());
            System.exit(1);
        }
        try {
            oracle = (Classifier) cls3.newInstance();
        } catch (Exception e5) {
            System.err.println(new StringBuffer().append("Can't instantiate '").append(str2).append("': ").append(e5).toString());
            System.exit(1);
        }
        try {
            cls4 = Class.forName(str3);
        } catch (Exception e6) {
            System.err.println(new StringBuffer().append("Can't get class for '").append(str3).append("': ").append(e6).toString());
            System.exit(1);
        }
        Class<?>[] clsArr = new Class[1];
        if (class$java$lang$String == null) {
            cls = class$("java.lang.String");
            class$java$lang$String = cls;
        } else {
            cls = class$java$lang$String;
        }
        clsArr[0] = cls;
        Constructor<?> constructor = null;
        try {
            constructor = cls4.getConstructor(clsArr);
        } catch (Exception e7) {
            System.err.println(new StringBuffer().append("Can't get the constructor with a single String parameter for '").append(str3).append("': ").append(e7).toString());
            System.exit(1);
        }
        try {
            parser = (Parser) constructor.newInstance(str4);
        } catch (InvocationTargetException e8) {
            Throwable cause = e8.getCause();
            System.err.println(new StringBuffer().append("Can't instantiate '").append(str3).append("':").toString());
            cause.printStackTrace();
            System.exit(1);
        } catch (Exception e9) {
            System.err.println(new StringBuffer().append("Can't instantiate '").append(str3).append("': ").append(e9).toString());
            System.exit(1);
        }
        return testDiscrete;
    }

    public void reportPrediction(String str, String str2) {
        histogramAdd(this.goldHistogram, str2, 1);
        histogramAdd(this.predictionHistogram, str, 1);
        if (str.equals(str2)) {
            histogramAdd(this.correctHistogram, str, 1);
        }
    }

    public void reportAll(TestDiscrete testDiscrete) {
        histogramAddAll(this.goldHistogram, testDiscrete.goldHistogram);
        histogramAddAll(this.predictionHistogram, testDiscrete.predictionHistogram);
        histogramAddAll(this.correctHistogram, testDiscrete.correctHistogram);
    }

    public String[] getLabels() {
        return (String[]) this.goldHistogram.keySet().toArray(new String[0]);
    }

    public String[] getPredictions() {
        return (String[]) this.predictionHistogram.keySet().toArray(new String[0]);
    }

    public String[] getAllClasses() {
        HashSet hashSet = new HashSet(this.goldHistogram.keySet());
        hashSet.addAll(this.predictionHistogram.keySet());
        return (String[]) hashSet.toArray(new String[0]);
    }

    public void addNull(String str) {
        this.nullLabels.add(str);
    }

    public void removeNull(String str) {
        this.nullLabels.remove(str);
    }

    public boolean isNull(String str) {
        return this.nullLabels.contains(str);
    }

    public boolean hasNulls() {
        return this.nullLabels.size() > 0;
    }

    protected void histogramAdd(HashMap hashMap, String str, int i) {
        Integer num = (Integer) hashMap.get(str);
        if (num == null) {
            num = new Integer(0);
        }
        hashMap.put(str, new Integer(num.intValue() + i));
    }

    protected int histogramGet(HashMap hashMap, String str) {
        Integer num = (Integer) hashMap.get(str);
        if (num == null) {
            num = new Integer(0);
        }
        return num.intValue();
    }

    protected void histogramAddAll(HashMap hashMap, HashMap hashMap2) {
        for (Map.Entry entry : hashMap2.entrySet()) {
            histogramAdd(hashMap, (String) entry.getKey(), ((Integer) entry.getValue()).intValue());
        }
    }

    public int getLabeled(String str) {
        return histogramGet(this.goldHistogram, str);
    }

    public int getPredicted(String str) {
        return histogramGet(this.predictionHistogram, str);
    }

    public int getCorrect(String str) {
        return histogramGet(this.correctHistogram, str);
    }

    public double getPrecision(String str) {
        return getCorrect(str) / getPredicted(str);
    }

    public double getRecall(String str) {
        return getCorrect(str) / getLabeled(str);
    }

    public double getF1(String str) {
        return getF(1.0d, str);
    }

    public double getF(double d, String str) {
        double precision = getPrecision(str);
        double recall = getRecall(str);
        return ((((d * d) + 1.0d) * precision) * recall) / (((d * d) * precision) + recall);
    }

    public double[] getOverallStats() {
        return getOverallStats(1.0d);
    }

    public double[] getOverallStats(double d) {
        String[] allClasses = getAllClasses();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        for (int i6 = 0; i6 < allClasses.length; i6++) {
            int correct = getCorrect(allClasses[i6]);
            int predicted = getPredicted(allClasses[i6]);
            int labeled = getLabeled(allClasses[i6]);
            i += correct;
            i2 += predicted;
            if (hasNulls() && !isNull(allClasses[i6])) {
                i3 += correct;
                i4 += predicted;
                i5 += labeled;
            }
        }
        double[] dArr = new double[4];
        dArr[3] = i / i2;
        if (hasNulls()) {
            dArr[0] = i3 / i4;
            dArr[1] = i3 / i5;
            dArr[2] = ((((d * d) + 1.0d) * dArr[0]) * dArr[1]) / (((d * d) * dArr[0]) + dArr[1]);
        } else {
            double d2 = dArr[3];
            dArr[2] = d2;
            dArr[1] = d2;
            dArr[0] = d2;
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void printPerformance(PrintStream printStream) {
        String[] allClasses = getAllClasses();
        Arrays.sort(allClasses, new Comparator(this, this.nullLabels) { // from class: LBJ2.classify.TestDiscrete.1
            private final HashSet val$n;
            private final TestDiscrete this$0;

            {
                this.this$0 = this;
                this.val$n = r5;
            }

            @Override // java.util.Comparator
            public int compare(Object obj, Object obj2) {
                String str = (String) obj;
                String str2 = (String) obj2;
                int i = this.val$n.contains(str) ? 1 : 0;
                int i2 = this.val$n.contains(str2) ? 1 : 0;
                return i != i2 ? i - i2 : str.compareTo(str2);
            }
        });
        String[] strArr = new String[allClasses.length + 3];
        String[] strArr2 = new String[6];
        strArr2[0] = "Label";
        strArr2[1] = "Precision";
        strArr2[2] = "Recall";
        strArr2[3] = "F1";
        strArr2[4] = "LCount";
        strArr2[5] = "PCount";
        strArr[0] = strArr2;
        if (hasNulls()) {
            String[] strArr3 = new String[6];
            strArr3[0] = "Overall";
            strArr3[1] = WekaWrapper.defaultAttributeString;
            strArr3[2] = WekaWrapper.defaultAttributeString;
            strArr3[3] = WekaWrapper.defaultAttributeString;
            strArr3[4] = WekaWrapper.defaultAttributeString;
            strArr3[5] = WekaWrapper.defaultAttributeString;
            strArr[1] = strArr3;
        } else {
            String[] strArr4 = new String[6];
            strArr4[0] = WekaWrapper.defaultAttributeString;
            strArr4[1] = WekaWrapper.defaultAttributeString;
            strArr4[2] = WekaWrapper.defaultAttributeString;
            strArr4[3] = WekaWrapper.defaultAttributeString;
            strArr4[4] = WekaWrapper.defaultAttributeString;
            strArr4[5] = WekaWrapper.defaultAttributeString;
            strArr[1] = strArr4;
        }
        String[] strArr5 = new String[6];
        strArr5[0] = "Accuracy";
        strArr5[1] = WekaWrapper.defaultAttributeString;
        strArr5[2] = WekaWrapper.defaultAttributeString;
        strArr5[3] = WekaWrapper.defaultAttributeString;
        strArr5[4] = WekaWrapper.defaultAttributeString;
        strArr5[5] = WekaWrapper.defaultAttributeString;
        strArr[2] = strArr5;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        for (int i6 = 0; i6 < allClasses.length; i6++) {
            int correct = getCorrect(allClasses[i6]);
            int predicted = getPredicted(allClasses[i6]);
            int labeled = getLabeled(allClasses[i6]);
            i += correct;
            i2 += predicted;
            if (hasNulls() && !isNull(allClasses[i6])) {
                i3 += correct;
                i4 += predicted;
                i5 += labeled;
            }
            String[] strArr6 = new String[6];
            strArr6[0] = allClasses[i6];
            strArr6[1] = "0.000";
            strArr6[2] = "0.000";
            strArr6[3] = "0.000";
            strArr6[4] = new StringBuffer().append(WekaWrapper.defaultAttributeString).append(labeled).toString();
            strArr6[5] = new StringBuffer().append(WekaWrapper.defaultAttributeString).append(predicted).toString();
            strArr[i6 + 3] = strArr6;
            if (correct > 0) {
                double d = correct / predicted;
                double d2 = correct / labeled;
                strArr[i6 + 3][1] = format(d);
                strArr[i6 + 3][2] = format(d2);
                strArr[i6 + 3][3] = format(((2.0d * d) * d2) / (d + d2));
            }
        }
        if (hasNulls()) {
            if (i3 == 0) {
                String[] strArr7 = strArr[1];
                String[] strArr8 = strArr[1];
                strArr[1][3] = "0.000";
                strArr8[2] = "0.000";
                strArr7[1] = "0.000";
            } else {
                double d3 = i3 / i4;
                double d4 = i3 / i5;
                strArr[1][1] = format(d3);
                strArr[1][2] = format(d4);
                strArr[1][3] = format(((2.0d * d3) * d4) / (d3 + d4));
            }
            strArr[1][4] = new StringBuffer().append(WekaWrapper.defaultAttributeString).append(i5).toString();
            strArr[1][5] = new StringBuffer().append(WekaWrapper.defaultAttributeString).append(i4).toString();
        }
        strArr[2][1] = format(i2 == 0 ? 0.0d : i / i2);
        strArr[2][5] = new StringBuffer().append(WekaWrapper.defaultAttributeString).append(i2).toString();
        int[] iArr = new int[strArr[0].length];
        for (int i7 = 0; i7 < strArr.length; i7++) {
            for (int i8 = 0; i8 < strArr[i7].length; i8++) {
                iArr[i8] = Math.max(iArr[i8], strArr[i7][i8].length());
            }
        }
        String center = center(strArr[0][0], iArr[0]);
        for (int i9 = 1; i9 < strArr[0].length; i9++) {
            center = new StringBuffer().append(center).append(" ").append(center(strArr[0][i9], iArr[i9])).toString();
        }
        System.out.println(center);
        String str = WekaWrapper.defaultAttributeString;
        for (int i10 = 0; i10 < center.length(); i10++) {
            str = new StringBuffer().append(str).append("-").toString();
        }
        System.out.println(str);
        for (int i11 = 0; i11 < allClasses.length - this.nullLabels.size(); i11++) {
            System.out.print(ljust(strArr[i11 + 3][0], iArr[0]));
            for (int i12 = 1; i12 < strArr[i11 + 3].length; i12++) {
                System.out.print(new StringBuffer().append(" ").append(rjust(strArr[i11 + 3][i12], iArr[i12])).toString());
            }
            System.out.println();
        }
        System.out.println(str);
        if (hasNulls()) {
            for (int length = allClasses.length - this.nullLabels.size(); length < allClasses.length; length++) {
                System.out.print(ljust(strArr[length + 3][0], iArr[0]));
                for (int i13 = 1; i13 < strArr[length + 3].length; i13++) {
                    System.out.print(new StringBuffer().append(" ").append(rjust(strArr[length + 3][i13], iArr[i13])).toString());
                }
                System.out.println();
            }
            System.out.println(str);
            System.out.print(ljust(strArr[1][0], iArr[0]));
            for (int i14 = 1; i14 < strArr[1].length; i14++) {
                System.out.print(new StringBuffer().append(" ").append(rjust(strArr[1][i14], iArr[i14])).toString());
            }
            System.out.println();
        }
        System.out.print(ljust(strArr[2][0], iArr[0]));
        for (int i15 = 1; i15 < strArr[2].length; i15++) {
            System.out.print(new StringBuffer().append(" ").append(rjust(strArr[2][i15], iArr[i15])).toString());
        }
        System.out.println();
    }

    protected static String format(double d) {
        StringBuffer stringBuffer = new StringBuffer(new StringBuffer().append(WekaWrapper.defaultAttributeString).append(Math.round(d * 100000.0d)).toString());
        while (stringBuffer.length() < 4) {
            stringBuffer.insert(0, '0');
        }
        stringBuffer.insert(stringBuffer.length() - 3, '.');
        return stringBuffer.toString();
    }

    protected static String ljust(String str, int i) {
        while (str.length() < i) {
            str = new StringBuffer().append(str).append(" ").toString();
        }
        return str;
    }

    protected static String rjust(String str, int i) {
        while (str.length() < i) {
            str = new StringBuffer().append(" ").append(str).toString();
        }
        return str;
    }

    protected static String center(String str, int i) {
        int length = i - str.length();
        int i2 = 0;
        while (i2 < length / 2) {
            str = new StringBuffer().append(" ").append(str).toString();
            i2++;
        }
        while (i2 < length) {
            str = new StringBuffer().append(str).append(" ").toString();
            i2++;
        }
        return str;
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$LBJ2$classify$TestDiscrete == null) {
            cls = class$("LBJ2.classify.TestDiscrete");
            class$LBJ2$classify$TestDiscrete = cls;
        } else {
            cls = class$LBJ2$classify$TestDiscrete;
        }
        $assertionsDisabled = !cls.desiredAssertionStatus();
    }
}
