package chipmunk.test.segmenter;

import chipmunk.segmenter.SegmentationInstance;
import chipmunk.segmenter.SegmentationReading;
import chipmunk.segmenter.SegmentationResult;
import chipmunk.segmenter.SegmentationSumLattice;
import chipmunk.segmenter.SegmenterModel;
import chipmunk.segmenter.SegmenterOptions;
import chipmunk.segmenter.Word;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import marmot.util.DynamicWeights;
import marmot.util.Numerics;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:chipmunk/test/segmenter/SumLatticeTest.class */
public class SumLatticeTest {
    static final /* synthetic */ boolean $assertionsDisabled;

    double explicit_update(SegmentationInstance segmentationInstance, SegmenterModel segmenterModel) {
        int maxSegmentLength = segmenterModel.getMaxSegmentLength();
        LinkedList linkedList = new LinkedList();
        addAllResults(segmentationInstance, segmenterModel, maxSegmentLength, linkedList, 0);
        double d = Double.NEGATIVE_INFINITY;
        Iterator<SegmentationResult> it = linkedList.iterator();
        while (it.hasNext()) {
            d = Numerics.sumLogProb(segmenterModel.getScore(segmentationInstance, it.next()), d);
        }
        for (SegmentationResult segmentationResult : linkedList) {
            segmenterModel.update(segmentationInstance, segmentationResult, -Math.exp(segmenterModel.getScore(segmentationInstance, segmentationResult) - d));
        }
        if (!$assertionsDisabled && segmentationInstance.getResults().size() != 1) {
            throw new AssertionError();
        }
        SegmentationResult next = segmentationInstance.getResults().iterator().next();
        double score = segmenterModel.getScore(segmentationInstance, next) - d;
        segmenterModel.update(segmentationInstance, next, 1.0d);
        return score;
    }

    private void addAllResults(SegmentationInstance segmentationInstance, SegmenterModel segmenterModel, int i, List<SegmentationResult> list, int i2) {
        String word = segmentationInstance.getWord().getWord();
        for (int i3 = i2 + 1; i3 <= Math.min(i2 + i, word.length()); i3++) {
            LinkedList linkedList = new LinkedList();
            if (i3 == word.length()) {
                for (int i4 = 0; i4 < segmenterModel.getNumTags(); i4++) {
                    LinkedList linkedList2 = new LinkedList();
                    linkedList2.add(Integer.valueOf(i4));
                    LinkedList linkedList3 = new LinkedList();
                    linkedList3.add(Integer.valueOf(i3));
                    list.add(new SegmentationResult(linkedList2, linkedList3));
                }
            } else {
                addAllResults(segmentationInstance, segmenterModel, i, linkedList, i3);
                for (SegmentationResult segmentationResult : linkedList) {
                    for (int i5 = 0; i5 < segmenterModel.getNumTags(); i5++) {
                        LinkedList linkedList4 = new LinkedList();
                        linkedList4.add(Integer.valueOf(i5));
                        linkedList4.addAll(segmentationResult.getTags());
                        LinkedList linkedList5 = new LinkedList();
                        linkedList5.add(Integer.valueOf(i3));
                        linkedList5.addAll(segmentationResult.getInputIndexes());
                        list.add(new SegmentationResult(linkedList4, linkedList5));
                    }
                }
            }
        }
    }

    @Test
    public void test() {
        LinkedList<Word> linkedList = new LinkedList();
        linkedList.add(toWord(Arrays.asList("b"), Arrays.asList("B")));
        linkedList.add(toWord(Arrays.asList("aa"), Arrays.asList("A")));
        linkedList.add(toWord(Arrays.asList("a", "bb"), Arrays.asList("A", "B")));
        linkedList.add(toWord(Arrays.asList("aa", "bb"), Arrays.asList("A", "B")));
        linkedList.add(toWord(Arrays.asList("a", "b"), Arrays.asList("A", "B")));
        linkedList.add(toWord(Arrays.asList("aa", "b"), Arrays.asList("A", "B")));
        linkedList.add(toWord(Arrays.asList("aa", "c"), Arrays.asList("A", "C")));
        SegmenterModel segmenterModel = new SegmenterModel();
        SegmenterOptions segmenterOptions = new SegmenterOptions();
        segmenterOptions.setOption(SegmenterOptions.USE_CHARACTER_FEATURE, false);
        segmenterOptions.setOption(SegmenterOptions.USE_SEGMENT_CONTEXT, false);
        segmenterModel.init(segmenterOptions, linkedList);
        SegmentationSumLattice segmentationSumLattice = new SegmentationSumLattice(segmenterModel);
        Random random = new Random(42L);
        for (int i = 0; i < 10; i++) {
            double[] dArr = new double[50];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = random.nextGaussian();
            }
            double[] dArr2 = new double[dArr.length];
            segmenterModel.setScorerWeights(new DynamicWeights(dArr, false, false));
            segmenterModel.setUpdaterWeights(new DynamicWeights(dArr2, false, false));
            for (Word word : linkedList) {
                SegmentationInstance segmenterModel2 = segmenterModel.getInstance(word);
                double update = segmentationSumLattice.update(segmenterModel2, true);
                double[] dArr3 = (double[]) dArr2.clone();
                Arrays.fill(dArr2, 0.0d);
                double explicit_update = explicit_update(segmenterModel2, segmenterModel);
                double[] dArr4 = (double[]) dArr2.clone();
                Arrays.fill(dArr2, 0.0d);
                boolean approximatelyEqual = Numerics.approximatelyEqual(dArr3, dArr4, 1.0E-5d);
                if (!approximatelyEqual) {
                    System.err.println(Arrays.toString(dArr3) + "\n" + Arrays.toString(dArr4));
                }
                boolean approximatelyEqual2 = Numerics.approximatelyEqual(update, explicit_update);
                if (!approximatelyEqual2) {
                    System.err.println(word + " " + update + "\n" + explicit_update);
                }
                Assert.assertTrue(approximatelyEqual && approximatelyEqual2);
            }
        }
    }

    private Word toWord(List<String> list, List<String> list2) {
        String str = "";
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            str = str + it.next();
        }
        Word word = new Word(str);
        word.add(new SegmentationReading(list, list2));
        return word;
    }

    static {
        $assertionsDisabled = !SumLatticeTest.class.desiredAssertionStatus();
    }
}
