package lemming.test.lemma.toutanova;

import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import junit.framework.AssertionFailedError;
import lemming.lemma.LemmaInstance;
import lemming.lemma.LemmaOptions;
import lemming.lemma.LemmaResult;
import lemming.lemma.Lemmatizer;
import lemming.lemma.LemmatizerTrainer;
import lemming.lemma.SimpleLemmatizerTrainer;
import marmot.morph.io.SentenceReader;
import marmot.util.Copy;
import marmot.util.Numerics;
import org.junit.Test;

/* loaded from: input_file:lemming/test/lemma/toutanova/SimpleTrainerTest.class */
public class SimpleTrainerTest {
    private static final String pos_indexes = "form-index=4,lemma-index=5,tag-index=2,";
    private static final String morph_indexes = "form-index=4,lemma-index=5,tag-index=2,morph-index=3,";

    @Test
    public void moderateTest() {
        runModerateTest(new SimpleLemmatizerTrainer(), 98.41d, 64.47d);
    }

    @Test
    public void moderateUnseenTest() {
        SimpleLemmatizerTrainer simpleLemmatizerTrainer = new SimpleLemmatizerTrainer();
        simpleLemmatizerTrainer.getOptions().setOption(SimpleLemmatizerTrainer.SimpleLemmatizerTrainerOptions.HANDLE_UNSEEN, true);
        simpleLemmatizerTrainer.getOptions().setOption(LemmaOptions.USE_POS, false);
        runModerateTest(simpleLemmatizerTrainer, 99.48d, 86.63d);
    }

    @Test
    public void moderateUnseenPosTest() {
        SimpleLemmatizerTrainer simpleLemmatizerTrainer = new SimpleLemmatizerTrainer();
        simpleLemmatizerTrainer.getOptions().setOption(SimpleLemmatizerTrainer.SimpleLemmatizerTrainerOptions.HANDLE_UNSEEN, true);
        simpleLemmatizerTrainer.getOptions().setOption(LemmaOptions.USE_POS, true);
        runModerateTest(simpleLemmatizerTrainer, 99.96d, 86.84d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String getResourceFile(String str) {
        return String.format("res:///%s/%s", "marmot/test/lemma", str);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<LemmaInstance> getCopyInstances(List<LemmaInstance> list) {
        LinkedList linkedList = new LinkedList();
        for (LemmaInstance lemmaInstance : list) {
            if (lemmaInstance.getForm().equals(lemmaInstance.getLemma())) {
                linkedList.add(lemmaInstance);
            }
        }
        return linkedList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void runSmallTest(LemmatizerTrainer lemmatizerTrainer, double d, double d2) {
        runSmallTest(lemmatizerTrainer, d, d2, false);
    }

    protected void runSmallTest(LemmatizerTrainer lemmatizerTrainer, double d, double d2, boolean z) {
        runTest(lemmatizerTrainer, d, d2, "trn_sml.tsv", z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void runModerateTest(LemmatizerTrainer lemmatizerTrainer, double d, double d2) {
        runModerateTest(lemmatizerTrainer, d, d2, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void runModerateTest(LemmatizerTrainer lemmatizerTrainer, double d, double d2, boolean z) {
        runTest(lemmatizerTrainer, d, d2, "trn_mod.tsv", z);
    }

    protected void runTest(LemmatizerTrainer lemmatizerTrainer, double d, double d2, String str) {
        runTest(lemmatizerTrainer, d, d2, str, false);
    }

    protected void runTest(LemmatizerTrainer lemmatizerTrainer, double d, double d2, String str, boolean z) {
        String str2 = pos_indexes;
        if (z) {
            str2 = morph_indexes;
        }
        List<LemmaInstance> instances = LemmaInstance.getInstances(new SentenceReader(str2 + getResourceFile(str)));
        Lemmatizer train = lemmatizerTrainer.train(instances, null);
        assertAccuracy(train, instances, d);
        assertAccuracy(train, LemmaInstance.getInstances(new SentenceReader(str2 + getResourceFile("dev.tsv"))), d2);
        assertAccuracy(train, LemmaInstance.getInstances(new SentenceReader(str2 + getResourceFile("dev.tsv.morfette"))), 1.0d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void testIfLemmatizerIsSerializable(LemmatizerTrainer lemmatizerTrainer) {
        Copy.clone(lemmatizerTrainer.train(LemmaInstance.getInstances(pos_indexes + getResourceFile("trn_sml.tsv")), null));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void assertAccuracy(Lemmatizer lemmatizer, Collection<LemmaInstance> collection, double d) {
        LemmaResult test = LemmaResult.test(lemmatizer, collection);
        double tokenAccuracy = test.getTokenAccuracy();
        test.logAccuracy();
        if (!Numerics.approximatelyGreaterEqual(tokenAccuracy, d)) {
            throw new AssertionFailedError(String.format("%g > %g", Double.valueOf(tokenAccuracy), Double.valueOf(d)));
        }
    }
}
