/*
 * Decompiled with CFR 0.152.
 */
package org.aksw.limes.core.ml.algorithm;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import org.aksw.limes.core.datastrutures.LogicOperator;
import org.aksw.limes.core.datastrutures.Tree;
import org.aksw.limes.core.evaluation.qualititativeMeasures.PseudoFMeasure;
import org.aksw.limes.core.exceptions.UnsupportedMLImplementationException;
import org.aksw.limes.core.io.cache.ACache;
import org.aksw.limes.core.io.ls.LinkSpecification;
import org.aksw.limes.core.io.mapping.AMapping;
import org.aksw.limes.core.io.mapping.MappingFactory;
import org.aksw.limes.core.measures.mapper.MappingOperations;
import org.aksw.limes.core.ml.algorithm.LearningParameter;
import org.aksw.limes.core.ml.algorithm.MLImplementationType;
import org.aksw.limes.core.ml.algorithm.MLResults;
import org.aksw.limes.core.ml.algorithm.classifier.ExtendedClassifier;
import org.aksw.limes.core.ml.algorithm.wombat.AWombat;
import org.aksw.limes.core.ml.algorithm.wombat.LinkEntropy;
import org.aksw.limes.core.ml.algorithm.wombat.RefinementNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WombatSimple
extends AWombat {
    private static final Logger logger = LoggerFactory.getLogger(WombatSimple.class);
    private static final String ALGORITHM_NAME = "Wombat Simple";
    private static final int activeLearningRate = 3;
    private RefinementNode bestSolutionNode = null;
    private List<ExtendedClassifier> classifiers = null;
    private Tree<RefinementNode> refinementTreeRoot = null;

    protected WombatSimple() {
    }

    @Override
    protected String getName() {
        return ALGORITHM_NAME;
    }

    @Override
    protected void init(List<LearningParameter> lp, ACache sourceCache, ACache targetCache) {
        super.init(lp, sourceCache, targetCache);
        sourceUris = sourceCache.getAllUris();
        targetUris = targetCache.getAllUris();
        this.bestSolutionNode = null;
        this.classifiers = null;
    }

    @Override
    protected MLResults learn(AMapping trainingData) {
        this.trainingData = trainingData;
        this.fillSampleSourceTargetCaches(trainingData);
        this.sourceCache = this.sourceSample;
        this.targetCache = this.targetSample;
        return this.learn();
    }

    private MLResults learn() {
        String bestMetricExpr;
        if (this.bestSolutionNode == null) {
            this.bestSolutionNode = this.findBestSolution();
        }
        if (!(bestMetricExpr = this.bestSolutionNode.getMetricExpression()).equals("")) {
            double threshold = Double.parseDouble(bestMetricExpr.substring(bestMetricExpr.lastIndexOf("|") + 1, bestMetricExpr.length()));
            AMapping bestMapping = this.bestSolutionNode.getMapping();
            LinkSpecification bestLS = new LinkSpecification(bestMetricExpr, threshold);
            double bestFMeasure = this.bestSolutionNode.getFMeasure();
            return new MLResults(bestLS, bestMapping, bestFMeasure, null);
        }
        return null;
    }

    @Override
    protected MLResults learn(PseudoFMeasure pfm) {
        this.pseudoFMeasure = pfm != null ? pfm : new PseudoFMeasure();
        this.isUnsupervised = true;
        return this.learn();
    }

    @Override
    protected boolean supports(MLImplementationType mlType) {
        return mlType == MLImplementationType.SUPERVISED_BATCH || mlType == MLImplementationType.UNSUPERVISED || mlType == MLImplementationType.SUPERVISED_ACTIVE;
    }

    @Override
    protected AMapping getNextExamples(int size) throws UnsupportedMLImplementationException {
        List<RefinementNode> bestNodes = this.getBestKNodes(this.refinementTreeRoot, 3);
        AMapping intersectionMapping = MappingFactory.createDefaultMapping();
        AMapping unionMapping = MappingFactory.createDefaultMapping();
        for (RefinementNode sn : bestNodes) {
            intersectionMapping = MappingOperations.intersection(intersectionMapping, sn.getMapping());
            unionMapping = MappingOperations.union(unionMapping, sn.getMapping());
        }
        AMapping posEntropyMapping = MappingOperations.difference(unionMapping, intersectionMapping);
        TreeSet<LinkEntropy> linkEntropy = new TreeSet<LinkEntropy>();
        int entropyPos = 0;
        int entropyNeg = 0;
        for (String s : posEntropyMapping.getMap().keySet()) {
            for (String t : posEntropyMapping.getMap().get(s).keySet()) {
                for (RefinementNode sn : bestNodes) {
                    if (sn.getMapping().contains(s, t)) {
                        ++entropyPos;
                        continue;
                    }
                    ++entropyNeg;
                }
                int entropy = entropyPos * entropyNeg;
                linkEntropy.add(new LinkEntropy(s, t, entropy));
            }
        }
        ArrayList<LinkEntropy> highestEntropyLinks = new ArrayList<LinkEntropy>();
        Iterator itr = linkEntropy.descendingIterator();
        for (int i = 0; itr.hasNext() && i < size; ++i) {
            highestEntropyLinks.add((LinkEntropy)itr.next());
        }
        AMapping result = MappingFactory.createDefaultMapping();
        for (LinkEntropy l : highestEntropyLinks) {
            result.add(l.getSourceUri(), l.getTargetUri(), l.getEntropy());
        }
        return result;
    }

    @Override
    protected MLResults activeLearn() {
        return this.learn(new PseudoFMeasure());
    }

    @Override
    protected MLResults activeLearn(AMapping oracleMapping) throws UnsupportedMLImplementationException {
        this.isUnsupervised = false;
        this.trainingData = MappingOperations.union(this.trainingData, oracleMapping);
        this.updateScores(this.refinementTreeRoot);
        this.bestSolutionNode = this.findBestSolution();
        String bestMetricExpr = this.bestSolutionNode.getMetricExpression();
        double threshold = Double.parseDouble(bestMetricExpr.substring(bestMetricExpr.lastIndexOf("|") + 1, bestMetricExpr.length()));
        AMapping bestMapping = this.bestSolutionNode.getMapping();
        LinkSpecification bestLS = new LinkSpecification(bestMetricExpr, threshold);
        double bestFMeasure = this.bestSolutionNode.getFMeasure();
        return new MLResults(bestLS, bestMapping, bestFMeasure, null);
    }

    protected void updateScores(Tree<RefinementNode> r) {
        if (r.getchildren() == null || r.getchildren().size() == 0) {
            r.getValue().setfMeasure(this.fMeasure(r.getValue().getMapping()));
            return;
        }
        for (Tree<RefinementNode> child : r.getchildren()) {
            if (!(child.getValue().getFMeasure() >= 0.0)) continue;
            r.getValue().setfMeasure(this.fMeasure(r.getValue().getMapping()));
            this.updateScores(child);
        }
    }

    public RefinementNode findBestSolution() {
        this.classifiers = this.findInitialClassifiers();
        this.createRefinementTreeRoot();
        Tree<RefinementNode> mostPromisingNode = this.getMostPromisingNode(this.refinementTreeRoot);
        logger.debug("Most promising node: " + mostPromisingNode.getValue());
        for (int i = 1; mostPromisingNode.getValue().getFMeasure() < this.getMaxFitnessThreshold() && this.refinementTreeRoot.size() <= (long)this.getMaxRefinmentTreeSize() && i <= this.getMaxIterationNumber(); ++i) {
            this.expandNode(mostPromisingNode);
            mostPromisingNode = this.getMostPromisingNode(this.refinementTreeRoot);
            if (mostPromisingNode.getValue().getFMeasure() == -1.7976931348623157E308) break;
            logger.debug("Most promising node: " + mostPromisingNode.getValue());
        }
        RefinementNode bestSolution = this.getBestNode(this.refinementTreeRoot).getValue();
        logger.debug("Overall Best Solution: " + bestSolution);
        return bestSolution;
    }

    protected List<RefinementNode> getBestKNodes(Tree<RefinementNode> r, int k) {
        TreeSet<RefinementNode> ts = new TreeSet<RefinementNode>();
        TreeSet<RefinementNode> sortedNodes = this.getSortedNodes(r, this.getOverAllPenaltyWeight(), ts);
        ArrayList<RefinementNode> resultList = new ArrayList<RefinementNode>();
        Iterator<RefinementNode> itr = sortedNodes.descendingIterator();
        for (int i = 0; itr.hasNext() && i < k; ++i) {
            resultList.add(itr.next());
        }
        return resultList;
    }

    protected TreeSet<RefinementNode> getSortedNodes(Tree<RefinementNode> r, double penaltyWeight, TreeSet<RefinementNode> result) {
        if (r.getchildren() == null || r.getchildren().size() == 0) {
            result.add(r.getValue());
            return result;
        }
        for (Tree<RefinementNode> child : r.getchildren()) {
            if (!(child.getValue().getFMeasure() >= 0.0)) continue;
            result.add(r.getValue());
            return this.getSortedNodes(child, penaltyWeight, result);
        }
        return null;
    }

    private void expandNode(Tree<RefinementNode> node) {
        AMapping map = MappingFactory.createDefaultMapping();
        for (ExtendedClassifier c : this.classifiers) {
            for (LogicOperator op : LogicOperator.values()) {
                if (node.getValue().getMetricExpression().equals(c.getMetricExpression())) continue;
                if (op.equals((Object)LogicOperator.AND)) {
                    map = MappingOperations.intersection(node.getValue().getMapping(), c.getMapping());
                } else if (op.equals((Object)LogicOperator.OR)) {
                    map = MappingOperations.union(node.getValue().getMapping(), c.getMapping());
                } else if (op.equals((Object)LogicOperator.MINUS)) {
                    map = MappingOperations.difference(node.getValue().getMapping(), c.getMapping());
                }
                String metricExpr = (Object)((Object)op) + "(" + node.getValue().getMetricExpression() + "," + c.getMetricExpression() + ")|0";
                RefinementNode child = this.createNode(map, metricExpr);
                node.addChild(new Tree<RefinementNode>(child));
            }
        }
        if (this.isVerbose()) {
            this.refinementTreeRoot.print();
        }
    }

    protected void createRefinementTreeRoot() {
        RefinementNode initialNode = new RefinementNode(-1.7976931348623157E308, MappingFactory.createMapping(MappingFactory.MappingType.DEFAULT), "");
        this.refinementTreeRoot = new Tree<RefinementNode>(null, initialNode, null);
        for (ExtendedClassifier c : this.classifiers) {
            RefinementNode n = new RefinementNode(c.getfMeasure(), c.getMapping(), c.getMetricExpression());
            this.refinementTreeRoot.addChild(new Tree<RefinementNode>(this.refinementTreeRoot, n, null));
        }
        if (this.isVerbose()) {
            this.refinementTreeRoot.print();
        }
    }
}

