package lemming.test.lemma.toutanova;

import java.util.List;
import java.util.logging.Logger;
import junit.framework.AssertionFailedError;
import lemming.lemma.LemmaInstance;
import lemming.lemma.toutanova.Result;
import lemming.lemma.toutanova.ToutanovaInstance;
import lemming.lemma.toutanova.ToutanovaLemmatizer;
import lemming.lemma.toutanova.ToutanovaModel;
import lemming.lemma.toutanova.ToutanovaTrainer;
import lemming.lemma.toutanova.ZeroOrderDecoder;
import lemming.lemma.toutanova.ZeroOrderNbestDecoder;
import marmot.morph.io.SentenceReader;
import marmot.util.Numerics;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:lemming/test/lemma/toutanova/NbestDecoderTest.class */
public class NbestDecoderTest {
    private static final double DELTA = 0.01d;

    public void trainDecodeTest(String str, String str2, int i, int i2) {
        testDecoder((ToutanovaLemmatizer) new ToutanovaTrainer().train(LemmaInstance.getInstances(new SentenceReader(str)), (List<LemmaInstance>) null), str2, i2);
    }

    private void testDecoder(ToutanovaLemmatizer toutanovaLemmatizer, String str, int i) {
        ToutanovaModel model = toutanovaLemmatizer.getModel();
        ZeroOrderDecoder zeroOrderDecoder = new ZeroOrderDecoder();
        zeroOrderDecoder.init(model);
        ZeroOrderNbestDecoder zeroOrderNbestDecoder = new ZeroOrderNbestDecoder(i);
        zeroOrderNbestDecoder.init(model);
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        for (LemmaInstance lemmaInstance : LemmaInstance.getInstances(new SentenceReader(str))) {
            ToutanovaInstance toutanovaInstance = new ToutanovaInstance(lemmaInstance, null);
            model.addIndexes(toutanovaInstance, false);
            Result decode = zeroOrderDecoder.decode(toutanovaInstance);
            double score = model.getScore(toutanovaInstance, decode);
            double score2 = decode.getScore();
            Assert.assertEquals(score, score2, DELTA);
            List<Result> decode2 = zeroOrderNbestDecoder.decode(toutanovaInstance);
            Assert.assertTrue(!decode2.isEmpty());
            Result result = decode2.get(0);
            Assert.assertEquals(decode.getOutput(), result.getOutput());
            Assert.assertEquals(score2, result.getScore(), DELTA);
            Result result2 = null;
            boolean z = false;
            for (Result result3 : decode2) {
                Assert.assertEquals(model.getScore(toutanovaInstance, result3), result3.getScore(), DELTA);
                if (result2 != null && !Numerics.approximatelyLesserEqual(result3.getScore(), result2.getScore())) {
                    throw new AssertionFailedError(String.format("%g <= %g", Double.valueOf(result3.getScore()), Double.valueOf(result2.getScore())));
                }
                result2 = result3;
                if (result3.getOutput().equals(lemmaInstance.getLemma())) {
                    z = true;
                }
            }
            if (z) {
                i3 = (int) (i3 + lemmaInstance.getCount());
            }
            if (decode.getOutput().equals(lemmaInstance.getLemma())) {
                i2 = (int) (i2 + lemmaInstance.getCount());
            }
            i4 = (int) (i4 + lemmaInstance.getCount());
        }
        Logger logger = Logger.getLogger(getClass().getName());
        logger.info(String.format("One-best : %5d %5d = %g", Integer.valueOf(i2), Integer.valueOf(i4), Double.valueOf((i2 * 100.0d) / i4)));
        logger.info(String.format("N-best : %5d %5d = %g", Integer.valueOf(i3), Integer.valueOf(i4), Double.valueOf((i3 * 100.0d) / i4)));
    }

    @Test
    public void test() {
        String str = "form-index=4,lemma-index=5,tag-index=2," + getResourceFile("trn_mod.tsv");
        String str2 = "form-index=4,lemma-index=5,tag-index=2," + getResourceFile("dev.tsv");
        trainDecodeTest(str, str, 1, 5);
        trainDecodeTest(str, str2, 10, 10);
    }

    protected String getResourceFile(String str) {
        return String.format("res:///%s/%s", "marmot/test/lemma", str);
    }
}
