package cmu.arktweetnlp;

import cmu.arktweetnlp.impl.Model;
import cmu.arktweetnlp.impl.ModelSentence;
import cmu.arktweetnlp.impl.OWLQN;
import cmu.arktweetnlp.impl.Sentence;
import cmu.arktweetnlp.impl.features.FeatureExtractor;
import cmu.arktweetnlp.io.CoNLLReader;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.DiffFunction;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;

/* loaded from: input_file:cmu/arktweetnlp/Train.class */
public class Train {
    static final /* synthetic */ boolean $assertionsDisabled;
    public double l2penalty = 2.0d;
    public double l1penalty = 0.25d;
    public double tol = 1.0E-7d;
    public int maxIter = 500;
    public String modelLoadFilename = null;
    public String examplesFilename = null;
    public String modelSaveFilename = null;
    public boolean dumpFeatures = false;
    private int numTokens = 0;
    private ArrayList<Sentence> lSentences = new ArrayList<>();
    private ArrayList<ModelSentence> mSentences = new ArrayList<>();
    private Model model = new Model();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cmu/arktweetnlp/Train$GradientCalculator.class */
    public class GradientCalculator implements DiffFunction {
        private GradientCalculator() {
        }

        public int domainDimension() {
            return Train.this.model.flatIDsize();
        }

        public double valueAt(double[] dArr) {
            Train.this.model.setCoefsFromFlat(dArr);
            double d = 0.0d;
            Iterator it = Train.this.mSentences.iterator();
            while (it.hasNext()) {
                d += Train.this.model.computeLogLik((ModelSentence) it.next());
            }
            return (-d) + Train.this.regularizerValue(dArr);
        }

        public double[] derivativeAt(double[] dArr) {
            double[] dArr2 = new double[Train.this.model.flatIDsize()];
            Train.this.model.setCoefsFromFlat(dArr);
            Iterator it = Train.this.mSentences.iterator();
            while (it.hasNext()) {
                Train.this.model.computeGradient((ModelSentence) it.next(), dArr2);
            }
            ArrayMath.multiplyInPlace(dArr2, -1.0d);
            Train.this.addL2regularizerGradient(dArr2, dArr);
            return dArr2;
        }
    }

    /* loaded from: input_file:cmu/arktweetnlp/Train$MyWeightsPrinter.class */
    public class MyWeightsPrinter implements OWLQN.WeightsPrinter {
        public MyWeightsPrinter() {
        }

        @Override // cmu.arktweetnlp.impl.OWLQN.WeightsPrinter
        public void printWeights() {
            double d = 0.0d;
            Iterator it = Train.this.mSentences.iterator();
            while (it.hasNext()) {
                d += Train.this.model.computeLogLik((ModelSentence) it.next());
            }
            System.out.printf("\tTokLL %.6f\t", Double.valueOf(d / Train.this.numTokens));
        }
    }

    Train() {
    }

    public void doFeatureDumping() throws IOException {
        readTrainingSentences(this.examplesFilename);
        constructLabelVocab();
        extractFeatures();
        dumpFeatures();
    }

    public void doTraining() throws IOException {
        readTrainingSentences(this.examplesFilename);
        constructLabelVocab();
        extractFeatures();
        this.model.lockdownAfterFeatureExtraction();
        if (this.modelLoadFilename != null) {
            readWarmStartModel();
        }
        optimizationLoop();
        this.model.saveModelAsText(this.modelSaveFilename);
    }

    public void readTrainingSentences(String str) throws IOException {
        this.lSentences = CoNLLReader.readFile(str);
        Iterator<Sentence> it = this.lSentences.iterator();
        while (it.hasNext()) {
            this.numTokens += it.next().T();
        }
    }

    public void constructLabelVocab() {
        Iterator<Sentence> it = this.lSentences.iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = it.next().labels.iterator();
            while (it2.hasNext()) {
                this.model.labelVocab.num(it2.next());
            }
        }
        this.model.labelVocab.lock();
        this.model.numLabels = this.model.labelVocab.size();
    }

    public void dumpFeatures() throws IOException {
        FeatureExtractor featureExtractor = new FeatureExtractor(this.model, true);
        featureExtractor.dumpMode = true;
        Iterator<Sentence> it = this.lSentences.iterator();
        while (it.hasNext()) {
            Sentence next = it.next();
            featureExtractor.computeFeatures(next, new ModelSentence(next.T()));
        }
    }

    public void extractFeatures() throws IOException {
        System.out.println("Extracting features");
        FeatureExtractor featureExtractor = new FeatureExtractor(this.model, true);
        Iterator<Sentence> it = this.lSentences.iterator();
        while (it.hasNext()) {
            Sentence next = it.next();
            ModelSentence modelSentence = new ModelSentence(next.T());
            featureExtractor.computeFeatures(next, modelSentence);
            this.mSentences.add(modelSentence);
        }
    }

    public void readWarmStartModel() throws IOException {
        if (!$assertionsDisabled && !this.model.featureVocab.isLocked()) {
            throw new AssertionError();
        }
        Model.copyCoefsForIntersectingFeatures(Model.loadModelFromText(this.modelLoadFilename), this.model);
    }

    public void optimizationLoop() {
        OWLQN owlqn = new OWLQN();
        owlqn.setMaxIters(this.maxIter);
        owlqn.setQuiet(false);
        owlqn.setWeightsPrinting(new MyWeightsPrinter());
        this.model.setCoefsFromFlat(owlqn.minimize(new GradientCalculator(), this.model.convertCoefsToFlat(), this.l1penalty, this.tol, 5));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void addL2regularizerGradient(double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && dArr.length != dArr2.length) {
            throw new AssertionError();
        }
        for (int i = 0; i < dArr2.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + (this.l2penalty * dArr2[i]);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double regularizerValue(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += Math.pow(d2, 2.0d);
        }
        return 0.5d * this.l2penalty * d;
    }

    public static void main(String[] strArr) throws IOException {
        Train train = new Train();
        if (strArr.length < 2 || strArr[0].equals("-h") || strArr[1].equals("--help")) {
            usage();
        }
        int i = 0;
        while (i < strArr.length && strArr[i].startsWith("-")) {
            if (strArr[i].equals("--warm-start")) {
                train.modelLoadFilename = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equals("--max-iter")) {
                train.maxIter = Integer.parseInt(strArr[i + 1]);
                i += 2;
            } else if (strArr[i].equals("--dump-feat")) {
                train.dumpFeatures = true;
                i++;
            } else if (strArr[i].equals("--l2")) {
                train.l2penalty = Double.parseDouble(strArr[i + 1]);
                i += 2;
            } else if (strArr[i].equals("--l1")) {
                train.l1penalty = Double.parseDouble(strArr[i + 1]);
                i += 2;
            } else {
                usage();
            }
        }
        if (train.dumpFeatures) {
            train.examplesFilename = strArr[i];
            train.doFeatureDumping();
            System.exit(0);
        }
        if (strArr.length - i < 2) {
            usage();
        }
        train.examplesFilename = strArr[i];
        train.modelSaveFilename = strArr[i + 1];
        train.doTraining();
    }

    public static void usage() {
        System.out.println("Train [options] <ExamplesFilename> <ModelOutputFilename>\nOptions:\n  --max-iter <n>\n  --warm-start <modelfile>    Initializes at weights of this model.  discards base features that aren't in training set.\n  --dump-feat                 Show extracted features, instead of training. Useful for debugging/analyzing feature extractors.\n");
        System.exit(1);
    }

    static {
        $assertionsDisabled = !Train.class.desiredAssertionStatus();
    }
}
