package org.dice_research.topicmodeling.evaluate;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.SingularValueDecomposition;
import com.carrotsearch.hppc.IntIntOpenHashMap;
import java.util.stream.IntStream;
import org.dice_research.topicmodeling.algorithms.Model;
import org.dice_research.topicmodeling.algorithms.ProbTopicModelingAlgorithmStateSupplier;
import org.dice_research.topicmodeling.evaluate.result.EvaluationResult;
import org.dice_research.topicmodeling.evaluate.result.ManagedEvaluationResultContainer;
import org.dice_research.topicmodeling.evaluate.result.SingleEvaluationResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/dice_research/topicmodeling/evaluate/ArunModelSelectionEvaluator.class */
public class ArunModelSelectionEvaluator extends AbstractEvaluator {
    private static final Logger LOGGER = LoggerFactory.getLogger(ArunModelSelectionEvaluator.class);
    public static final double LOG2 = Math.log(2.0d);
    private ProbTopicModelingAlgorithmStateSupplier probAlgState;
    private boolean parallel;

    public ArunModelSelectionEvaluator(ProbTopicModelingAlgorithmStateSupplier probTopicModelingAlgorithmStateSupplier) {
        this.parallel = false;
        this.probAlgState = probTopicModelingAlgorithmStateSupplier;
    }

    public ArunModelSelectionEvaluator(ProbTopicModelingAlgorithmStateSupplier probTopicModelingAlgorithmStateSupplier, boolean z) {
        this.parallel = false;
        this.probAlgState = probTopicModelingAlgorithmStateSupplier;
        this.parallel = z;
    }

    /* JADX WARN: Type inference failed for: r1v6, types: [double[], double[][]] */
    @Override // org.dice_research.topicmodeling.evaluate.AbstractEvaluator
    protected EvaluationResult evaluate(Model model, ManagedEvaluationResultContainer managedEvaluationResultContainer) {
        int numberOfWords = this.probAlgState.getNumberOfWords();
        int numberOfTopics = this.probAlgState.getNumberOfTopics();
        int numberOfDocuments = this.probAlgState.getNumberOfDocuments();
        DoubleMatrix2D make = DoubleFactory2D.sparse.make(numberOfTopics, numberOfWords);
        DoubleMatrix2D make2 = DoubleFactory2D.sparse.make(numberOfDocuments, numberOfTopics);
        double[] dArr = new double[numberOfDocuments];
        long currentTimeMillis = System.currentTimeMillis();
        if (this.parallel) {
            fillMatrices(make, make2, dArr);
        } else {
            fillMatricesInParallel(make, make2, dArr);
        }
        LOGGER.info("Generating matrices took {}ms. (parallelization={})", Long.valueOf(System.currentTimeMillis() - currentTimeMillis), Boolean.valueOf(this.parallel));
        double[] singularValues = new SingularValueDecomposition(make.rows() >= make.columns() ? make : Algebra.DEFAULT.transpose(make)).getSingularValues();
        DoubleMatrix1D viewRow = Algebra.DEFAULT.mult(DoubleFactory2D.dense.make((double[][]) new double[]{dArr}), make2).viewRow(0);
        normalize(viewRow);
        return new SingleEvaluationResult("Arun", new Double(SymmetricKLDivergence(singularValues, viewRow.toArray())));
    }

    protected void fillMatrices(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2, double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            int[] wordsOfDocument = this.probAlgState.getWordsOfDocument(i);
            int[] wordTopicAssignmentForDocument = this.probAlgState.getWordTopicAssignmentForDocument(i);
            for (int i2 = 0; i2 < wordsOfDocument.length; i2++) {
                doubleMatrix2D.set(wordTopicAssignmentForDocument[i2], wordsOfDocument[i2], doubleMatrix2D.getQuick(wordTopicAssignmentForDocument[i2], wordsOfDocument[i2]) + 1.0d);
                doubleMatrix2D2.set(i, wordTopicAssignmentForDocument[i2], doubleMatrix2D2.getQuick(i, wordTopicAssignmentForDocument[i2]) + 1.0d);
            }
            dArr[i] = wordsOfDocument.length;
        }
    }

    protected void fillMatricesInParallel(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2, double[] dArr) {
        IntStream.range(0, dArr.length).parallel().forEach(i -> {
            handlePerDocumentData(i, doubleMatrix2D2, dArr);
        });
        IntStream.range(0, this.probAlgState.getNumberOfTopics()).parallel().forEach(i2 -> {
            handlePerTopicData(i2, doubleMatrix2D);
        });
    }

    private void handlePerDocumentData(int i, DoubleMatrix2D doubleMatrix2D, double[] dArr) {
        int[] wordTopicAssignmentForDocument = this.probAlgState.getWordTopicAssignmentForDocument(i);
        dArr[i] = wordTopicAssignmentForDocument.length;
        IntIntOpenHashMap intIntOpenHashMap = new IntIntOpenHashMap(this.probAlgState.getNumberOfTopics());
        for (int i2 : wordTopicAssignmentForDocument) {
            intIntOpenHashMap.putOrAdd(i2, 1, 1);
        }
        synchronized (doubleMatrix2D) {
            for (int i3 = 0; i3 < intIntOpenHashMap.allocated.length; i3++) {
                if (intIntOpenHashMap.allocated[i3]) {
                    doubleMatrix2D.setQuick(i, intIntOpenHashMap.keys[i3], intIntOpenHashMap.values[i3]);
                }
            }
        }
    }

    private void handlePerTopicData(int i, DoubleMatrix2D doubleMatrix2D) {
        IntIntOpenHashMap intIntOpenHashMap = new IntIntOpenHashMap();
        IntStream.range(0, this.probAlgState.getNumberOfDocuments()).forEach(i2 -> {
            int[] wordsOfDocument = this.probAlgState.getWordsOfDocument(i2);
            int[] wordTopicAssignmentForDocument = this.probAlgState.getWordTopicAssignmentForDocument(i2);
            for (int i2 = 0; i2 < wordTopicAssignmentForDocument.length; i2++) {
                if (wordTopicAssignmentForDocument[i2] == i) {
                    intIntOpenHashMap.putOrAdd(wordsOfDocument[i2], 1, 1);
                }
            }
        });
        synchronized (doubleMatrix2D) {
            for (int i3 = 0; i3 < intIntOpenHashMap.allocated.length; i3++) {
                if (intIntOpenHashMap.allocated[i3]) {
                    doubleMatrix2D.setQuick(i, intIntOpenHashMap.keys[i3], intIntOpenHashMap.values[i3]);
                }
            }
        }
    }

    protected double SymmetricKLDivergence(double[] dArr, double[] dArr2) {
        return calcKLDivergence(dArr, dArr2) + calcKLDivergence(dArr2, dArr);
    }

    protected double calcKLDivergence(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > 0.0d && dArr2[i] > 0.0d) {
                d += dArr[i] * Math.log(dArr[i] / dArr2[i]);
            }
        }
        return d / LOG2;
    }

    protected void normalize(DoubleMatrix1D doubleMatrix1D) {
        double norm2 = Algebra.DEFAULT.norm2(doubleMatrix1D);
        for (int i = 0; i < doubleMatrix1D.size(); i++) {
            doubleMatrix1D.setQuick(i, doubleMatrix1D.getQuick(i) / norm2);
        }
    }

    @Override // org.dice_research.topicmodeling.evaluate.Evaluator
    public void setReportProvisionalResults(boolean z) {
    }
}
