/*
 * Decompiled with CFR 0.152.
 */
package org.aksw.limes.core.controller;

import java.util.List;
import java.util.NoSuchElementException;
import java.util.Scanner;
import org.aksw.limes.core.evaluation.evaluator.EvaluatorFactory;
import org.aksw.limes.core.evaluation.evaluator.EvaluatorType;
import org.aksw.limes.core.evaluation.qualititativeMeasures.PseudoFMeasure;
import org.aksw.limes.core.exceptions.UnsupportedMLImplementationException;
import org.aksw.limes.core.io.cache.ACache;
import org.aksw.limes.core.io.config.Configuration;
import org.aksw.limes.core.io.mapping.AMapping;
import org.aksw.limes.core.io.mapping.MappingFactory;
import org.aksw.limes.core.io.mapping.reader.AMappingReader;
import org.aksw.limes.core.io.mapping.reader.CSVMappingReader;
import org.aksw.limes.core.io.mapping.reader.RDFMappingReader;
import org.aksw.limes.core.ml.algorithm.ACoreMLAlgorithm;
import org.aksw.limes.core.ml.algorithm.ActiveMLAlgorithm;
import org.aksw.limes.core.ml.algorithm.LearningParameter;
import org.aksw.limes.core.ml.algorithm.MLAlgorithmFactory;
import org.aksw.limes.core.ml.algorithm.MLImplementationType;
import org.aksw.limes.core.ml.algorithm.MLResults;
import org.aksw.limes.core.ml.algorithm.SupervisedMLAlgorithm;
import org.aksw.limes.core.ml.algorithm.UnsupervisedMLAlgorithm;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MLPipeline {
    public static final Logger logger = LoggerFactory.getLogger(MLPipeline.class);

    public static AMapping execute(ACache source, ACache target, Configuration configuration, String mlAlgorithmName, MLImplementationType mlImplementationType, List<LearningParameter> learningParameters, String trainingDataFile, EvaluatorType pfmType, int maxIt) throws UnsupportedMLImplementationException {
        Class<? extends ACoreMLAlgorithm> clazz = MLAlgorithmFactory.getAlgorithmType(mlAlgorithmName);
        AMapping trainingDataMap = MappingFactory.createDefaultMapping();
        if (mlImplementationType == MLImplementationType.SUPERVISED_BATCH) {
            AMappingReader mappingReader = trainingDataFile.endsWith(".csv") ? new CSVMappingReader(trainingDataFile) : new RDFMappingReader(trainingDataFile);
            trainingDataMap = mappingReader.read();
        }
        switch (mlImplementationType) {
            case SUPERVISED_BATCH: {
                SupervisedMLAlgorithm mls = new SupervisedMLAlgorithm(clazz);
                mls.init(learningParameters, source, target);
                mls.getMl().setConfiguration(configuration);
                MLResults mlm = mls.learn(trainingDataMap);
                logger.info("Learned: " + mlm.getLinkSpecification().getFullExpression() + " with threshold: " + mlm.getLinkSpecification().getThreshold());
                return mls.predict(source, target, mlm);
            }
            case SUPERVISED_ACTIVE: {
                ActiveMLAlgorithm mla = new ActiveMLAlgorithm(clazz);
                mla.init(learningParameters, source, target);
                mla.getMl().setConfiguration(configuration);
                MLResults mlm = mla.activeLearn();
                Scanner scan = new Scanner(System.in);
                int i = 0;
                while (true) {
                    logger.info("To rate the " + ++i + ". set of examples, write 'r' and press enter.\nTo quit learning at this point and write out the mapping, write 'q' and press enter.\nFor rating examples, use numbers in [-1,+1].\n\t(-1 := strong negative example, +1 := strong positive example)");
                    String reply = scan.next();
                    if (reply.trim().equals("q")) break;
                    AMapping nextExamples = mla.getNextExamples(maxIt);
                    int j = 0;
                    for (String s : nextExamples.getMap().keySet()) {
                        for (String t : nextExamples.getMap().get(s).keySet()) {
                            boolean rated = false;
                            ++j;
                            do {
                                String evaluationMsg = "Exemplar #" + i + "." + j + ": (" + s + ", " + t + ")";
                                try {
                                    logger.info(evaluationMsg);
                                    double rating = scan.nextDouble();
                                    if (rating >= -1.0 && rating <= 1.0) {
                                        nextExamples.getMap().get(s).put(t, rating);
                                        rated = true;
                                        continue;
                                    }
                                    logger.error("Input number out of range [-1,+1], please try again...");
                                }
                                catch (NoSuchElementException e) {
                                    logger.error("Input did not match floating point number, please try again...");
                                    scan.next();
                                }
                            } while (!rated);
                        }
                    }
                    mlm = mla.activeLearn(nextExamples);
                }
                logger.info("Learned: " + mlm.getLinkSpecification().getFullExpression() + " with threshold: " + mlm.getLinkSpecification().getThreshold());
                return mla.predict(source, target, mlm);
            }
            case UNSUPERVISED: {
                UnsupervisedMLAlgorithm mlu = new UnsupervisedMLAlgorithm(clazz);
                mlu.init(learningParameters, source, target);
                mlu.getMl().setConfiguration(configuration);
                PseudoFMeasure pfm = null;
                if (pfmType != null) {
                    pfm = (PseudoFMeasure)EvaluatorFactory.create(pfmType);
                }
                MLResults mlm = mlu.learn(pfm);
                logger.info("Learned: " + mlm.getLinkSpecification().getFullExpression() + " with threshold: " + mlm.getLinkSpecification().getThreshold());
                return mlu.predict(source, target, mlm);
            }
        }
        throw new UnsupportedMLImplementationException(clazz.getName());
    }
}

