/*
 * Decompiled with CFR 0.152.
 */
package org.aksw.limes.core.evaluation.evaluator;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.aksw.limes.core.datastrutures.EvaluationRun;
import org.aksw.limes.core.datastrutures.GoldStandard;
import org.aksw.limes.core.datastrutures.TaskAlgorithm;
import org.aksw.limes.core.datastrutures.TaskData;
import org.aksw.limes.core.evaluation.evaluationDataLoader.EvaluationData;
import org.aksw.limes.core.evaluation.evaluator.EvaluatorType;
import org.aksw.limes.core.evaluation.evaluator.FoldData;
import org.aksw.limes.core.evaluation.evaluator.Summary;
import org.aksw.limes.core.evaluation.qualititativeMeasures.McNemarsTest;
import org.aksw.limes.core.evaluation.qualititativeMeasures.QualitativeMeasuresEvaluator;
import org.aksw.limes.core.evaluation.quantitativeMeasures.IQuantitativeMeasure;
import org.aksw.limes.core.evaluation.quantitativeMeasures.RunRecord;
import org.aksw.limes.core.exceptions.UnsupportedMLImplementationException;
import org.aksw.limes.core.io.cache.ACache;
import org.aksw.limes.core.io.cache.HybridCache;
import org.aksw.limes.core.io.cache.Instance;
import org.aksw.limes.core.io.config.Configuration;
import org.aksw.limes.core.io.mapping.AMapping;
import org.aksw.limes.core.io.mapping.MappingFactory;
import org.aksw.limes.core.measures.mapper.MappingOperations;
import org.aksw.limes.core.ml.algorithm.AMLAlgorithm;
import org.aksw.limes.core.ml.algorithm.ActiveMLAlgorithm;
import org.aksw.limes.core.ml.algorithm.LearningParameter;
import org.aksw.limes.core.ml.algorithm.MLImplementationType;
import org.aksw.limes.core.ml.algorithm.MLResults;
import org.aksw.limes.core.ml.algorithm.SupervisedMLAlgorithm;
import org.aksw.limes.core.ml.algorithm.UnsupervisedMLAlgorithm;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Evaluator {
    static Logger logger = LoggerFactory.getLogger(Evaluator.class);
    public List<EvaluationRun> runsList = new ArrayList<EvaluationRun>();
    public Map<String, Map<String, int[]>> successesAndFailures = new HashMap<String, Map<String, int[]>>();
    public Map<String, Map<String, Map<String, Double>>> statisticalTestResults = new HashMap<String, Map<String, Map<String, Double>>>();
    private QualitativeMeasuresEvaluator eval = new QualitativeMeasuresEvaluator();

    public List<EvaluationRun> evaluate(List<TaskAlgorithm> TaskAlgorithms, List<TaskData> datasets, Set<EvaluatorType> QlMeasures, Set<IQuantitativeMeasure> QnMeasures) {
        AMapping predictions = null;
        Map<EvaluatorType, Double> evaluationResults = null;
        try {
            for (TaskAlgorithm tAlgorithm : TaskAlgorithms) {
                logger.info("Running algorihm: " + tAlgorithm.getMlAlgorithm().getName());
                for (TaskData dataset : datasets) {
                    AMLAlgorithm sml;
                    logger.info("Used dataset: " + dataset.dataName);
                    tAlgorithm.getMlAlgorithm().init(null, dataset.source, dataset.target);
                    MLResults mlModel = null;
                    if (tAlgorithm.getMlType().equals((Object)MLImplementationType.SUPERVISED_BATCH)) {
                        logger.info("Implementation type: " + MLImplementationType.SUPERVISED_BATCH);
                        sml = (SupervisedMLAlgorithm)tAlgorithm.getMlAlgorithm();
                        mlModel = ((SupervisedMLAlgorithm)sml).learn(dataset.training);
                    } else if (tAlgorithm.getMlType().equals((Object)MLImplementationType.SUPERVISED_ACTIVE)) {
                        logger.info("Implementation type: " + MLImplementationType.SUPERVISED_ACTIVE);
                        sml = (ActiveMLAlgorithm)tAlgorithm.getMlAlgorithm();
                        sml.getMl().setConfiguration(dataset.evalData.getConfigReader().getConfiguration());
                        ((ActiveMLAlgorithm)sml).activeLearn();
                        AMapping nextExamples = ((ActiveMLAlgorithm)sml).getNextExamples((int)Math.round(0.5 * (double)dataset.training.size()));
                        AMapping oracleFeedback = this.oracleFeedback(nextExamples, dataset.training);
                        mlModel = ((ActiveMLAlgorithm)sml).activeLearn(oracleFeedback);
                    } else if (tAlgorithm.getMlType().equals((Object)MLImplementationType.UNSUPERVISED)) {
                        logger.info("Implementation type: " + MLImplementationType.UNSUPERVISED);
                        sml = (UnsupervisedMLAlgorithm)tAlgorithm.getMlAlgorithm();
                        mlModel = ((UnsupervisedMLAlgorithm)sml).learn(dataset.pseudoFMeasure);
                    }
                    predictions = tAlgorithm.getMlAlgorithm().predict(dataset.source, dataset.target, mlModel);
                    logger.info("Start the evaluation of the results");
                    evaluationResults = this.eval.evaluate(predictions, dataset.goldStandard, QlMeasures);
                    EvaluationRun er = new EvaluationRun(tAlgorithm.getMlAlgorithm().getName().replaceAll("\\s+", ""), tAlgorithm.getMlType().name().replaceAll("//s", ""), dataset.dataName.replaceAll("//s", ""), evaluationResults);
                    this.runsList.add(er);
                }
            }
        }
        catch (UnsupportedMLImplementationException e) {
            e.printStackTrace();
        }
        return this.runsList;
    }

    public List<EvaluationRun> crossValidate(List<TaskAlgorithm> algorithms, List<TaskData> datasets, int foldNumber, Set<EvaluatorType> qlMeasures) {
        for (TaskData dataset : datasets) {
            List<FoldData> folds = this.generateFolds(dataset.evalData, foldNumber, false);
            for (int k = 0; k < foldNumber; ++k) {
                HashMap<String, AMapping> algoMappings = new HashMap<String, AMapping>();
                FoldData testData = folds.get(k);
                FoldData trainData = this.getTrainingFold(folds, k, foldNumber);
                trainData = this.fixCachesIfNecessary(trainData, dataset);
                AMapping trainingData = trainData.map;
                ACache trainSourceCache = trainData.sourceCache;
                ACache trainTargetCache = trainData.targetCache;
                ACache testSourceCache = testData.sourceCache;
                ACache testTargetCache = testData.targetCache;
                GoldStandard goldStandard = new GoldStandard(testData.map, testSourceCache.getAllUris(), testTargetCache.getAllUris());
                for (TaskAlgorithm tAlgo : algorithms) {
                    AMLAlgorithm algorithm = tAlgo.getMlAlgorithm();
                    long begin = System.currentTimeMillis();
                    MLResults model = this.trainModel(algorithm, tAlgo.getMlParameter(), trainingData, dataset.evalData.getConfigReader().read(), trainSourceCache, trainTargetCache);
                    AMapping prediction = algorithm.predict(testSourceCache, testTargetCache, model);
                    double runTime = (double)(System.currentTimeMillis() - begin) / 1000.0;
                    algoMappings.put(tAlgo.getName(), prediction);
                    EvaluationRun er = new EvaluationRun(tAlgo.getName(), tAlgo.getMlType().toString(), dataset.dataName, this.eval.evaluate(prediction, goldStandard, qlMeasures), k, model.getLinkSpecification());
                    er.setQuanititativeRecord(new RunRecord(k, runTime, 0.0, model.getLinkSpecification().size()));
                    er.display();
                    this.runsList.add(er);
                }
            }
        }
        return this.runsList;
    }

    public Summary crossValidateWithTuningAndMcNemarsTest(List<TaskAlgorithm> TaskAlgorithms, TaskData dataset, Set<EvaluatorType> qlMeasures, int foldNumber) {
        this.successesAndFailures = new HashMap<String, Map<String, int[]>>();
        List<FoldData> folds = this.generateFolds(dataset.evalData, foldNumber, false);
        for (int k = 0; k < foldNumber; ++k) {
            HashMap<String, AMapping> algoMappings = new HashMap<String, AMapping>();
            FoldData testData = folds.get(k);
            FoldData trainData = this.getTrainingFold(folds, k, foldNumber);
            trainData = this.fixCachesIfNecessary(trainData, dataset);
            AMapping trainingData = trainData.map;
            ACache trainSourceCache = trainData.sourceCache;
            ACache trainTargetCache = trainData.targetCache;
            ACache testSourceCache = testData.sourceCache;
            ACache testTargetCache = testData.targetCache;
            GoldStandard goldStandard = new GoldStandard(testData.map, testSourceCache.getAllUris(), testTargetCache.getAllUris());
            List<FoldData> tuneFolds = this.createTuneFolds(trainData, 5.0);
            GoldStandard tuneGold = new GoldStandard(tuneFolds.get((int)1).map, tuneFolds.get((int)1).sourceCache.getAllUris(), tuneFolds.get((int)1).targetCache.getAllUris());
            for (TaskAlgorithm tAlgo : TaskAlgorithms) {
                AMLAlgorithm algorithm = tAlgo.getMlAlgorithm();
                List<LearningParameter> params = null;
                if (tAlgo.getMlParameterValues() != null) {
                    Set<List<LearningParameter>> parameterGrid = this.createParameterGrid(tAlgo.getMlParameterValues());
                    double bestFM = 0.0;
                    for (List<LearningParameter> lps : parameterGrid) {
                        MLResults tuneModel = this.trainModel(algorithm, lps, tuneFolds.get((int)0).map, dataset.evalData.getConfigReader().read(), tuneFolds.get((int)0).sourceCache, tuneFolds.get((int)0).targetCache);
                        double current = this.eval.evaluate(algorithm.predict(tuneFolds.get((int)1).sourceCache, tuneFolds.get((int)1).targetCache, tuneModel), tuneGold, (Set<EvaluatorType>)ImmutableSet.of((Object)((Object)EvaluatorType.F_MEASURE))).get((Object)EvaluatorType.F_MEASURE);
                        if (!(current > bestFM)) continue;
                        bestFM = current;
                        params = lps;
                    }
                } else {
                    params = tAlgo.getMlParameter();
                }
                long begin = System.currentTimeMillis();
                MLResults model = this.trainModel(algorithm, params, trainingData, dataset.evalData.getConfigReader().read(), trainSourceCache, trainTargetCache);
                AMapping prediction = algorithm.predict(testSourceCache, testTargetCache, model);
                double runTime = (double)(System.currentTimeMillis() - begin) / 1000.0;
                algoMappings.put(tAlgo.getName(), prediction);
                EvaluationRun er = new EvaluationRun(tAlgo.getName(), tAlgo.getMlType().toString(), dataset.dataName, this.eval.evaluate(prediction, goldStandard, qlMeasures), k, model.getLinkSpecification());
                er.setQuanititativeRecord(new RunRecord(k, runTime, 0.0, model.getLinkSpecification().size()));
                er.display();
                this.runsList.add(er);
            }
            this.updateSuccessesAndFailures(algoMappings, testData);
        }
        for (String a : this.successesAndFailures.keySet()) {
            for (String b : this.successesAndFailures.get(a).keySet()) {
                double pValue = McNemarsTest.calculate(this.successesAndFailures.get(a).get(b));
                this.addToMapMapMap(this.statisticalTestResults, dataset.dataName, a, b, pValue);
            }
        }
        System.out.println(this.statisticalTestResults);
        Summary summary = new Summary(this.runsList, foldNumber);
        summary.setStatisticalTestResults(this.statisticalTestResults);
        return summary;
    }

    private void updateSuccessesAndFailures(Map<String, AMapping> algoMappings, FoldData testData) {
        for (String a : algoMappings.keySet()) {
            for (String b : algoMappings.keySet()) {
                int failures;
                int successes;
                if (a.equals(b)) continue;
                if ((this.successesAndFailures.get(a) == null || this.successesAndFailures.get(a).get(b) == null) && this.successesAndFailures.get(b) == null) {
                    successes = McNemarsTest.getSuccesses(algoMappings.get(a), algoMappings.get(b), testData.map);
                    failures = McNemarsTest.getSuccesses(algoMappings.get(b), algoMappings.get(a), testData.map);
                    this.addToMapMap(this.successesAndFailures, a, b, new int[]{successes, failures});
                    continue;
                }
                if (this.successesAndFailures.get(a) == null || this.successesAndFailures.get(a).get(b) == null) continue;
                successes = McNemarsTest.getSuccesses(algoMappings.get(a), algoMappings.get(b), testData.map);
                failures = McNemarsTest.getSuccesses(algoMappings.get(b), algoMappings.get(a), testData.map);
                int[] previous = this.successesAndFailures.get(a).get(b);
                previous[0] = previous[0] + successes;
                previous[1] = previous[1] + failures;
            }
        }
    }

    private <T, S, U> void addToMapMap(Map<T, Map<S, U>> mapmap, T key, S subMapKey, U item) {
        if (!mapmap.containsKey(key)) {
            HashMap<S, U> subMap = new HashMap<S, U>();
            subMap.put(subMapKey, item);
            mapmap.put(key, subMap);
        } else {
            HashMap<S, U> subMap = mapmap.get(key);
            subMap = subMap != null ? subMap : new HashMap<S, U>();
            subMap.put(subMapKey, item);
            mapmap.put(key, subMap);
        }
    }

    private <V, T, S, U> void addToMapMapMap(Map<V, Map<T, Map<S, U>>> mapmapmap, V key, T subMapKey, S subsubMapKey, U item) {
        if (!mapmapmap.containsKey(key)) {
            HashMap subMap = new HashMap();
            HashMap<S, U> subsubMap = new HashMap<S, U>();
            subsubMap.put(subsubMapKey, item);
            subMap.put(subMapKey, subsubMap);
            mapmapmap.put((HashMap<V, Map<T, Map<S, U>>>)key, subMap);
        } else {
            Map<T, Map<S, U>> subMap = mapmapmap.get(key);
            Map<Object, Object> subsubMap = null;
            subsubMap = !subMap.containsKey(subMapKey) ? new HashMap() : subMap.get(subMapKey);
            subsubMap.put(subsubMapKey, item);
            subMap.put(subMapKey, subsubMap);
            mapmapmap.put((HashMap<V, Map<T, Map<S, U>>>)key, subMap);
        }
    }

    private MLResults trainModel(AMLAlgorithm algorithm, List<LearningParameter> params, AMapping trainingData, Configuration config, ACache trainSourceCache, ACache trainTargetCache) {
        algorithm.init(params, trainSourceCache, trainTargetCache);
        algorithm.getMl().setConfiguration(config);
        MLResults model = null;
        try {
            if (algorithm instanceof SupervisedMLAlgorithm) {
                model = algorithm.asSupervised().learn(trainingData);
            } else if (algorithm instanceof ActiveMLAlgorithm) {
                model = algorithm.asActive().activeLearn(trainingData);
            }
        }
        catch (UnsupportedMLImplementationException e) {
            e.printStackTrace();
        }
        return model;
    }

    private List<FoldData> createTuneFolds(FoldData trainData, double factor) {
        AMapping tuneTraining = MappingFactory.createDefaultMapping();
        AMapping tuneTest = MappingFactory.createDefaultMapping();
        int tuneTrainingSize = (int)Math.ceil((double)trainData.map.size() / 5.0);
        for (String key : trainData.map.getMap().keySet()) {
            if (tuneTraining.size() < tuneTrainingSize) {
                tuneTraining.add(key, trainData.map.getMap().get(key));
                continue;
            }
            tuneTest.add(key, trainData.map.getMap().get(key));
        }
        ArrayList<AMapping> mappings = new ArrayList<AMapping>();
        mappings.add(tuneTraining);
        mappings.add(tuneTest);
        return this.createFoldDataFromCaches(mappings, trainData.sourceCache, trainData.targetCache);
    }

    public Set<List<LearningParameter>> createParameterGrid(Map<LearningParameter, List<Object>> parameters) {
        ArrayList grid = new ArrayList();
        for (LearningParameter lp : parameters.keySet()) {
            HashSet<LearningParameter> parameterPossibilites = new HashSet<LearningParameter>();
            for (Object value : parameters.get(lp)) {
                parameterPossibilites.add(new LearningParameter(lp.getName(), value));
            }
            grid.add(parameterPossibilites);
        }
        return Sets.cartesianProduct(grid);
    }

    private FoldData getTrainingFold(List<FoldData> folds, int k, int foldNumber) {
        FoldData trainData = new FoldData();
        for (int i = 0; i < foldNumber; ++i) {
            if (i == k) continue;
            trainData.map = MappingOperations.union(trainData.map, folds.get((int)i).map);
            trainData.sourceCache = this.cacheUnion(trainData.sourceCache, folds.get((int)i).sourceCache);
            trainData.targetCache = this.cacheUnion(trainData.targetCache, folds.get((int)i).targetCache);
        }
        return trainData;
    }

    private FoldData fixCachesIfNecessary(FoldData trainData, TaskData dataset) {
        for (String s : trainData.map.getMap().keySet()) {
            for (String t : trainData.map.getMap().get(s).keySet()) {
                if (trainData.targetCache.containsUri(t)) continue;
                trainData.targetCache.addInstance(dataset.target.getInstance(t));
            }
            if (trainData.sourceCache.containsUri(s)) continue;
            trainData.sourceCache.addInstance(dataset.source.getInstance(s));
        }
        return trainData;
    }

    public ACache cacheUnion(ACache a, ACache b) {
        HybridCache result = new HybridCache();
        for (Instance i : a.getAllInstances()) {
            ((ACache)result).addInstance(i);
        }
        for (Instance i : b.getAllInstances()) {
            ((ACache)result).addInstance(i);
        }
        return result;
    }

    public List<FoldData> generateFolds(EvaluationData data, int foldNumber, boolean withNegativeExamples) {
        ACache source = data.getSourceCache();
        ACache target = data.getTargetCache();
        AMapping refMap = data.getReferenceMapping();
        refMap = this.removeLinksWithNoInstances(refMap, source, target);
        List<AMapping> foldMaps = this.generateMappingFolds(refMap, source, target, foldNumber, withNegativeExamples);
        return this.createFoldDataFromCaches(foldMaps, source, target);
    }

    private List<FoldData> createFoldDataFromCaches(List<AMapping> foldMaps, ACache source, ACache target) {
        ArrayList<FoldData> folds = new ArrayList<FoldData>();
        for (AMapping foldMap : foldMaps) {
            HybridCache sourceFoldCache = new HybridCache();
            HybridCache targetFoldCache = new HybridCache();
            for (String s : foldMap.getMap().keySet()) {
                if (!source.containsUri(s)) continue;
                ((ACache)sourceFoldCache).addInstance(source.getInstance(s));
                for (String t : foldMap.getMap().get(s).keySet()) {
                    if (!target.containsUri(t)) continue;
                    ((ACache)targetFoldCache).addInstance(target.getInstance(t));
                }
            }
            folds.add(new FoldData(foldMap, sourceFoldCache, targetFoldCache));
        }
        return folds;
    }

    public List<AMapping> generateMappingFolds(AMapping refMap, ACache source, ACache target, int foldNumber, boolean withNegativeExamples) {
        Random rand = new Random();
        ArrayList<AMapping> foldMaps = new ArrayList<AMapping>();
        int mapSize = refMap.getMap().keySet().size();
        int foldSize = mapSize / foldNumber;
        Iterator<HashMap<String, Double>> it = refMap.getMap().values().iterator();
        ArrayList<String> values = new ArrayList<String>();
        while (it.hasNext()) {
            for (String t : it.next().keySet()) {
                values.add(t);
            }
        }
        for (int foldIndex = 0; foldIndex < foldNumber; ++foldIndex) {
            HashSet<Integer> index = new HashSet<Integer>();
            while (index.size() < foldSize) {
                int number;
                while (index.contains(number = (int)((double)mapSize * Math.random()))) {
                }
                index.add(number);
            }
            AMapping foldMap = MappingFactory.createDefaultMapping();
            int count = 0;
            for (String key : refMap.getMap().keySet()) {
                HashMap<String, Double> help;
                if (foldIndex != foldNumber - 1) {
                    if (index.contains(count) && count % 2 == 0) {
                        help = new HashMap();
                        for (String k : refMap.getMap().get(key).keySet()) {
                            help.put(k, 1.0);
                        }
                        foldMap.getMap().put(key, help);
                    } else if (withNegativeExamples && index.contains(count)) {
                        help = new HashMap();
                        help.put(Evaluator.getRandomTargetInstance(source, target, values, rand, refMap.getMap(), key, -1), 0.0);
                        foldMap.getMap().put(key, help);
                    }
                } else if (index.contains(count)) {
                    help = new HashMap<String, Double>();
                    for (String k : refMap.getMap().get(key).keySet()) {
                        help.put(k, 1.0);
                    }
                    foldMap.getMap().put(key, help);
                }
                ++count;
            }
            foldMaps.add(foldMap);
            refMap = this.removeSubMap(refMap, foldMap);
        }
        int i = 0;
        int odd = 0;
        for (String key : refMap.getMap().keySet()) {
            HashMap<String, Double> help;
            if (i != foldNumber - 1) {
                if (odd % 2 == 0) {
                    help = new HashMap();
                    for (String k : refMap.getMap().get(key).keySet()) {
                        help.put(k, 1.0);
                    }
                    ((AMapping)foldMaps.get(i)).add(key, help);
                } else {
                    help = new HashMap();
                    help.put(Evaluator.getRandomTargetInstance(source, target, values, rand, refMap.getMap(), key, -1), 0.0);
                    ((AMapping)foldMaps.get(i)).add(key, help);
                }
            } else {
                help = new HashMap<String, Double>();
                for (String k : refMap.getMap().get(key).keySet()) {
                    help.put(k, 1.0);
                }
                ((AMapping)foldMaps.get(i)).add(key, help);
            }
            ++odd;
            i = (i + 1) % foldNumber;
        }
        return foldMaps;
    }

    public static String getRandomTargetInstance(ACache source, ACache target, List<String> values, Random random, HashMap<String, HashMap<String, Double>> refMap, String sourceInstance, int previousRandom) {
        int randomInt;
        while ((randomInt = random.nextInt(values.size())) == previousRandom) {
        }
        String tmpTarget = values.get(randomInt);
        if (refMap.get(sourceInstance).get(tmpTarget) == null && target.getInstance(tmpTarget) != null) {
            return tmpTarget;
        }
        return Evaluator.getRandomTargetInstance(source, target, values, random, refMap, sourceInstance, randomInt);
    }

    public AMapping removeSubMap(AMapping mainMap, AMapping subMap) {
        AMapping result = MappingFactory.createDefaultMapping();
        double value = 0.0;
        for (String mainMapSourceUri : mainMap.getMap().keySet()) {
            for (String mainMapTargetUri : mainMap.getMap().get(mainMapSourceUri).keySet()) {
                if (subMap.contains(mainMapSourceUri, mainMapTargetUri)) continue;
                result.add(mainMapSourceUri, mainMapTargetUri, value);
            }
        }
        return result;
    }

    private AMapping removeLinksWithNoInstances(AMapping map, ACache source, ACache target) {
        AMapping result = MappingFactory.createDefaultMapping();
        for (String s : map.getMap().keySet()) {
            for (String t : map.getMap().get(s).keySet()) {
                if (!source.containsUri(s) || !target.containsUri(t)) continue;
                result.add(s, t, map.getMap().get(s).get(t));
            }
        }
        return result;
    }

    private AMapping oracleFeedback(AMapping predictionMapping, AMapping referenceMapping) {
        AMapping result = MappingFactory.createDefaultMapping();
        for (String s : predictionMapping.getMap().keySet()) {
            for (String t : predictionMapping.getMap().get(s).keySet()) {
                if (referenceMapping.contains(s, t)) {
                    result.add(s, t, 1.0);
                    continue;
                }
                result.add(s, t, 0.0);
            }
        }
        return result;
    }
}

