package marmot.core.lattice;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Set;
import lemming.lemma.ranker.RankerCandidate;
import marmot.core.State;
import marmot.core.Transition;
import marmot.util.HashableIntArray;

/* loaded from: input_file:marmot/core/lattice/SequenceViterbiLattice.class */
public class SequenceViterbiLattice implements ViterbiLattice {
    private LatticeEntry[][][] lattice_;
    private List<List<State>> candidates_;
    private State boundary_;
    private int beam_size_;
    private boolean initilized_ = false;
    private boolean marginalize_lemmas_;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SequenceViterbiLattice(List<List<State>> list, State state, int i, boolean z) {
        this.candidates_ = list;
        this.boundary_ = state;
        this.beam_size_ = i;
        this.marginalize_lemmas_ = z;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v4, types: [marmot.core.lattice.LatticeEntry[][], marmot.core.lattice.LatticeEntry[][][]] */
    public void init() {
        LatticeEntry latticeEntry;
        if (this.initilized_) {
            return;
        }
        this.initilized_ = true;
        this.lattice_ = new LatticeEntry[this.candidates_.size()];
        PriorityQueue priorityQueue = new PriorityQueue();
        List<State> singletonList = Collections.singletonList(this.boundary_);
        for (int i = 0; i < this.candidates_.size(); i++) {
            List<State> list = this.candidates_.get(i);
            this.lattice_[i] = new LatticeEntry[list.size()];
            int i2 = 0;
            for (State state : list) {
                priorityQueue.clear();
                double score = state.getScore();
                State zeroOrderState = state.getZeroOrderState();
                if (zeroOrderState.getLemmaCandidates() != null && !this.marginalize_lemmas_) {
                    double score2 = (score - zeroOrderState.getScore()) + zeroOrderState.getRealScore();
                    score = Double.NEGATIVE_INFINITY;
                    Iterator<RankerCandidate> it = zeroOrderState.getLemmaCandidates().iterator();
                    while (it.hasNext()) {
                        score = Math.max(score, score2 + it.next().getScore());
                    }
                }
                for (int i3 = 0; i3 < singletonList.size(); i3++) {
                    Transition transition = state.getTransition(i3);
                    if (transition != null) {
                        double score3 = score + transition.getScore();
                        if (i > 0) {
                            score3 += this.lattice_[i - 1][i3][0].getScore();
                        }
                        priorityQueue.add(new LatticeEntry(score3, i3));
                    }
                }
                int min = Math.min(this.beam_size_, priorityQueue.size());
                if (!$assertionsDisabled && min <= 0) {
                    throw new AssertionError();
                }
                this.lattice_[i][i2] = new LatticeEntry[min];
                for (int i4 = 0; i4 < min && (latticeEntry = (LatticeEntry) priorityQueue.poll()) != null; i4++) {
                    this.lattice_[i][i2][i4] = latticeEntry;
                }
                i2++;
            }
            singletonList = list;
        }
    }

    @Override // marmot.core.lattice.ViterbiLattice
    public Hypothesis getViterbiSequence() {
        init();
        return getSequenceBySignature(new HashableIntArray(new int[this.candidates_.size() - 1]));
    }

    public Hypothesis getSequenceBySignature(HashableIntArray hashableIntArray) {
        LatticeEntry latticeEntry;
        init();
        LinkedList linkedList = new LinkedList();
        int size = this.candidates_.size() - 1;
        int i = 0;
        linkedList.add(0);
        Double d = null;
        int[] array = hashableIntArray.getArray();
        while (size >= 1) {
            int i2 = array[size - 1];
            if (i2 >= this.lattice_[size][i].length || (latticeEntry = this.lattice_[size][i][i2]) == null) {
                return null;
            }
            if (d == null) {
                d = Double.valueOf(latticeEntry.getScore());
            }
            if (i2 != 0) {
                d = Double.valueOf(d.doubleValue() + (latticeEntry.getScore() - this.lattice_[size][i][0].getScore()));
            }
            i = latticeEntry.getPreviousStateIndex();
            size--;
            linkedList.add(Integer.valueOf(i));
        }
        if (d == null) {
            return null;
        }
        Collections.reverse(linkedList);
        return new Hypothesis(linkedList, d.doubleValue(), hashableIntArray);
    }

    @Override // marmot.core.lattice.ViterbiLattice
    public List<Hypothesis> getNbestSequences() {
        Hypothesis hypothesis;
        init();
        LinkedList linkedList = new LinkedList();
        HashableIntArray hashableIntArray = new HashableIntArray(new int[this.candidates_.size() - 1]);
        PriorityQueue priorityQueue = new PriorityQueue();
        HashSet hashSet = new HashSet();
        priorityQueue.add(getSequenceBySignature(hashableIntArray));
        hashSet.add(hashableIntArray);
        while (linkedList.size() < this.beam_size_ && (hypothesis = (Hypothesis) priorityQueue.poll()) != null) {
            linkedList.add(hypothesis);
            int[] array = hypothesis.getSignature().getArray();
            for (int i = 0; i < array.length; i++) {
                int[] iArr = new int[array.length];
                System.arraycopy(array, 0, iArr, 0, array.length);
                int i2 = i;
                iArr[i2] = iArr[i2] + 1;
                HashableIntArray hashableIntArray2 = new HashableIntArray(iArr);
                if (!hashSet.contains(hashableIntArray2)) {
                    hashSet.add(hashableIntArray2);
                    Hypothesis sequenceBySignature = getSequenceBySignature(hashableIntArray2);
                    if (sequenceBySignature != null) {
                        priorityQueue.add(sequenceBySignature);
                    }
                }
            }
        }
        return linkedList;
    }

    public void findGoldSequence(List<Integer> list) {
        LatticeEntry latticeEntry;
        init();
        if (!$assertionsDisabled && list.size() != this.candidates_.size()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && list.size() != this.lattice_.length) {
            throw new AssertionError();
        }
        for (int size = list.size() - 1; size > 0; size--) {
            int intValue = list.get(size).intValue();
            int intValue2 = list.get(size - 1).intValue();
            boolean z = false;
            LatticeEntry[] latticeEntryArr = this.lattice_[size][intValue];
            int length = latticeEntryArr.length;
            int i = 0;
            while (true) {
                if (i >= length || (latticeEntry = latticeEntryArr[i]) == null) {
                    break;
                }
                if (latticeEntry.getPreviousStateIndex() == intValue2) {
                    z = true;
                    break;
                }
                i++;
            }
            if (!z) {
                System.err.format("%s index = %d p_index = %d lattice entries = %s\n", this.candidates_.get(size).get(intValue), Integer.valueOf(size), Integer.valueOf(intValue2), Arrays.toString(this.lattice_[size][intValue]));
            }
        }
    }

    @Override // marmot.core.lattice.Lattice
    public List<List<State>> prune() {
        init();
        List<List<State>> candidates = getCandidates();
        ArrayList arrayList = new ArrayList(candidates.size());
        for (int i = 0; i < candidates.size(); i++) {
            arrayList.add(new HashSet());
        }
        Iterator<Hypothesis> it = getNbestSequences().iterator();
        while (it.hasNext()) {
            int i2 = 0;
            int i3 = 0;
            Iterator<Integer> it2 = it.next().getStates().iterator();
            while (it2.hasNext()) {
                int intValue = it2.next().intValue();
                ((Set) arrayList.get(i2)).add(Integer.valueOf((intValue * (i2 - 1 >= 0 ? candidates.get(i2 - 1).size() : 1)) + i3));
                i3 = intValue;
                i2++;
            }
        }
        ArrayList arrayList2 = new ArrayList(candidates.size());
        int[] iArr = null;
        for (int i4 = 0; i4 < candidates.size(); i4++) {
            Set set = (Set) arrayList.get(i4);
            int[] iArr2 = new int[candidates.get(i4).size()];
            Arrays.fill(iArr2, -1);
            ArrayList arrayList3 = new ArrayList(set.size());
            Iterator it3 = set.iterator();
            while (it3.hasNext()) {
                int intValue2 = ((Integer) it3.next()).intValue();
                int size = i4 - 1 >= 0 ? candidates.get(i4 - 1).size() : 1;
                int i5 = intValue2 / size;
                int i6 = intValue2 % size;
                int i7 = iArr2[i5];
                if (i7 < 0) {
                    i7 = arrayList3.size();
                    iArr2[i5] = i7;
                    State state = candidates.get(i4).get(i5);
                    if (i4 > 0) {
                        state = state.copy();
                        state.setTransitions(new Transition[((List) arrayList2.get(i4 - 1)).size()]);
                    }
                    arrayList3.add(state);
                }
                if (i4 > 0) {
                    ((State) arrayList3.get(i7)).getTransitions()[iArr[i6]] = candidates.get(i4).get(i5).getTransitions()[i6];
                }
            }
            arrayList2.add(arrayList3);
            iArr = iArr2;
        }
        return arrayList2;
    }

    @Override // marmot.core.lattice.Lattice
    public List<List<State>> getCandidates() {
        return this.candidates_;
    }

    static {
        $assertionsDisabled = !SequenceViterbiLattice.class.desiredAssertionStatus();
    }
}
