package lemming.test.lemma.toutanova;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import lemming.lemma.LemmaInstance;
import lemming.lemma.toutanova.FirstOrderDecoder;
import lemming.lemma.toutanova.Result;
import lemming.lemma.toutanova.ToutanovaInstance;
import lemming.lemma.toutanova.ToutanovaModel;
import lemming.lemma.toutanova.ToutanovaTrainer;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:lemming/test/lemma/toutanova/DecoderTest.class */
public class DecoderTest {
    @Test
    public void test() {
        ToutanovaModel toutanovaModel = new ToutanovaModel();
        LinkedList linkedList = new LinkedList();
        linkedList.add(new ToutanovaInstance(new LemmaInstance("aaae", "aaa", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
        linkedList.add(new ToutanovaInstance(new LemmaInstance("bbbe", "bbb", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
        linkedList.add(new ToutanovaInstance(new LemmaInstance("ccce", "ccc", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
        linkedList.add(new ToutanovaInstance(new LemmaInstance("aaaf", "aaa", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
        linkedList.add(new ToutanovaInstance(new LemmaInstance("bbbf", "bbb", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
        linkedList.add(new ToutanovaInstance(new LemmaInstance("cccf", "ccc", null, null), Arrays.asList(1, 1, 1, 1, 2, 1)));
        toutanovaModel.init(ToutanovaTrainer.ToutanovaOptions.newInstance(), linkedList, null);
        FirstOrderDecoder firstOrderDecoder = new FirstOrderDecoder();
        firstOrderDecoder.init(toutanovaModel);
        int index = toutanovaModel.getOutputTable().toIndex("a");
        int index2 = toutanovaModel.getOutputTable().toIndex("b");
        int index3 = toutanovaModel.getOutputTable().toIndex("c");
        assertResultEquals(Arrays.asList(Integer.valueOf(index), Integer.valueOf(index), Integer.valueOf(index)), Arrays.asList(1, 2, 4), firstOrderDecoder.decode(linkedList.get(0)));
        assertResultEquals(Arrays.asList(Integer.valueOf(index), Integer.valueOf(index), Integer.valueOf(index)), Arrays.asList(1, 2, 4), firstOrderDecoder.decode(linkedList.get(3)));
        assertResultEquals(Arrays.asList(Integer.valueOf(index2), Integer.valueOf(index2), Integer.valueOf(index2)), Arrays.asList(1, 2, 4), firstOrderDecoder.decode(linkedList.get(1)));
        assertResultEquals(Arrays.asList(Integer.valueOf(index2), Integer.valueOf(index2), Integer.valueOf(index2)), Arrays.asList(1, 2, 4), firstOrderDecoder.decode(linkedList.get(4)));
        assertResultEquals(Arrays.asList(Integer.valueOf(index3), Integer.valueOf(index3), Integer.valueOf(index3)), Arrays.asList(1, 2, 4), firstOrderDecoder.decode(linkedList.get(2)));
        assertResultEquals(Arrays.asList(Integer.valueOf(index3), Integer.valueOf(index3), Integer.valueOf(index3)), Arrays.asList(1, 2, 4), firstOrderDecoder.decode(linkedList.get(5)));
    }

    private void assertResultEquals(List<Integer> list, List<Integer> list2, Result result) {
        Assert.assertEquals(list, result.getOutputs());
        Assert.assertEquals(list2, result.getInputs());
    }
}
