package edu.stanford.nlp.parser.metrics;

import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.trees.Constituent;
import edu.stanford.nlp.trees.Tree;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

/* loaded from: input_file:edu/stanford/nlp/parser/metrics/EvalbByCat.class */
public class EvalbByCat extends AbstractEval {
    private final Evalb evalb;
    private final ClassicCounter<Label> precisions;
    private final ClassicCounter<Label> recalls;
    private final ClassicCounter<Label> f1s;
    private final ClassicCounter<Label> precisions2;
    private final ClassicCounter<Label> recalls2;
    private final ClassicCounter<Label> pnums2;
    private final ClassicCounter<Label> rnums2;

    public EvalbByCat(String str, boolean z) {
        super(str, z);
        this.evalb = new Evalb(str, false);
        this.precisions = new ClassicCounter<>();
        this.recalls = new ClassicCounter<>();
        this.f1s = new ClassicCounter<>();
        this.precisions2 = new ClassicCounter<>();
        this.recalls2 = new ClassicCounter<>();
        this.pnums2 = new ClassicCounter<>();
        this.rnums2 = new ClassicCounter<>();
    }

    @Override // edu.stanford.nlp.parser.metrics.AbstractEval
    protected Set<Constituent> makeObjects(Tree tree) {
        return this.evalb.makeObjects(tree);
    }

    private Map<Label, Set<Constituent>> makeObjectsByCat(Tree tree) {
        HashMap hashMap = new HashMap();
        for (Constituent constituent : makeObjects(tree)) {
            Label label = constituent.label();
            if (!hashMap.keySet().contains(label)) {
                hashMap.put(label, new HashSet());
            }
            ((Set) hashMap.get(label)).add(constituent);
        }
        return hashMap;
    }

    @Override // edu.stanford.nlp.parser.metrics.AbstractEval
    public void evaluate(Tree tree, Tree tree2, PrintWriter printWriter) {
        Map<Label, Set<Constituent>> makeObjectsByCat = makeObjectsByCat(tree);
        Map<Label, Set<Constituent>> makeObjectsByCat2 = makeObjectsByCat(tree2);
        HashSet<Label> hashSet = new HashSet();
        hashSet.addAll(makeObjectsByCat.keySet());
        hashSet.addAll(makeObjectsByCat2.keySet());
        if (printWriter != null && this.runningAverages) {
            printWriter.println("========================================");
            printWriter.println("Labeled Bracketed Evaluation by Category");
            printWriter.println("========================================");
        }
        this.num += 1.0d;
        for (Label label : hashSet) {
            Set<Constituent> set = makeObjectsByCat.get(label);
            Set<Constituent> set2 = makeObjectsByCat2.get(label);
            if (set == null) {
                set = new HashSet();
            }
            if (set2 == null) {
                set2 = new HashSet();
            }
            double precision = precision(set, set2);
            double precision2 = precision(set2, set);
            double d = (precision <= 0.0d || precision2 <= 0.0d) ? 0.0d : 2.0d / ((1.0d / precision) + (1.0d / precision2));
            this.precisions.incrementCount(label, precision);
            this.recalls.incrementCount(label, precision2);
            this.f1s.incrementCount(label, d);
            this.precisions2.incrementCount(label, set.size() * precision);
            this.pnums2.incrementCount(label, set.size());
            this.recalls2.incrementCount(label, set2.size() * precision2);
            this.rnums2.incrementCount(label, set2.size());
            if (printWriter != null && this.runningAverages) {
                printWriter.println(label + "\tP: " + (((int) (precision * 10000.0d)) / 100.0d) + " (sent ave " + (((int) ((this.precisions.getCount(label) * 10000.0d) / this.num)) / 100.0d) + ") (evalb " + (((int) ((this.precisions2.getCount(label) * 10000.0d) / this.pnums2.getCount(label))) / 100.0d) + ")");
                printWriter.println("\tR: " + (((int) (precision2 * 10000.0d)) / 100.0d) + " (sent ave " + (((int) ((this.recalls.getCount(label) * 10000.0d) / this.num)) / 100.0d) + ") (evalb " + (((int) ((this.recalls2.getCount(label) * 10000.0d) / this.rnums2.getCount(label))) / 100.0d) + ")");
                printWriter.println(this.str + " F1: " + (((int) (d * 10000.0d)) / 100.0d) + " (sent ave " + (((int) ((10000.0d * this.f1s.getCount(label)) / this.num)) / 100.0d) + ", evalb " + (((int) (10000.0d * (2.0d / ((this.rnums2.getCount(label) / this.recalls2.getCount(label)) + (this.pnums2.getCount(label) / this.precisions2.getCount(label)))))) / 100.0d) + ")");
            }
        }
        if (printWriter == null || !this.runningAverages) {
            return;
        }
        printWriter.println("========================================");
    }

    @Override // edu.stanford.nlp.parser.metrics.AbstractEval
    public void display(boolean z, PrintWriter printWriter) {
        DecimalFormat decimalFormat = new DecimalFormat("0.00");
        HashSet<Label> hashSet = new HashSet();
        hashSet.addAll(this.precisions.keySet());
        hashSet.addAll(this.recalls.keySet());
        TreeMap treeMap = new TreeMap();
        for (Label label : hashSet) {
            double count = 2.0d / ((1.0d / (this.precisions2.getCount(label) / this.pnums2.getCount(label))) + (1.0d / (this.recalls2.getCount(label) / this.rnums2.getCount(label))));
            if (new Double(count).equals(Double.valueOf(Double.NaN))) {
                count = -1.0d;
            }
            treeMap.put(Double.valueOf(count), label);
        }
        printWriter.println("============================================================");
        printWriter.println("Labeled Bracketed Evaluation by Category -- final statistics");
        printWriter.println("============================================================");
        for (Label label2 : treeMap.values()) {
            double count2 = this.pnums2.getCount(label2);
            double count3 = this.rnums2.getCount(label2);
            double count4 = this.precisions2.getCount(label2) / count2;
            double count5 = this.recalls2.getCount(label2) / count3;
            printWriter.println(label2 + "\tLP: " + (count2 == 0.0d ? " N/A" : decimalFormat.format(count4)) + "\tguessed: " + ((int) count2) + "\tLR: " + (count3 == 0.0d ? " N/A" : decimalFormat.format(count5)) + "\tgold:  " + ((int) count3) + "\tF1: " + ((count2 == 0.0d || count3 == 0.0d) ? " N/A" : decimalFormat.format(2.0d / ((1.0d / count4) + (1.0d / count5)))));
        }
        printWriter.println("============================================================");
    }
}
