package org.dllearner.utilities.semkernel;

import com.google.common.collect.Sets;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.jena.sparql.sse.Tags;
import org.apache.log4j.Logger;
import org.dllearner.algorithms.semkernel.SemKernel;
import org.dllearner.core.ComponentAnn;
import org.dllearner.core.ComponentInitException;
import org.obolibrary.obo2owl.Obo2OWLConstants;
import org.semanticweb.elk.owlapi.ElkReasonerFactory;
import org.semanticweb.owlapi.apibinding.OWLManager;
import org.semanticweb.owlapi.model.IRI;
import org.semanticweb.owlapi.model.OWLClass;
import org.semanticweb.owlapi.model.OWLDataFactory;
import org.semanticweb.owlapi.model.OWLOntology;
import org.semanticweb.owlapi.model.OWLOntologyCreationException;
import org.semanticweb.owlapi.reasoner.ConsoleProgressMonitor;
import org.semanticweb.owlapi.reasoner.InferenceType;
import org.semanticweb.owlapi.reasoner.OWLReasoner;
import org.semanticweb.owlapi.reasoner.SimpleConfiguration;
import uk.ac.manchester.cs.owl.owlapi.OWLClassImpl;

@ComponentAnn(name = "Mammalian Phenotype SemKernel Workflow", shortName = "mpskw", version = 0.1d)
/* loaded from: input_file:BOOT-INF/lib/components-core-1.3.0-jena3-SNAPSHOT.jar:org/dllearner/utilities/semkernel/MPSemKernelWorkflow.class */
public class MPSemKernelWorkflow extends SemKernelWorkflow {
    private String trainURIsFilePath;
    private String mpKBFilePath;
    private String trainingInputDirectoryPath;
    private String goKBFilePath;
    private String trainingOutputDirectoryPath;
    private String predictionURIsFilePath;
    private String predictionInputDirectoryPath;
    private String predictionOutputDirectoryPath;
    private String mgi2mpMappingsFilePath;
    private String mgi2goMappingsFilePath;
    private SemKernel kernel;
    private OWLDataFactory dataFactory;
    private OWLOntology mpKB;
    private OWLReasoner mpKBReasoner;
    private Map<String, Set<String>> mgi2mp;
    private Map<String, Set<String>> mgi2go;
    private SemKernel.SvmType svmType = SemKernel.SvmType.C_SVC;
    private boolean doProbabilityEstimates = true;
    private int crossValidationFolds = 10;
    private float cost = 5.0f;
    private boolean predictProbability = true;
    private double posNegExampleRatio = 1.0d;
    private boolean doTraining = true;
    private boolean doPrediction = true;
    private final Logger logger = Logger.getLogger(MPSemKernelWorkflow.class);
    private final String oboPrefix = Obo2OWLConstants.DEFAULT_IRI_PREFIX;

    @Override // org.dllearner.utilities.semkernel.SemKernelWorkflow, org.dllearner.core.Component
    public void init() throws ComponentInitException {
        this.logger.info("Inializing workflow...");
        this.dataFactory = OWLManager.getOWLDataFactory();
        try {
            this.mpKB = OWLManager.createOWLOntologyManager().loadOntologyFromOntologyDocument(new File(this.mpKBFilePath));
        } catch (OWLOntologyCreationException e) {
            e.printStackTrace();
            System.exit(1);
        }
        this.mpKBReasoner = new ElkReasonerFactory().createReasoner(this.mpKB, new SimpleConfiguration(new ConsoleProgressMonitor()));
        this.mpKBReasoner.precomputeInferences(InferenceType.CLASS_HIERARCHY);
        try {
            this.mgi2go = readMGI2GOMapping(this.mgi2goMappingsFilePath);
            this.mgi2mp = readMGI2MPMapping(this.mgi2mpMappingsFilePath);
        } catch (IOException e2) {
            e2.printStackTrace();
            System.exit(1);
        }
        this.kernel = new SemKernel();
        this.kernel.setSvmType(this.svmType);
        this.kernel.setDoProbabilityEstimates(this.doProbabilityEstimates);
        this.kernel.setCrossValidationFolds(this.crossValidationFolds);
        this.kernel.setCost(this.cost);
        this.kernel.setOntologyFilePath(this.goKBFilePath);
        this.kernel.setTrainingDirPath(this.trainingInputDirectoryPath);
        this.kernel.setModelDirPath(this.trainingOutputDirectoryPath);
        this.kernel.setPredictionDataDirPath(this.predictionInputDirectoryPath);
        this.kernel.setResultsDirPath(this.predictionOutputDirectoryPath);
        this.kernel.setGamma(CMAESOptimizer.DEFAULT_STOPFITNESS);
        this.kernel.setPredictProbability(this.predictProbability);
        this.kernel.init();
        this.logger.info("Finished workflow initialization.");
    }

    @Override // org.dllearner.utilities.semkernel.SemKernelWorkflow
    public void start() {
        if (this.doTraining) {
            this.logger.info("Preparing training data...");
            try {
                prepareMPSampleTrainingData();
            } catch (IOException e) {
                e.printStackTrace();
                System.exit(1);
            }
            this.logger.info("Finished training data preparation.");
            this.logger.info("Training...");
            this.kernel.train();
            this.logger.info("Finished trainig.");
        }
        if (this.doPrediction) {
            this.logger.info("Preparing prediction data...");
            try {
                prepareMPPredictionData();
            } catch (IOException e2) {
                e2.printStackTrace();
                System.exit(1);
            }
            this.logger.info("Finished prediction data preparation.");
            this.logger.info("Doing predictions...");
            this.kernel.predict();
            this.logger.info("Finished prediction.");
        }
    }

    private Map<String, Set<String>> readMGI2MPMapping(String str) throws IOException {
        HashMap hashMap = new HashMap();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(str)));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                return hashMap;
            }
            String[] split = readLine.split("\t");
            if (split.length >= 2) {
                String str2 = split[0];
                if (this.mgi2go.containsKey(str2)) {
                    String str3 = split[1];
                    if (str3.trim().length() != 0) {
                        String str4 = Obo2OWLConstants.DEFAULT_IRI_PREFIX + str3.replace(":", "_");
                        if (!hashMap.containsKey(str2)) {
                            hashMap.put(str2, new TreeSet());
                        }
                        ((Set) hashMap.get(str2)).add(str4);
                    }
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v67, types: [java.util.List] */
    public void prepareMPSampleTrainingData() throws IOException {
        if (!this.trainingInputDirectoryPath.endsWith(File.separator)) {
            this.trainingInputDirectoryPath += File.separator;
        }
        for (String str : readTrainURIs(this.trainURIsFilePath)) {
            String str2 = this.trainingInputDirectoryPath + getLocalPart(str);
            OWLClassImpl oWLClassImpl = new OWLClassImpl(IRI.create(str));
            TreeSet treeSet = new TreeSet();
            treeSet.add(str);
            Iterator<OWLClass> it = this.mpKBReasoner.getSubClasses(oWLClassImpl, false).getFlattened().iterator();
            while (it.hasNext()) {
                treeSet.add(it.next().getIRI().toString());
            }
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (String str3 : this.mgi2mp.keySet()) {
                Set<String> set = this.mgi2mp.get(str3);
                String str4 = Sets.intersection(treeSet, set).isEmpty() ? "0" : "1";
                if (this.mgi2go.containsKey(str3)) {
                    Iterator<String> it2 = this.mgi2go.get(str3).iterator();
                    while (it2.hasNext()) {
                        str4 = str4 + "\t" + it2.next();
                    }
                    if (Sets.intersection(treeSet, set).isEmpty()) {
                        arrayList.add(str4);
                    } else {
                        arrayList2.add(str4);
                    }
                }
            }
            Collections.shuffle(arrayList);
            int round = (int) Math.round(this.posNegExampleRatio * arrayList2.size());
            if (round < arrayList.size()) {
                arrayList = arrayList.subList(0, round);
            }
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str2));
            Iterator it3 = arrayList2.iterator();
            while (it3.hasNext()) {
                bufferedWriter.write((String) it3.next());
                bufferedWriter.newLine();
            }
            Iterator it4 = arrayList.iterator();
            while (it4.hasNext()) {
                bufferedWriter.write((String) it4.next());
                bufferedWriter.newLine();
            }
            bufferedWriter.close();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v67, types: [java.util.List] */
    private void prepareMPPredictionData() throws IOException {
        if (!this.predictionInputDirectoryPath.endsWith(File.separator)) {
            this.predictionInputDirectoryPath += File.separator;
        }
        for (String str : readTrainURIs(this.predictionURIsFilePath)) {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(this.predictionInputDirectoryPath + getLocalPart(str)));
            TreeSet treeSet = new TreeSet();
            treeSet.add(str);
            Iterator<OWLClass> it = this.mpKBReasoner.getSubClasses(new OWLClassImpl(IRI.create(str)), false).getFlattened().iterator();
            while (it.hasNext()) {
                treeSet.add(it.next().getIRI().toString());
            }
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (String str2 : this.mgi2mp.keySet()) {
                Set<String> set = this.mgi2mp.get(str2);
                String str3 = Sets.intersection(treeSet, set).isEmpty() ? "0" : "1";
                if (this.mgi2go.containsKey(str2)) {
                    Iterator<String> it2 = this.mgi2go.get(str2).iterator();
                    while (it2.hasNext()) {
                        str3 = str3 + "\t" + it2.next();
                    }
                    if (Sets.intersection(treeSet, set).isEmpty()) {
                        arrayList.add(str3);
                    } else {
                        arrayList2.add(str3);
                    }
                }
            }
            Collections.shuffle(arrayList);
            int round = (int) Math.round(this.posNegExampleRatio * arrayList2.size());
            if (round < arrayList.size()) {
                arrayList = arrayList.subList(0, round);
            }
            Iterator it3 = arrayList2.iterator();
            while (it3.hasNext()) {
                bufferedWriter.write((String) it3.next());
                bufferedWriter.newLine();
            }
            Iterator it4 = arrayList.iterator();
            while (it4.hasNext()) {
                bufferedWriter.write((String) it4.next());
                bufferedWriter.newLine();
            }
            bufferedWriter.close();
        }
    }

    private static Set<String> readTrainURIs(String str) throws IOException {
        HashSet hashSet = new HashSet();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(str)));
        while (true) {
            String readLine = bufferedReader.readLine();
            String str2 = readLine;
            if (readLine == null) {
                bufferedReader.close();
                return hashSet;
            }
            if (str2.startsWith(Tags.symLT) && str2.endsWith(Tags.symGT)) {
                str2 = str2.substring(1, str2.length() - 1);
            }
            hashSet.add(str2);
        }
    }

    private Map<String, Set<String>> readMGI2GOMapping(String str) throws IOException {
        HashMap hashMap = new HashMap();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(str)));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                return hashMap;
            }
            if (!readLine.startsWith("!")) {
                String[] split = readLine.split("\t");
                String str2 = split[1];
                String str3 = split[3];
                String str4 = split[4];
                String str5 = split[6];
                if (str4.trim().length() != 0 && !Objects.equals(str5, "ND") && !str3.contains("NOT")) {
                    String str6 = Obo2OWLConstants.DEFAULT_IRI_PREFIX + str4.replace(":", "_");
                    if (!hashMap.containsKey(str2)) {
                        hashMap.put(str2, new TreeSet());
                    }
                    ((Set) hashMap.get(str2)).add(str6);
                }
            }
        }
    }

    private static String getLocalPart(String str) {
        int lastIndexOf = str.lastIndexOf(47);
        if (lastIndexOf > -1) {
            return str.substring(lastIndexOf);
        }
        int lastIndexOf2 = str.lastIndexOf(35);
        return lastIndexOf2 > -1 ? str.substring(lastIndexOf2) : str;
    }

    public String getTrainURIsFilePath() {
        return this.trainURIsFilePath;
    }

    public void setTrainURIsFilePath(String str) {
        this.trainURIsFilePath = str;
    }

    public String getMpKBFilePath() {
        return this.mpKBFilePath;
    }

    public void setMpKBFilePath(String str) {
        this.mpKBFilePath = str;
    }

    public String getTrainingInputDirectoryPath() {
        return this.trainingInputDirectoryPath;
    }

    public void setTrainingInputDirectoryPath(String str) {
        this.trainingInputDirectoryPath = str;
    }

    public String getGoKBFilePath() {
        return this.goKBFilePath;
    }

    public void setGoKBFilePath(String str) {
        this.goKBFilePath = str;
    }

    public String getTrainingOutputDirectoryPath() {
        return this.trainingOutputDirectoryPath;
    }

    public void setTrainingOutputDirectoryPath(String str) {
        this.trainingOutputDirectoryPath = str;
    }

    public String getPredictionURIsFilePath() {
        return this.predictionURIsFilePath;
    }

    public void setPredictionURIsFilePath(String str) {
        this.predictionURIsFilePath = str;
    }

    public String getPredictionInputDirectoryPath() {
        return this.predictionInputDirectoryPath;
    }

    public void setPredictionInputDirectoryPath(String str) {
        this.predictionInputDirectoryPath = str;
    }

    public String getPredictionOutputDirectoryPath() {
        return this.predictionOutputDirectoryPath;
    }

    public void setPredictionOutputDirectoryPath(String str) {
        this.predictionOutputDirectoryPath = str;
    }

    public String getMgi2mpMappingsFilePath() {
        return this.mgi2mpMappingsFilePath;
    }

    public void setMgi2mpMappingsFilePath(String str) {
        this.mgi2mpMappingsFilePath = str;
    }

    public String getMgi2goMappingsFilePath() {
        return this.mgi2goMappingsFilePath;
    }

    public void setMgi2goMappingsFilePath(String str) {
        this.mgi2goMappingsFilePath = str;
    }

    public SemKernel.SvmType getSvmType() {
        return this.svmType;
    }

    public void setSvmType(SemKernel.SvmType svmType) {
        this.svmType = svmType;
    }

    public boolean isDoProbabilityEstimates() {
        return this.doProbabilityEstimates;
    }

    public void setDoProbabilityEstimates(boolean z) {
        this.doProbabilityEstimates = z;
    }

    public int getCrossValidationFolds() {
        return this.crossValidationFolds;
    }

    public void setCrossValidationFolds(int i) {
        this.crossValidationFolds = i;
    }

    public float getCost() {
        return this.cost;
    }

    public void setCost(float f) {
        this.cost = f;
    }

    public boolean isPredictProbability() {
        return this.predictProbability;
    }

    public void setPredictProbability(boolean z) {
        this.predictProbability = z;
    }

    public double getPosNegExampleRatio() {
        return this.posNegExampleRatio;
    }

    public void setPosNegExampleRatio(double d) {
        this.posNegExampleRatio = d;
    }

    public boolean isDoTraining() {
        return this.doTraining;
    }

    public void setDoTraining(boolean z) {
        this.doTraining = z;
    }

    public boolean isDoPrediction() {
        return this.doPrediction;
    }

    public void setDoPrediction(boolean z) {
        this.doPrediction = z;
    }
}
