package org.dice_research.topicmodeling.evaluate;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import org.apache.commons.io.IOUtils;
import org.dice_research.topicmodeling.algorithms.ClassificationModel;
import org.dice_research.topicmodeling.algorithms.Model;
import org.dice_research.topicmodeling.algorithms.ProbTopicModelingAlgorithmStateSupplier;
import org.dice_research.topicmodeling.evaluate.result.EvaluationResult;
import org.dice_research.topicmodeling.evaluate.result.ManagedEvaluationResultContainer;
import org.dice_research.topicmodeling.utils.corpus.Corpus;
import org.dice_research.topicmodeling.utils.doc.Document;
import org.dice_research.topicmodeling.utils.doc.DocumentClassificationResult;
import org.dice_research.topicmodeling.utils.doc.DocumentName;
import org.dice_research.topicmodeling.utils.doc.DocumentURI;
import org.dice_research.topicmodeling.utils.doc.ProbabilisticClassificationResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/dice_research/topicmodeling/evaluate/Tapioca_LodStatsGold_Evaluator.class */
public class Tapioca_LodStatsGold_Evaluator extends AbstractEvaluatorWithClassifiedTestCorpus {
    private static final Logger LOGGER = LoggerFactory.getLogger(Tapioca_LodStatsGold_Evaluator.class);
    public static final String SIMILARITY_OUTPUT_FILE_NAME = "dataset_similarities.csv";
    protected ProbTopicModelingAlgorithmStateSupplier probAlgState;
    protected static final int NUMNER_OF_TOPIC_INFERENCES = 1;
    protected String outputFolder;
    protected Corpus trainCorpus;
    protected String outputFileName;

    public Tapioca_LodStatsGold_Evaluator(Corpus corpus, Corpus corpus2, ProbTopicModelingAlgorithmStateSupplier probTopicModelingAlgorithmStateSupplier, String str) {
        super(corpus2);
        this.outputFileName = "dataset_similarities.csv";
        this.probAlgState = probTopicModelingAlgorithmStateSupplier;
        this.outputFolder = str;
        this.trainCorpus = corpus;
    }

    public Tapioca_LodStatsGold_Evaluator(Corpus corpus, Corpus corpus2, ProbTopicModelingAlgorithmStateSupplier probTopicModelingAlgorithmStateSupplier) {
        super(corpus2);
        this.outputFileName = "dataset_similarities.csv";
        this.probAlgState = probTopicModelingAlgorithmStateSupplier;
        this.outputFolder = null;
        this.trainCorpus = corpus;
    }

    @Override // org.dice_research.topicmodeling.evaluate.AbstractEvaluatorWithClassifiedTestCorpus, org.dice_research.topicmodeling.evaluate.AbstractEvaluator
    public EvaluationResult evaluate(Model model, ManagedEvaluationResultContainer managedEvaluationResultContainer) {
        if (!(model instanceof ClassificationModel)) {
            throw new IllegalArgumentException("Got a " + model.getClass().getCanonicalName() + " as Model while expecting a " + ClassificationModel.class.getCanonicalName());
        }
        ClassificationModel classificationModel = (ClassificationModel) model;
        classifyDocuments(classificationModel);
        return evaluateModelWithClassifiedCorpus(classificationModel, managedEvaluationResultContainer);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.dice_research.topicmodeling.evaluate.AbstractEvaluatorWithClassifiedTestCorpus
    public EvaluationResult evaluateModelWithClassifiedCorpus(ClassificationModel classificationModel, ManagedEvaluationResultContainer managedEvaluationResultContainer) {
        DocumentClassificationResult[] classifyDocuments = classifyDocuments(classificationModel, this.trainCorpus);
        double[] dArr = new double[classifyDocuments.length];
        double[] dArr2 = new double[classifyDocuments.length];
        for (int i = 0; i < dArr.length; i += NUMNER_OF_TOPIC_INFERENCES) {
            dArr[i] = ((ProbabilisticClassificationResult) classifyDocuments[i]).getTopicProbabilities();
            dArr2[i] = getLength(dArr[i]);
        }
        try {
            PrintStream printStream = this.outputFolder != null ? new PrintStream(this.outputFolder + File.separator + this.outputFileName) : new PrintStream(this.outputFileName);
            for (int i2 = 0; i2 < this.trainCorpus.getNumberOfDocuments(); i2 += NUMNER_OF_TOPIC_INFERENCES) {
                Document document = this.trainCorpus.getDocument(i2);
                printStream.print(",\"");
                DocumentURI property = document.getProperty(DocumentURI.class);
                if (property != null) {
                    printStream.print((String) property.get());
                } else {
                    DocumentName property2 = document.getProperty(DocumentName.class);
                    if (property2 != null) {
                        printStream.print((String) property2.get());
                    } else {
                        printStream.print("Document #");
                        printStream.print(i2);
                    }
                }
                printStream.print('\"');
            }
            printStream.println();
            for (int i3 = 0; i3 < this.testCorpus.getNumberOfDocuments(); i3 += NUMNER_OF_TOPIC_INFERENCES) {
                Document document2 = this.testCorpus.getDocument(i3);
                document2.getProperty(DocumentName.class);
                printStream.print('\"');
                DocumentURI property3 = document2.getProperty(DocumentURI.class);
                if (property3 != null) {
                    printStream.print((String) property3.get());
                } else {
                    DocumentName property4 = document2.getProperty(DocumentName.class);
                    if (property4 != null) {
                        printStream.print((String) property4.get());
                    } else {
                        printStream.print("Document #");
                        printStream.print(i3);
                    }
                }
                printStream.print('\"');
                double[] topicProbabilities = this.classifications[i3].getTopicProbabilities();
                double length = getLength(topicProbabilities);
                for (int i4 = 0; i4 < dArr.length; i4 += NUMNER_OF_TOPIC_INFERENCES) {
                    double calculateSimilarity = calculateSimilarity(topicProbabilities, length, dArr[i4], dArr2[i4]);
                    printStream.print(',');
                    printStream.print(calculateSimilarity);
                }
                printStream.println();
            }
            IOUtils.closeQuietly(printStream);
            return null;
        } catch (FileNotFoundException e) {
            LOGGER.error("Couldn't create similarity file. Aborting.", e);
            return null;
        }
    }

    protected double calculateSimilarity(double[] dArr, double d, double[] dArr2, double d2) {
        if (d == 0.0d || d2 == 0.0d) {
            return (d == 0.0d && d2 == 0.0d) ? 1.0d : 0.0d;
        }
        double d3 = 0.0d;
        for (int i = 0; i < dArr.length; i += NUMNER_OF_TOPIC_INFERENCES) {
            d3 += dArr[i] * dArr2[i];
        }
        return d3 / (d * d2);
    }

    protected double getLength(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i += NUMNER_OF_TOPIC_INFERENCES) {
            d += Math.pow(dArr[i], 2.0d);
        }
        return Math.sqrt(d);
    }

    @Override // org.dice_research.topicmodeling.evaluate.Evaluator
    public void setReportProvisionalResults(boolean z) {
    }
}
