package edu.stanford.nlp.stats;

import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.util.Function;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Triple;
import java.io.BufferedReader;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.Formatter;
import java.util.List;
import java.util.Locale;
import java.util.regex.Pattern;

/* loaded from: input_file:edu/stanford/nlp/stats/MultiClassPrecisionRecallExtendedStats.class */
public class MultiClassPrecisionRecallExtendedStats<L> extends MultiClassPrecisionRecallStats<L> {
    protected IntCounter<L> correctGuesses;
    protected IntCounter<L> foundCorrect;
    protected IntCounter<L> foundGuessed;
    protected int tokensCount;
    protected int tokensCorrect;
    protected int noLabel;
    protected Function<String, L> stringConverter;

    /* loaded from: input_file:edu/stanford/nlp/stats/MultiClassPrecisionRecallExtendedStats$MultiClassStringLabelStats.class */
    public static class MultiClassStringLabelStats extends MultiClassPrecisionRecallExtendedStats<String> {
        public <F> MultiClassStringLabelStats(Classifier<String, F> classifier, GeneralDataset<String, F> generalDataset, String str) {
            super(classifier, generalDataset, str);
            this.stringConverter = new StringStringConverter();
        }

        public MultiClassStringLabelStats(String str) {
            super(str);
            this.stringConverter = new StringStringConverter();
        }

        public MultiClassStringLabelStats(Index<String> index, String str) {
            this(str);
            setLabelIndex(index);
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/stats/MultiClassPrecisionRecallExtendedStats$StringStringConverter.class */
    public static class StringStringConverter implements Function<String, String> {
        @Override // edu.stanford.nlp.util.Function
        public String apply(String str) {
            return str;
        }
    }

    public <F> MultiClassPrecisionRecallExtendedStats(Classifier<L, F> classifier, GeneralDataset<L, F> generalDataset, L l) {
        super(classifier, generalDataset, l);
        this.tokensCount = 0;
        this.tokensCorrect = 0;
        this.noLabel = 0;
    }

    public MultiClassPrecisionRecallExtendedStats(L l) {
        super(l);
        this.tokensCount = 0;
        this.tokensCorrect = 0;
        this.noLabel = 0;
    }

    public MultiClassPrecisionRecallExtendedStats(Index<L> index, L l) {
        this(l);
        setLabelIndex(index);
    }

    public void setLabelIndex(Index<L> index) {
        this.labelIndex = index;
        this.negIndex = this.labelIndex.indexOf(this.negLabel);
    }

    @Override // edu.stanford.nlp.stats.MultiClassPrecisionRecallStats
    public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> generalDataset) {
        this.labelIndex = new HashIndex();
        this.labelIndex.addAll(classifier.labels());
        this.labelIndex.addAll(generalDataset.labelIndex.objectsList());
        clearCounts();
        int[] labelsArray = generalDataset.getLabelsArray();
        for (int i = 0; i < generalDataset.size(); i++) {
            addGuess(classifier.classOf(generalDataset.getRVFDatum(i)), this.labelIndex.get(labelsArray[i]));
        }
        finalizeCounts();
        return getFMeasure();
    }

    public double score(List<L> list, List<L> list2, Index<L> index) {
        setLabelIndex(index);
        return score(list, list2);
    }

    public double score(List<L> list, List<L> list2) {
        clearCounts();
        addGuesses(list, list2);
        finalizeCounts();
        return getFMeasure();
    }

    public double score() {
        finalizeCounts();
        return getFMeasure();
    }

    public void clearCounts() {
        if (this.foundCorrect != null) {
            this.foundCorrect.clear();
        } else {
            this.foundCorrect = new IntCounter<>();
        }
        if (this.foundGuessed != null) {
            this.foundGuessed.clear();
        } else {
            this.foundGuessed = new IntCounter<>();
        }
        if (this.correctGuesses != null) {
            this.correctGuesses.clear();
        } else {
            this.correctGuesses = new IntCounter<>();
        }
        if (this.tpCount != null) {
            Arrays.fill(this.tpCount, 0);
        }
        if (this.fnCount != null) {
            Arrays.fill(this.fnCount, 0);
        }
        if (this.fpCount != null) {
            Arrays.fill(this.fpCount, 0);
        }
        this.tokensCount = 0;
        this.tokensCorrect = 0;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void finalizeCounts() {
        this.negIndex = this.labelIndex.indexOf(this.negLabel);
        int size = this.labelIndex.size();
        if (this.tpCount == null || this.tpCount.length != size) {
            this.tpCount = new int[size];
        }
        if (this.fpCount == null || this.fpCount.length != size) {
            this.fpCount = new int[size];
        }
        if (this.fnCount == null || this.fnCount.length != size) {
            this.fnCount = new int[size];
        }
        for (int i = 0; i < size; i++) {
            L l = this.labelIndex.get(i);
            this.tpCount[i] = this.correctGuesses.getIntCount(l);
            this.fnCount[i] = this.foundCorrect.getIntCount(l) - this.tpCount[i];
            this.fpCount[i] = this.foundGuessed.getIntCount(l) - this.tpCount[i];
        }
    }

    protected void markBoundary() {
    }

    protected void addGuess(L l, L l2) {
        addGuess(l, l2, true);
    }

    protected void addGuess(L l, L l2, boolean z) {
        if (l2 == null) {
            this.noLabel++;
            return;
        }
        if (z) {
            if (this.labelIndex == null) {
                this.labelIndex = new HashIndex();
            }
            this.labelIndex.add(l);
            this.labelIndex.add(l2);
        }
        if (l.equals(l2)) {
            this.correctGuesses.incrementCount(l2);
            this.tokensCorrect++;
        }
        if (!l.equals(this.negLabel)) {
            this.foundGuessed.incrementCount(l);
        }
        if (!l2.equals(this.negLabel)) {
            this.foundCorrect.incrementCount(l2);
        }
        this.tokensCount++;
    }

    public void addGuesses(List<L> list, List<L> list2) {
        for (int i = 0; i < list.size(); i++) {
            addGuess(list.get(i), list2.get(i));
        }
    }

    public int getCorrect() {
        return this.correctGuesses.totalIntCount();
    }

    public int getCorrect(L l) {
        return this.correctGuesses.getIntCount(l);
    }

    public int getRetrieved(L l) {
        return this.foundGuessed.getIntCount(l);
    }

    public int getRetrieved() {
        return this.foundGuessed.totalIntCount();
    }

    public int getRelevant(L l) {
        return this.foundCorrect.getIntCount(l);
    }

    public int getRelevant() {
        return this.foundCorrect.totalIntCount();
    }

    public Triple<Double, Integer, Integer> getAccuracyInfo() {
        int i = this.tokensCorrect;
        return new Triple<>(Double.valueOf(i / this.tokensCount), Integer.valueOf(i), Integer.valueOf(this.tokensCount - this.tokensCorrect));
    }

    public double getAccuracy() {
        return getAccuracyInfo().first().doubleValue();
    }

    public String getAccuracyDescription(int i) {
        NumberFormat numberInstance = NumberFormat.getNumberInstance();
        numberInstance.setMaximumFractionDigits(i);
        Triple<Double, Integer, Integer> accuracyInfo = getAccuracyInfo();
        return numberInstance.format(accuracyInfo.first()) + "  (" + accuracyInfo.second() + "/" + (accuracyInfo.second().intValue() + accuracyInfo.third().intValue()) + ")";
    }

    public double score(String str, String str2) throws IOException {
        return score(str, str2, (String) null);
    }

    public double score(String str, String str2, String str3) throws IOException {
        return score(IOUtils.getBufferedFileReader(str), str2, str3);
    }

    public double score(BufferedReader bufferedReader, String str) throws IOException {
        return score(bufferedReader, str, (String) null);
    }

    public double score(BufferedReader bufferedReader, String str, String str2) throws IOException {
        Pattern compile = Pattern.compile(str);
        clearCounts();
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                finalizeCounts();
                return getFMeasure();
            }
            String trim = readLine.trim();
            if (trim.length() > 0) {
                String[] split = compile.split(trim);
                if (str2 == null || !str2.equals(split[0])) {
                    addGuess(this.stringConverter.apply(split[2]), this.stringConverter.apply(split[1]));
                } else {
                    markBoundary();
                }
            } else {
                markBoundary();
            }
        }
    }

    public List<L> getLabels() {
        return this.labelIndex.objectsList();
    }

    public String getConllEvalString() {
        return getConllEvalString(true);
    }

    public String getConllEvalString(boolean z) {
        List<L> labels = getLabels();
        if (labels.size() > 1 && (labels.get(0) instanceof Comparable)) {
            Collections.sort(labels);
        }
        return getConllEvalString(labels, z);
    }

    private String getConllEvalString(List<L> list, boolean z) {
        StringBuilder sb = new StringBuilder();
        int correct = getCorrect() - getCorrect(this.negLabel);
        Triple<Double, Integer, Integer> accuracyInfo = getAccuracyInfo();
        sb.append("processed " + (accuracyInfo.second().intValue() + accuracyInfo.third().intValue()) + " tokens with " + getRelevant() + " phrases; ");
        sb.append("found: " + getRetrieved() + " phrases; correct: " + correct + "\n");
        Formatter formatter = new Formatter(sb, Locale.US);
        formatter.format("accuracy: %6.2f%%; ", Double.valueOf(accuracyInfo.first().doubleValue() * 100.0d));
        formatter.format("precision: %6.2f%%; ", Double.valueOf(getPrecision() * 100.0d));
        formatter.format("recall: %6.2f%%; ", Double.valueOf(getRecall() * 100.0d));
        formatter.format("FB1: %6.2f\n", Double.valueOf(getFMeasure() * 100.0d));
        for (L l : list) {
            if (!z || !l.equals(this.negLabel)) {
                formatter.format("%17s: ", l);
                formatter.format("precision: %6.2f%%; ", Double.valueOf(getPrecision(l) * 100.0d));
                formatter.format("recall: %6.2f%%; ", Double.valueOf(getRecall(l) * 100.0d));
                formatter.format("FB1: %6.2f  %d\n", Double.valueOf(getFMeasure(l) * 100.0d), Integer.valueOf(getRetrieved(l)));
            }
        }
        return sb.toString();
    }
}
