package experimental.ising;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.javatuples.Pair;

/* loaded from: input_file:experimental/ising/IsingFactorGraph.class */
public class IsingFactorGraph {
    private String word;
    private int numVariables;
    protected List<Variable> variables;
    protected List<UnaryFactor> unaryFactors;
    protected List<BinaryFactor> binaryFactors;
    protected List<Integer> golden;
    protected int numParameters;
    protected UnaryFeatureExtractor ufe;
    private int inferenceIterations;

    public IsingFactorGraph(String str, UnaryFeatureExtractor unaryFeatureExtractor, int i, int i2, List<Pair<Integer, Integer>> list, List<Integer> list2, List<String> list3) {
        this.numVariables = i2;
        this.variables = new ArrayList();
        this.unaryFactors = new ArrayList();
        this.binaryFactors = new ArrayList();
        this.word = str;
        this.ufe = unaryFeatureExtractor;
        this.inferenceIterations = i;
        this.golden = list2;
        for (int i3 = 0; i3 < this.numVariables; i3++) {
            Variable variable = new Variable(2, i3, list3.get(i3));
            UnaryFactor unaryFactor = new UnaryFactor(str, list3.get(i3), 2, i3, unaryFeatureExtractor);
            variable.getNeighbors().add(unaryFactor);
            variable.getMessageIds().add(0);
            variable.getMessages().add(new Message(2));
            unaryFactor.getNeighbors().add(variable);
            unaryFactor.getMessageIds().add(Integer.valueOf(variable.getMessages().size()));
            unaryFactor.getMessages().add(new Message(2));
            this.variables.add(variable);
            this.unaryFactors.add(unaryFactor);
        }
        for (Pair<Integer, Integer> pair : list) {
            int intValue = ((Integer) pair.getValue0()).intValue();
            int intValue2 = ((Integer) pair.getValue1()).intValue();
            BinaryFactor binaryFactor = new BinaryFactor(2, 2, intValue, intValue2);
            Variable variable2 = this.variables.get(intValue);
            Variable variable3 = this.variables.get(intValue2);
            variable2.getNeighbors().add(binaryFactor);
            variable2.getMessageIds().add(0);
            binaryFactor.getMessages().add(new Message(2));
            variable3.getNeighbors().add(binaryFactor);
            variable3.getMessageIds().add(1);
            binaryFactor.getMessages().add(new Message(2));
            binaryFactor.getNeighbors().add(variable2);
            binaryFactor.getMessageIds().add(Integer.valueOf(variable2.getMessages().size()));
            variable2.getMessages().add(new Message(2));
            binaryFactor.getNeighbors().add(variable3);
            binaryFactor.getMessageIds().add(Integer.valueOf(variable3.getMessages().size()));
            variable3.getMessages().add(new Message(2));
            this.binaryFactors.add(binaryFactor);
        }
        this.numParameters = (2 * this.unaryFactors.size()) + (4 * this.binaryFactors.size());
    }

    public IsingFactorGraph(int i, List<Pair<Integer, Integer>> list, List<Integer> list2, List<String> list3) {
        throw new UnsupportedOperationException();
    }

    public double[][] inferenceBruteForce() {
        double[][] dArr = new double[this.numVariables][2];
        double d = 0.0d;
        for (int i = 0; i < Math.pow(2.0d, this.numVariables); i++) {
            double d2 = 1.0d;
            String format = String.format("%0" + this.numVariables + "d", Integer.valueOf(Integer.toBinaryString(i)));
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < this.numVariables; i2++) {
                arrayList.add(Integer.valueOf(Character.getNumericValue(format.charAt(i2))));
            }
            for (UnaryFactor unaryFactor : this.unaryFactors) {
                d2 *= unaryFactor.potential[((Integer) arrayList.get(unaryFactor.getI())).intValue()];
            }
            for (BinaryFactor binaryFactor : this.binaryFactors) {
                d2 *= binaryFactor.potential[((Integer) arrayList.get(binaryFactor.getI())).intValue()][((Integer) arrayList.get(binaryFactor.getJ())).intValue()];
            }
            d += d2;
            for (int i3 = 0; i3 < this.numVariables; i3++) {
                int intValue = ((Integer) arrayList.get(i3)).intValue();
                double[] dArr2 = dArr[i3];
                dArr2[intValue] = dArr2[intValue] + d2;
            }
        }
        for (int i4 = 0; i4 < this.numVariables; i4++) {
            double d3 = dArr[i4][0] + dArr[i4][1];
            double[] dArr3 = dArr[i4];
            dArr3[0] = dArr3[0] / d3;
            double[] dArr4 = dArr[i4];
            dArr4[1] = dArr4[1] / d3;
        }
        return dArr;
    }

    public double betheFreeEnergy() {
        double d = 0.0d;
        for (Variable variable : this.variables) {
            variable.computeBelief();
            UnaryFactor unaryFactor = this.unaryFactors.get(variable.getI());
            for (int i = 0; i < variable.getSize(); i++) {
                int size = variable.getNeighbors().size() - 2;
                if (size != 0) {
                    d += size * variable.getBelief().measure[i] * Math.log(variable.getBelief().measure[i]);
                }
                d += variable.getBelief().measure[i] * Math.log(unaryFactor.potential[i]);
            }
        }
        return d;
    }

    public double approximateZ() {
        return Math.exp(betheFreeEnergy());
    }

    public void inference(int i, double d) {
        for (int i2 = 0; i2 < i; i2++) {
            Iterator<UnaryFactor> it = this.unaryFactors.iterator();
            while (it.hasNext()) {
                it.next().passMessage();
            }
        }
        Iterator<Variable> it2 = this.variables.iterator();
        while (it2.hasNext()) {
            it2.next().computeBelief();
        }
    }

    public List<String> viterbiDecode() {
        return null;
    }

    public List<String> posteriorDecode() {
        LinkedList linkedList = new LinkedList();
        inference(this.inferenceIterations, 0.01d);
        for (Variable variable : this.variables) {
            Belief belief = variable.getBelief();
            if (belief.measure[1] > belief.measure[0]) {
                linkedList.add(variable.getTagName());
            }
        }
        return linkedList;
    }

    public double logLikelihood() {
        double betheFreeEnergy = betheFreeEnergy();
        double d = 1.0d;
        for (UnaryFactor unaryFactor : this.unaryFactors) {
            d *= unaryFactor.potential[this.golden.get(unaryFactor.getI()).intValue()];
        }
        return Math.log(d) - betheFreeEnergy;
    }

    public double[] finiteDifference(double[] dArr, double d) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + d;
            updatePotentials(dArr);
            inference(10, 1.0d);
            double logLikelihood = logLikelihood();
            int i3 = i;
            dArr[i3] = dArr[i3] - (2.0d * d);
            updatePotentials(dArr);
            inference(10, 1.0d);
            dArr2[i] = (logLikelihood - logLikelihood()) / (2.0d * d);
            int i4 = i;
            dArr[i4] = dArr[i4] + d;
        }
        return dArr2;
    }

    public void updatePotentials2(double[] dArr) {
        int i = 0;
        for (UnaryFactor unaryFactor : this.unaryFactors) {
            unaryFactor.setPotential(0, Math.exp(dArr[i]));
            int i2 = i + 1;
            unaryFactor.setPotential(1, Math.exp(dArr[i2]));
            i = i2 + 1;
        }
        for (BinaryFactor binaryFactor : this.binaryFactors) {
            binaryFactor.setPotential(0, 0, Math.exp(dArr[i]));
            int i3 = i + 1;
            binaryFactor.setPotential(0, 1, Math.exp(dArr[i3]));
            int i4 = i3 + 1;
            binaryFactor.setPotential(1, 0, Math.exp(dArr[i4]));
            int i5 = i4 + 1;
            binaryFactor.setPotential(1, 1, Math.exp(dArr[i5]));
            i = i5 + 1;
        }
    }

    public void updatePotentials(double[] dArr) {
        Iterator<UnaryFactor> it = this.unaryFactors.iterator();
        while (it.hasNext()) {
            it.next().updatePotential(dArr);
        }
    }

    public void featurizedGradient(double[] dArr, int i) {
        inference(this.inferenceIterations, 0.01d);
        for (UnaryFactor unaryFactor : this.unaryFactors) {
            if (this.golden.get(unaryFactor.getI()).intValue() == 1) {
                Iterator<Integer> it = unaryFactor.getFeaturesPositive().iterator();
                while (it.hasNext()) {
                    int intValue = it.next().intValue();
                    dArr[intValue] = dArr[intValue] + 1.0d;
                }
            }
            Iterator<Integer> it2 = unaryFactor.getFeaturesPositive().iterator();
            while (it2.hasNext()) {
                int intValue2 = it2.next().intValue();
                dArr[intValue2] = dArr[intValue2] - this.variables.get(unaryFactor.getI()).getBelief().measure[1];
            }
            if (this.golden.get(unaryFactor.getI()).intValue() == 0) {
                Iterator<Integer> it3 = unaryFactor.getFeaturesNegative().iterator();
                while (it3.hasNext()) {
                    int intValue3 = it3.next().intValue();
                    dArr[intValue3] = dArr[intValue3] + 1.0d;
                }
            }
            Iterator<Integer> it4 = unaryFactor.getFeaturesNegative().iterator();
            while (it4.hasNext()) {
                int intValue4 = it4.next().intValue();
                dArr[intValue4] = dArr[intValue4] - this.variables.get(unaryFactor.getI()).getBelief().measure[0];
            }
        }
    }

    public double[] unfeaturizedGradient() {
        inference(10, 0.01d);
        double[] dArr = new double[this.numParameters];
        int i = 0;
        for (UnaryFactor unaryFactor : this.unaryFactors) {
            if (this.golden.get(unaryFactor.getI()).intValue() == 0) {
                int i2 = i;
                dArr[i2] = dArr[i2] + 1.0d;
            }
            int i3 = i;
            dArr[i3] = dArr[i3] - this.variables.get(unaryFactor.getI()).getBelief().measure[0];
            int i4 = i + 1;
            if (this.golden.get(unaryFactor.getI()).intValue() == 1) {
                dArr[i4] = dArr[i4] + 1.0d;
            }
            dArr[i4] = dArr[i4] - this.variables.get(unaryFactor.getI()).getBelief().measure[1];
            i = i4 + 1;
        }
        for (BinaryFactor binaryFactor : this.binaryFactors) {
            if (this.golden.get(binaryFactor.getI()).intValue() == 0 && this.golden.get(binaryFactor.getJ()).intValue() == 0) {
                int i5 = i;
                dArr[i5] = dArr[i5] + 1.0d;
            }
            int i6 = i;
            dArr[i6] = dArr[i6] - (this.variables.get(binaryFactor.getI()).getBelief().measure[0] * this.variables.get(binaryFactor.getJ()).getBelief().measure[0]);
            int i7 = i + 1;
            if (this.golden.get(binaryFactor.getI()).intValue() == 0 && this.golden.get(binaryFactor.getJ()).intValue() == 1) {
                dArr[i7] = dArr[i7] + 1.0d;
            }
            dArr[i7] = dArr[i7] - (this.variables.get(binaryFactor.getI()).getBelief().measure[0] * this.variables.get(binaryFactor.getJ()).getBelief().measure[1]);
            int i8 = i7 + 1;
            if (this.golden.get(binaryFactor.getI()).intValue() == 1 && this.golden.get(binaryFactor.getJ()).intValue() == 0) {
                dArr[i8] = dArr[i8] + 1.0d;
            }
            dArr[i8] = dArr[i8] - (this.variables.get(binaryFactor.getI()).getBelief().measure[1] * this.variables.get(binaryFactor.getJ()).getBelief().measure[0]);
            int i9 = i8 + 1;
            if (this.golden.get(binaryFactor.getI()).intValue() == 1 && this.golden.get(binaryFactor.getJ()).intValue() == 1) {
                dArr[i9] = dArr[i9] + 1.0d;
            }
            dArr[i9] = dArr[i9] - (this.variables.get(binaryFactor.getI()).getBelief().measure[1] * this.variables.get(binaryFactor.getJ()).getBelief().measure[1]);
            i = i9 + 1;
        }
        return dArr;
    }

    public String getWord() {
        return this.word;
    }

    public void setWord(String str) {
        this.word = str;
    }

    public List<Variable> getVariables() {
        return this.variables;
    }

    public List<UnaryFactor> getUnaryFactor() {
        return this.unaryFactors;
    }
}
