package chipmunk.segmenter;

import java.util.Arrays;
import marmot.util.Numerics;

/* loaded from: input_file:chipmunk/segmenter/SegmentationSumLattice.class */
public class SegmentationSumLattice {
    private SegmenterModel model_;
    private int num_tags_;
    private int max_segment_length;
    private double[] forward_score_array_;
    private double[] backward_score_array_;
    private int input_length_;

    public SegmentationSumLattice(SegmenterModel segmenterModel) {
        this.model_ = segmenterModel;
        this.num_tags_ = this.model_.getNumTags();
        this.max_segment_length = this.model_.getMaxSegmentLength();
    }

    public double update(SegmentationInstance segmentationInstance, boolean z) {
        this.input_length_ = segmentationInstance.getLength();
        checkArraySize(this.num_tags_ * this.input_length_);
        Arrays.fill(this.forward_score_array_, Double.NEGATIVE_INFINITY);
        for (int i = 1; i < this.input_length_ + 1; i++) {
            for (int i2 = 0; i2 < this.num_tags_; i2++) {
                double d = Double.NEGATIVE_INFINITY;
                for (int max = Math.max(0, i - this.max_segment_length); max < i; max++) {
                    double pairScore = this.model_.getPairScore(segmentationInstance, max, i, i2);
                    if (max == 0) {
                        d = Numerics.sumLogProb(pairScore, d);
                    } else {
                        for (int i3 = 0; i3 < this.num_tags_; i3++) {
                            d = Numerics.sumLogProb(pairScore + this.model_.getTransitionScore(segmentationInstance, i3, i2, max, i) + this.forward_score_array_[getIndex(i3, max - 1)], d);
                        }
                    }
                }
                this.forward_score_array_[getIndex(i2, i - 1)] = d;
            }
        }
        Arrays.fill(this.backward_score_array_, Double.NEGATIVE_INFINITY);
        for (int i4 = this.input_length_ - 1; i4 >= 0; i4--) {
            for (int i5 = 0; i5 < this.num_tags_; i5++) {
                double d2 = Double.NEGATIVE_INFINITY;
                for (int min = Math.min(this.input_length_, i4 + this.max_segment_length); min > i4; min--) {
                    double pairScore2 = this.model_.getPairScore(segmentationInstance, i4, min, i5);
                    if (min == this.input_length_) {
                        d2 = Numerics.sumLogProb(pairScore2, d2);
                    } else {
                        for (int i6 = 0; i6 < this.num_tags_; i6++) {
                            d2 = Numerics.sumLogProb(pairScore2 + this.model_.getTransitionScore(segmentationInstance, i5, i6, i4, min) + this.backward_score_array_[getIndex(i6, min)], d2);
                        }
                    }
                }
                this.backward_score_array_[getIndex(i5, i4)] = d2;
            }
        }
        double sumTag = sumTag(this.backward_score_array_, 0);
        for (int i7 = 1; i7 < this.input_length_ + 1; i7++) {
            for (int i8 = 0; i8 < this.num_tags_; i8++) {
                for (int max2 = Math.max(0, i7 - this.max_segment_length); max2 < i7; max2++) {
                    double pairScore3 = this.model_.getPairScore(segmentationInstance, max2, i7, i8);
                    double d3 = 0.0d;
                    if (i7 < this.input_length_) {
                        d3 = Double.NEGATIVE_INFINITY;
                        for (int i9 = 0; i9 < this.num_tags_; i9++) {
                            d3 = Numerics.sumLogProb(this.backward_score_array_[getIndex(i9, i7)] + this.model_.getTransitionScore(segmentationInstance, i8, i9, max2, i7), d3);
                        }
                    }
                    if (max2 == 0) {
                        double d4 = -Math.exp((d3 + pairScore3) - sumTag);
                        if (z) {
                            this.model_.update(segmentationInstance, max2, i7, i8, d4);
                        }
                    } else {
                        double d5 = 0.0d;
                        for (int i10 = 0; i10 < this.num_tags_; i10++) {
                            double d6 = -Math.exp((((this.forward_score_array_[getIndex(i10, max2 - 1)] + pairScore3) + this.model_.getTransitionScore(segmentationInstance, i10, i8, max2, i7)) + d3) - sumTag);
                            if (z) {
                                this.model_.update(segmentationInstance, max2, i7, i10, i8, d6);
                            }
                            d5 += d6;
                        }
                        if (z) {
                            this.model_.update(segmentationInstance, max2, i7, i8, d5);
                        }
                    }
                }
            }
        }
        double d7 = 0.0d;
        for (SegmentationResult segmentationResult : segmentationInstance.getResults()) {
            this.model_.update(segmentationInstance, segmentationResult, 1.0d / segmentationInstance.getResults().size());
            d7 += this.model_.getScore(segmentationInstance, segmentationResult) - sumTag;
        }
        return d7;
    }

    private double sumTag(double[] dArr, int i) {
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.num_tags_; i2++) {
            d = Numerics.sumLogProb(dArr[getIndex(i2, i)], d);
        }
        return d;
    }

    private int getIndex(int i, int i2) {
        return (i * this.input_length_) + i2;
    }

    private void checkArraySize(int i) {
        if (this.forward_score_array_ == null || this.forward_score_array_.length < i) {
            this.forward_score_array_ = new double[i];
            this.backward_score_array_ = new double[i];
        }
    }
}
