package cmu.arktweetnlp.impl;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.DiffFunction;
import gnu.trove.set.hash.THashSet;
import java.util.LinkedList;
import java.util.Queue;
import java.util.Set;

/* loaded from: input_file:cmu/arktweetnlp/impl/OWLQN.class */
public class OWLQN {
    private int maxIters;
    boolean quiet;
    boolean responsibleForTermCrit;
    TerminationCriterion termCrit;
    WeightsPrinter printer;
    private static boolean constrained = false;
    public static Set<Integer> biasParameters = new THashSet();
    private static int numUnconstrainedWeights = -1;

    /* loaded from: input_file:cmu/arktweetnlp/impl/OWLQN$RelativeMeanImprovementCriterion.class */
    static class RelativeMeanImprovementCriterion implements TerminationCriterion {
        int numItersToAvg;
        Queue<Double> prevVals;

        RelativeMeanImprovementCriterion() {
            this(10);
        }

        RelativeMeanImprovementCriterion(int i) {
            this.numItersToAvg = i;
            this.prevVals = new LinkedList();
        }

        @Override // cmu.arktweetnlp.impl.OWLQN.TerminationCriterion
        public double getValue(OptimizerState optimizerState, StringBuilder sb) {
            double d = Double.POSITIVE_INFINITY;
            if (this.prevVals.size() >= this.numItersToAvg) {
                double doubleValue = this.prevVals.peek().doubleValue();
                if (this.prevVals.size() == this.numItersToAvg) {
                    this.prevVals.poll();
                }
                double value = ((doubleValue - optimizerState.getValue()) / this.prevVals.size()) / Math.abs(optimizerState.getValue());
                sb.append("  (").append(String.format("%.4e", Double.valueOf(value))).append(") ");
                d = value;
            } else {
                sb.append("  (wait for " + this.numItersToAvg + " iters) ");
            }
            this.prevVals.offer(Double.valueOf(optimizerState.getValue()));
            return d;
        }
    }

    /* loaded from: input_file:cmu/arktweetnlp/impl/OWLQN$TerminationCriterion.class */
    interface TerminationCriterion {
        double getValue(OptimizerState optimizerState, StringBuilder sb);
    }

    /* loaded from: input_file:cmu/arktweetnlp/impl/OWLQN$WeightsPrinter.class */
    public interface WeightsPrinter {
        void printWeights();
    }

    public OWLQN(boolean z) {
        this.maxIters = Integer.MAX_VALUE;
        this.quiet = z;
        this.termCrit = new RelativeMeanImprovementCriterion();
        this.responsibleForTermCrit = true;
    }

    public OWLQN() {
        this(false);
    }

    public OWLQN(TerminationCriterion terminationCriterion, boolean z) {
        this.maxIters = Integer.MAX_VALUE;
        this.quiet = z;
        this.termCrit = terminationCriterion;
        this.responsibleForTermCrit = false;
    }

    public void setQuiet(boolean z) {
        this.quiet = z;
    }

    public double[] minimize(DiffFunction diffFunction, double[] dArr, double d, double d2, int i) {
        OptimizerState optimizerState = new OptimizerState(diffFunction, dArr, i, d, this.quiet);
        if (!this.quiet) {
            System.err.printf("Optimizing function of %d variables with OWL-QN parameters:\n", Integer.valueOf(optimizerState.dim));
            System.err.printf("   l1 regularization weight: %f.\n", Double.valueOf(d));
            System.err.printf("   L-BFGS memory parameter (m): %d\n", Integer.valueOf(i));
            System.err.printf("   Convergence tolerance: %f\n\n", Double.valueOf(d2));
            System.err.printf("Iter    n:\tnew_value\tdf\t(conv_crit)\tline_search\n", new Object[0]);
            System.err.printf("Iter    0:\t%.4e\t\t(***********)\t", Double.valueOf(optimizerState.value));
        }
        StringBuilder sb = new StringBuilder();
        this.termCrit.getValue(optimizerState, sb);
        int i2 = 0;
        while (true) {
            if (i2 >= this.maxIters) {
                break;
            }
            sb.setLength(0);
            optimizerState.updateDir();
            optimizerState.backTrackingLineSearch();
            double value = this.termCrit.getValue(optimizerState, sb);
            if (!this.quiet) {
                System.err.printf("Iter %4d:\t%.4e\t%d", Integer.valueOf(optimizerState.iter), Double.valueOf(optimizerState.value), Integer.valueOf(ArrayMath.countNonZero(optimizerState.newX)));
                System.err.print("\t" + sb.toString());
                if (this.printer != null) {
                    this.printer.printWeights();
                }
            }
            if (arrayEquals(optimizerState.x, optimizerState.newX)) {
                System.err.println("Warning: Stopping OWL-QN since there was no change in the parameters in the last iteration.  This probably means convergence has been reached.");
                break;
            }
            if (value < d2) {
                break;
            }
            optimizerState.shift();
            i2++;
        }
        if (!this.quiet) {
            System.err.println();
            System.err.printf("Finished with optimization.  %d/%d non-zero weights.\n", Integer.valueOf(ArrayMath.countNonZero(optimizerState.newX)), Integer.valueOf(optimizerState.newX.length));
        }
        return optimizerState.newX;
    }

    private boolean arrayEquals(double[] dArr, double[] dArr2) {
        if (dArr.length != dArr2.length) {
            return false;
        }
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] != dArr2[i]) {
                return false;
            }
        }
        return true;
    }

    public void setMaxIters(int i) {
        this.maxIters = i;
    }

    public int getMaxIters() {
        return this.maxIters;
    }

    public void setWeightsPrinting(WeightsPrinter weightsPrinter) {
        this.printer = weightsPrinter;
    }

    public static void setConstrained(boolean z) {
        constrained = z;
        numUnconstrainedWeights = z ? 0 : -1;
    }

    public static void setConstrained(int i) {
        numUnconstrainedWeights = i;
        constrained = i >= 0;
    }

    public static boolean isConstrained() {
        return constrained;
    }

    protected static double[] projectWeights(double[] dArr) {
        if (numUnconstrainedWeights == 0) {
            return project(dArr);
        }
        double[] dArr2 = new double[dArr.length - numUnconstrainedWeights];
        for (int i = numUnconstrainedWeights; i < dArr.length; i++) {
            dArr2[i - numUnconstrainedWeights] = dArr[i];
        }
        double[] project = project(dArr2);
        double[] dArr3 = new double[dArr.length];
        int i2 = 0;
        while (i2 < dArr.length) {
            dArr3[i2] = i2 < numUnconstrainedWeights ? dArr[i2] : project[i2 - numUnconstrainedWeights];
            i2++;
        }
        return dArr3;
    }

    public static double[] project(double[] dArr) {
        boolean z;
        THashSet tHashSet = new THashSet();
        do {
            double d = 0.0d;
            for (double d2 : dArr) {
                d += d2;
            }
            double length = (d - 1.0d) / (dArr.length - tHashSet.size());
            z = true;
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = tHashSet.contains(Integer.valueOf(i)) ? 0.0d : dArr[i] - length;
                if (dArr[i] < 0.0d) {
                    z = false;
                    tHashSet.add(Integer.valueOf(i));
                    dArr[i] = 0.0d;
                }
            }
        } while (!z);
        return dArr;
    }
}
