package cmu.arktweetnlp.impl;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.DiffFunction;
import java.util.LinkedList;

/* JADX INFO: Access modifiers changed from: package-private */
/* compiled from: OWLQN.java */
/* loaded from: input_file:cmu/arktweetnlp/impl/OptimizerState.class */
public class OptimizerState {
    double[] x;
    double[] grad;
    double[] newX;
    double[] newGrad;
    double[] dir;
    double[] steepestDescDir;
    double[] alphas;
    double value;
    int m;
    int dim;
    DiffFunction func;
    double l1weight;
    boolean quiet;
    static final /* synthetic */ boolean $assertionsDisabled;
    LinkedList<double[]> sList = new LinkedList<>();
    LinkedList<double[]> yList = new LinkedList<>();
    LinkedList<Double> roList = new LinkedList<>();
    int iter = 1;

    private String arrayToString(double[] dArr) {
        String str = "";
        for (int i = 0; i < dArr.length; i++) {
            if (i > 0) {
                str = str + "\t";
            }
            str = str + dArr[i];
        }
        return str;
    }

    private void printStateValues() {
        System.err.println("\nSLIST:");
        for (int i = 0; i < this.sList.size(); i++) {
            System.err.println(arrayToString(this.sList.get(i)));
        }
        System.err.println("YLIST:");
        for (int i2 = 0; i2 < this.yList.size(); i2++) {
            System.err.println(arrayToString(this.yList.get(i2)));
        }
        System.err.println("ROLIST:");
        for (int i3 = 0; i3 < this.roList.size(); i3++) {
            System.err.println(this.roList.get(i3));
        }
        System.err.println();
    }

    void mapDirByInverseHessian() {
        int size = this.sList.size();
        if (size != 0) {
            for (int i = size - 1; i >= 0; i--) {
                this.alphas[i] = (-ArrayMath.innerProduct(this.sList.get(i), this.dir)) / this.roList.get(i).doubleValue();
                ArrayMath.addMultInPlace(this.dir, this.yList.get(i), this.alphas[i]);
            }
            double[] dArr = this.yList.get(size - 1);
            ArrayMath.multiplyInPlace(this.dir, this.roList.get(size - 1).doubleValue() / ArrayMath.innerProduct(dArr, dArr));
            for (int i2 = 0; i2 < size; i2++) {
                ArrayMath.addMultInPlace(this.dir, this.sList.get(i2), (-this.alphas[i2]) - (ArrayMath.innerProduct(this.yList.get(i2), this.dir) / this.roList.get(i2).doubleValue()));
            }
        }
    }

    void makeSteepestDescDir() {
        if (this.l1weight == 0.0d) {
            ArrayMath.multiplyInto(this.dir, this.grad, -1.0d);
        } else {
            for (int i = 0; i < this.dim; i++) {
                if (OWLQN.biasParameters.contains(Integer.valueOf(i))) {
                    this.dir[i] = -this.grad[i];
                } else if (this.x[i] < 0.0d) {
                    this.dir[i] = (-this.grad[i]) + this.l1weight;
                } else if (this.x[i] > 0.0d) {
                    this.dir[i] = (-this.grad[i]) - this.l1weight;
                } else if (this.grad[i] < (-this.l1weight)) {
                    this.dir[i] = (-this.grad[i]) - this.l1weight;
                } else if (this.grad[i] > this.l1weight) {
                    this.dir[i] = (-this.grad[i]) + this.l1weight;
                } else {
                    this.dir[i] = 0.0d;
                }
            }
        }
        this.steepestDescDir = (double[]) this.dir.clone();
    }

    void fixDirSigns() {
        if (this.l1weight > 0.0d) {
            for (int i = 0; i < this.dim; i++) {
                if (!OWLQN.biasParameters.contains(Integer.valueOf(i)) && this.dir[i] * this.steepestDescDir[i] <= 0.0d) {
                    this.dir[i] = 0.0d;
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void updateDir() {
        makeSteepestDescDir();
        mapDirByInverseHessian();
        fixDirSigns();
    }

    double dirDeriv() {
        if (this.l1weight == 0.0d) {
            return ArrayMath.innerProduct(this.dir, this.grad);
        }
        double d = 0.0d;
        for (int i = 0; i < this.dim; i++) {
            if (OWLQN.biasParameters.contains(Integer.valueOf(i))) {
                d += this.dir[i] * this.grad[i];
            } else if (this.dir[i] != 0.0d) {
                if (this.x[i] < 0.0d) {
                    d += this.dir[i] * (this.grad[i] - this.l1weight);
                } else if (this.x[i] > 0.0d) {
                    d += this.dir[i] * (this.grad[i] + this.l1weight);
                } else if (this.dir[i] < 0.0d) {
                    d += this.dir[i] * (this.grad[i] - this.l1weight);
                } else if (this.dir[i] > 0.0d) {
                    d += this.dir[i] * (this.grad[i] + this.l1weight);
                }
            }
        }
        return d;
    }

    private boolean getNextPoint(double d) {
        ArrayMath.addMultInto(this.newX, this.x, this.dir, d);
        if (this.l1weight <= 0.0d) {
            return true;
        }
        for (int i = 0; i < this.dim; i++) {
            if (!OWLQN.biasParameters.contains(Integer.valueOf(i)) && this.x[i] * this.newX[i] < 0.0d) {
                this.newX[i] = 0.0d;
            }
        }
        return true;
    }

    double evalL1() {
        double valueAt = this.func.valueAt(this.newX);
        this.newGrad = (double[]) this.func.derivativeAt(this.newX).clone();
        if (this.l1weight > 0.0d) {
            for (int i = 0; i < this.dim; i++) {
                if (!OWLQN.biasParameters.contains(Integer.valueOf(i))) {
                    valueAt += Math.abs(this.newX[i]) * this.l1weight;
                }
            }
        }
        return valueAt;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void backTrackingLineSearch() {
        double dirDeriv = dirDeriv();
        if (dirDeriv >= 0.0d) {
            throw new RuntimeException("L-BFGS chose a non-descent direction: check your gradient!");
        }
        double d = 1.0d;
        double d2 = 0.5d;
        if (this.iter == 1) {
            d = 1.0d / Math.sqrt(ArrayMath.innerProduct(this.dir, this.dir));
            d2 = 0.1d;
        }
        double d3 = this.value;
        while (true) {
            getNextPoint(d);
            this.value = evalL1();
            if (this.value <= d3 + (1.0E-4d * dirDeriv * d)) {
                break;
            }
            if (d < 1.0E-30d) {
                System.err.println("Warning: The line search backed off to alpha < 1e-30, and stayed with the current parameter values.  This probably means converged has been reached.");
                this.value = d3;
                break;
            } else {
                if (!this.quiet) {
                    System.err.print(".");
                }
                d *= d2;
            }
        }
        if (this.quiet) {
            return;
        }
        System.err.println();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void shift() {
        double[] dArr = null;
        double[] dArr2 = null;
        int size = this.sList.size();
        if (size < this.m) {
            try {
                dArr = new double[this.dim];
                dArr2 = new double[this.dim];
            } catch (OutOfMemoryError e) {
                this.m = size;
                dArr = null;
            }
        }
        if (dArr == null) {
            dArr = this.sList.poll();
            dArr2 = this.yList.poll();
            this.roList.poll();
        }
        ArrayMath.addMultInto(dArr, this.newX, this.x, -1.0d);
        ArrayMath.addMultInto(dArr2, this.newGrad, this.grad, -1.0d);
        double innerProduct = ArrayMath.innerProduct(dArr, dArr2);
        if (!$assertionsDisabled && innerProduct == 0.0d) {
            throw new AssertionError();
        }
        this.sList.offer(dArr);
        this.yList.offer(dArr2);
        this.roList.offer(Double.valueOf(innerProduct));
        double[] dArr3 = this.newX;
        this.newX = this.x;
        this.x = dArr3;
        double[] dArr4 = this.newGrad;
        this.newGrad = this.grad;
        this.grad = dArr4;
        this.iter++;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double getValue() {
        return this.value;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public OptimizerState(DiffFunction diffFunction, double[] dArr, int i, double d, boolean z) {
        this.x = dArr;
        this.grad = new double[dArr.length];
        this.newX = (double[]) dArr.clone();
        this.newGrad = new double[dArr.length];
        this.dir = new double[dArr.length];
        this.steepestDescDir = (double[]) this.newGrad.clone();
        this.alphas = new double[i];
        this.m = i;
        this.dim = dArr.length;
        this.func = diffFunction;
        this.l1weight = d;
        this.quiet = z;
        if (i <= 0) {
            throw new RuntimeException("m must be an integer greater than zero.");
        }
        this.value = evalL1();
        this.grad = (double[]) this.newGrad.clone();
    }

    static {
        $assertionsDisabled = !OptimizerState.class.desiredAssertionStatus();
    }
}
