package org.dllearner.algorithms.isle;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import org.apache.lucene.analysis.core.SimpleAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.BytesRef;

/* loaded from: input_file:org/dllearner/algorithms/isle/VSMCosineDocumentSimilarity.class */
public class VSMCosineDocumentSimilarity {
    public static final String CONTENT = "Content";
    public static final FieldType TYPE_STORED = new FieldType();
    private final Set<String> terms;
    private final RealVector v1;
    private final RealVector v2;

    /* loaded from: input_file:org/dllearner/algorithms/isle/VSMCosineDocumentSimilarity$TermWeighting.class */
    enum TermWeighting {
        TF,
        TF_IDF
    }

    public VSMCosineDocumentSimilarity(String str, String str2, TermWeighting termWeighting) throws IOException {
        this.terms = new HashSet();
        DirectoryReader open = DirectoryReader.open(createIndex(str, str2));
        if (termWeighting == TermWeighting.TF) {
            Map<String, Integer> termFrequencies = getTermFrequencies(open, 0);
            Map<String, Integer> termFrequencies2 = getTermFrequencies(open, 1);
            open.close();
            this.v1 = getTermVectorInteger(termFrequencies);
            this.v2 = getTermVectorInteger(termFrequencies2);
            return;
        }
        if (termWeighting != TermWeighting.TF_IDF) {
            this.v1 = null;
            this.v2 = null;
            return;
        }
        Map<String, Double> termWeights = getTermWeights(open, 0);
        Map<String, Double> termWeights2 = getTermWeights(open, 1);
        open.close();
        this.v1 = getTermVectorDouble(termWeights);
        this.v2 = getTermVectorDouble(termWeights2);
    }

    public VSMCosineDocumentSimilarity(String str, String str2) throws IOException {
        this(str, str2, TermWeighting.TF_IDF);
    }

    public static double getCosineSimilarity(String str, String str2) throws IOException {
        return new VSMCosineDocumentSimilarity(str, str2).getCosineSimilarity();
    }

    public static double getCosineSimilarity(String str, String str2, TermWeighting termWeighting) throws IOException {
        return new VSMCosineDocumentSimilarity(str, str2, termWeighting).getCosineSimilarity();
    }

    private Directory createIndex(String str, String str2) throws IOException {
        MMapDirectory mMapDirectory = new MMapDirectory(Files.createTempDirectory("Lucene", new FileAttribute[0]));
        IndexWriter indexWriter = new IndexWriter(mMapDirectory, new IndexWriterConfig(new SimpleAnalyzer()));
        addDocument(indexWriter, str);
        addDocument(indexWriter, str2);
        indexWriter.close();
        return mMapDirectory;
    }

    private void addDocument(IndexWriter indexWriter, String str) throws IOException {
        Document document = new Document();
        document.add(new Field(CONTENT, str, TYPE_STORED));
        indexWriter.addDocument(document);
    }

    private Map<String, Integer> getTermFrequencies(IndexReader indexReader, int i) throws IOException {
        TermsEnum it = indexReader.getTermVector(i, CONTENT).iterator();
        HashMap hashMap = new HashMap();
        while (true) {
            BytesRef next = it.next();
            if (next == null) {
                return hashMap;
            }
            String utf8ToString = next.utf8ToString();
            hashMap.put(utf8ToString, Integer.valueOf((int) it.totalTermFreq()));
            this.terms.add(utf8ToString);
        }
    }

    private Map<String, Double> getTermWeights(IndexReader indexReader, int i) throws IOException {
        Terms termVector = indexReader.getTermVector(i, CONTENT);
        if (termVector == null) {
            return new HashMap();
        }
        TermsEnum it = termVector.iterator();
        HashMap hashMap = new HashMap();
        while (true) {
            BytesRef next = it.next();
            if (next == null) {
                return hashMap;
            }
            String utf8ToString = next.utf8ToString();
            hashMap.put(utf8ToString, Double.valueOf(((int) it.totalTermFreq()) * getIDF(indexReader.numDocs(), indexReader.docFreq(new Term(CONTENT, next)))));
            this.terms.add(utf8ToString);
        }
    }

    private double getIDF(int i, int i2) {
        return 1.0d + Math.log(i / i2);
    }

    private double getCosineSimilarity() {
        return this.v1.dotProduct(this.v2) / (this.v1.getNorm() * this.v2.getNorm());
    }

    private RealVector getTermVectorInteger(Map<String, Integer> map) {
        ArrayRealVector arrayRealVector = new ArrayRealVector(this.terms.size());
        int i = 0;
        Iterator<String> it = this.terms.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            arrayRealVector.setEntry(i2, map.containsKey(it.next()) ? map.get(r0).intValue() : 0);
        }
        return arrayRealVector.mapDivide(arrayRealVector.getL1Norm());
    }

    private RealVector getTermVectorDouble(Map<String, Double> map) {
        ArrayRealVector arrayRealVector = new ArrayRealVector(this.terms.size());
        int i = 0;
        for (String str : this.terms) {
            int i2 = i;
            i++;
            arrayRealVector.setEntry(i2, map.containsKey(str) ? map.get(str).doubleValue() : 0.0d);
        }
        return arrayRealVector.mapDivide(arrayRealVector.getL1Norm());
    }

    public static void main(String[] strArr) throws Exception {
        System.out.println(getCosineSimilarity("The king is here", "The salad is cold"));
    }

    static {
        TYPE_STORED.setIndexOptions(IndexOptions.DOCS_AND_FREQS);
        TYPE_STORED.setTokenized(true);
        TYPE_STORED.setStored(true);
        TYPE_STORED.setStoreTermVectors(true);
        TYPE_STORED.setStoreTermVectorPositions(true);
        TYPE_STORED.freeze();
    }
}
