package lemming.lemma.toutanova;

import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Logger;
import lemming.lemma.LemmaCandidateGenerator;
import lemming.lemma.LemmaInstance;
import lemming.lemma.LemmaOptions;
import lemming.lemma.Lemmatizer;
import lemming.lemma.LemmatizerGenerator;
import lemming.lemma.LemmatizerGeneratorTrainer;
import marmot.util.DynamicWeights;

/* loaded from: input_file:lemming/lemma/toutanova/ToutanovaTrainer.class */
public class ToutanovaTrainer implements LemmatizerGeneratorTrainer {
    private ToutanovaOptions options_ = new ToutanovaOptions();
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:lemming/lemma/toutanova/ToutanovaTrainer$ToutanovaOptions.class */
    public static class ToutanovaOptions extends LemmaOptions {
        private static final long serialVersionUID = 1;
        public static final String FILTER_ALPHABET = "filter-alphabet";
        public static final String ALIGNER_TRAINER = "aligner-trainer";
        public static final String DECODER = "decoder";
        public static final String MAX_COUNT = "max-count";
        public static final String NBEST_RANK = "nbest-rank";
        public static final String WINDOW_SIZE = "window-size";

        public ToutanovaOptions() {
            this.map_.put(FILTER_ALPHABET, 5);
            this.map_.put(ALIGNER_TRAINER, EditTreeAlignerTrainer.class);
            this.map_.put(DECODER, ZeroOrderDecoder.class);
            this.map_.put(MAX_COUNT, 1);
            this.map_.put(NBEST_RANK, 50);
            this.map_.put(WINDOW_SIZE, 2);
        }

        public static ToutanovaOptions newInstance() {
            return new ToutanovaOptions();
        }

        public int getFilterAlphabet() {
            return ((Integer) getOption(FILTER_ALPHABET)).intValue();
        }

        public AlignerTrainer getAligner() {
            return (AlignerTrainer) getInstance(ALIGNER_TRAINER);
        }

        public Decoder getDecoderInstance() {
            return (Decoder) getInstance(DECODER);
        }

        public int getMaxCount() {
            return ((Integer) getOption(MAX_COUNT)).intValue();
        }

        public int getNbestRank() {
            return ((Integer) getOption(NBEST_RANK)).intValue();
        }

        public int getMaxWindowSize() {
            return ((Integer) getOption(WINDOW_SIZE)).intValue();
        }
    }

    public static List<ToutanovaInstance> createToutanovaInstances(List<LemmaInstance> list, Aligner aligner) {
        LinkedList linkedList = new LinkedList();
        for (LemmaInstance lemmaInstance : list) {
            List<Integer> list2 = null;
            if (aligner != null) {
                list2 = aligner.align(lemmaInstance.getForm(), lemmaInstance.getLemma());
                if (!$assertionsDisabled && list2 == null) {
                    throw new AssertionError();
                }
            }
            linkedList.add(new ToutanovaInstance(lemmaInstance, list2));
        }
        return linkedList;
    }

    @Override // lemming.lemma.LemmatizerGeneratorTrainer, lemming.lemma.LemmatizerTrainer, lemming.lemma.LemmaCandidateGeneratorTrainer
    public LemmatizerGenerator train(List<LemmaInstance> list, List<LemmaInstance> list2) {
        List<ToutanovaInstance> createToutanovaInstances = createToutanovaInstances(list, this.options_.getAligner().train(list));
        List<ToutanovaInstance> list3 = null;
        if (list2 != null) {
            list3 = createToutanovaInstances(list2, null);
        }
        return trainAligned(createToutanovaInstances, list3);
    }

    public LemmatizerGenerator trainAligned(List<ToutanovaInstance> list, List<ToutanovaInstance> list2) {
        Logger logger = Logger.getLogger(getClass().getName());
        ToutanovaModel toutanovaModel = new ToutanovaModel();
        toutanovaModel.init(this.options_, list, list2);
        DynamicWeights weights = toutanovaModel.getWeights();
        DynamicWeights dynamicWeights = this.options_.getAveraging() ? new DynamicWeights(null) : null;
        Decoder decoderInstance = this.options_.getDecoderInstance();
        decoderInstance.init(toutanovaModel);
        LinkedList<ToutanovaInstance> linkedList = new LinkedList();
        for (ToutanovaInstance toutanovaInstance : list) {
            if (!toutanovaInstance.isRare()) {
                for (int i = 0; i < Math.min(this.options_.getMaxCount(), toutanovaInstance.getInstance().getCount()); i++) {
                    linkedList.add(toutanovaInstance);
                }
            }
        }
        for (int i2 = 0; i2 < this.options_.getNumIterations(); i2++) {
            logger.info(String.format("Iter: %3d / %3d", Integer.valueOf(i2 + 1), Integer.valueOf(this.options_.getNumIterations())));
            double d = 0.0d;
            double d2 = 0.0d;
            int i3 = 0;
            Collections.shuffle(linkedList, this.options_.getRandom());
            for (ToutanovaInstance toutanovaInstance2 : linkedList) {
                Result decode = decoderInstance.decode(toutanovaInstance2);
                if (decode.getOutput().equals(toutanovaInstance2.getInstance().getLemma())) {
                    d += 1.0d;
                } else {
                    toutanovaModel.update(toutanovaInstance2, decode, -1.0d);
                    toutanovaModel.update(toutanovaInstance2, toutanovaInstance2.getResult(), 1.0d);
                    if (dynamicWeights != null) {
                        double size = linkedList.size() - i3;
                        if (!$assertionsDisabled && size <= 0.0d) {
                            throw new AssertionError();
                        }
                        toutanovaModel.setWeights(dynamicWeights);
                        dynamicWeights = toutanovaModel.getWeights();
                        toutanovaModel.update(toutanovaInstance2, decode, -size);
                        toutanovaModel.update(toutanovaInstance2, toutanovaInstance2.getResult(), size);
                        toutanovaModel.setWeights(weights);
                        weights = toutanovaModel.getWeights();
                    }
                }
                d2 += 1.0d;
                i3++;
                if (i3 % 1000 == 0 && this.options_.getVerbosity() > 0) {
                    logger.info(String.format("Processed: %3d / %3d", Integer.valueOf(i3), Integer.valueOf(linkedList.size())));
                }
            }
            if (dynamicWeights != null) {
                double size2 = 1.0d / ((i2 + 1.0d) * linkedList.size());
                double d3 = (i2 + 2.0d) / (i2 + 1.0d);
                for (int i4 = 0; i4 < weights.getLength(); i4++) {
                    weights.set(i4, dynamicWeights.get(i4) * size2);
                    dynamicWeights.set(i4, dynamicWeights.get(i4) * d3);
                }
            }
            logger.info(String.format("Train Accuracy: %g / %g = %g", Double.valueOf(d), Double.valueOf(d2), Double.valueOf((d * 100.0d) / d2)));
        }
        return new ToutanovaLemmatizer(this.options_, toutanovaModel);
    }

    @Override // lemming.lemma.LemmatizerTrainer, lemming.lemma.LemmaCandidateGeneratorTrainer
    public LemmaOptions getOptions() {
        return this.options_;
    }

    @Override // lemming.lemma.LemmatizerGeneratorTrainer, lemming.lemma.LemmatizerTrainer, lemming.lemma.LemmaCandidateGeneratorTrainer
    public /* bridge */ /* synthetic */ Lemmatizer train(List list, List list2) {
        return train((List<LemmaInstance>) list, (List<LemmaInstance>) list2);
    }

    @Override // lemming.lemma.LemmatizerGeneratorTrainer, lemming.lemma.LemmaCandidateGeneratorTrainer
    public /* bridge */ /* synthetic */ LemmaCandidateGenerator train(List list, List list2) {
        return train((List<LemmaInstance>) list, (List<LemmaInstance>) list2);
    }

    static {
        $assertionsDisabled = !ToutanovaTrainer.class.desiredAssertionStatus();
    }
}
