package edu.berkeley.compbio.ml.cluster.hierarchical;

import com.davidsoergel.conja.Function;
import com.davidsoergel.conja.Parallel;
import com.davidsoergel.dsutils.collections.IndexedSymmetric2dBiMapWithDefault;
import com.davidsoergel.dsutils.collections.InsertionTrackingSet;
import com.davidsoergel.stats.DissimilarityMeasure;
import edu.berkeley.compbio.ml.cluster.CentroidCluster;
import edu.berkeley.compbio.ml.cluster.ClusterMove;
import edu.berkeley.compbio.ml.cluster.Clusterable;
import edu.berkeley.compbio.ml.cluster.ClusterableIterator;
import edu.berkeley.compbio.ml.cluster.NoGoodClusterException;
import edu.berkeley.compbio.ml.cluster.PointClusterFilter;
import edu.berkeley.compbio.ml.cluster.ProhibitionModel;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.log4j.Logger;

/* loaded from: input_file:lib/ml-0.921.jar:edu/berkeley/compbio/ml/cluster/hierarchical/OnlineAgglomerativeClustering.class */
public class OnlineAgglomerativeClustering<T extends Clusterable<T>> extends OnlineHierarchicalClusteringMethod<T> {
    private static final Logger logger = Logger.getLogger(OnlineAgglomerativeClustering.class);
    public static final Float LONG_DISTANCE = Float.valueOf(Float.MAX_VALUE);
    protected final IndexedSymmetric2dBiMapWithDefault<HierarchicalCentroidCluster<T>, Float> theActiveNodeDistanceMatrix;
    private HierarchicalCentroidCluster<T> theRoot;
    private final Map<T, HierarchicalCentroidCluster<T>> sampleToLeafClusterMap;
    private final AtomicInteger idCount;
    HierarchicalCentroidCluster<T> saveNode;
    double threshold;
    Agglomerator<T> agglomerator;

    public OnlineAgglomerativeClustering(DissimilarityMeasure<T> dissimilarityMeasure, Set<String> set, Map<String, Set<String>> map, ProhibitionModel<T> prohibitionModel, Set<String> set2, Double d, Agglomerator agglomerator) {
        super(dissimilarityMeasure, set, map, prohibitionModel, set2);
        this.theActiveNodeDistanceMatrix = new IndexedSymmetric2dBiMapWithDefault<>(LONG_DISTANCE);
        this.sampleToLeafClusterMap = new HashMap();
        this.idCount = new AtomicInteger(0);
        this.threshold = d.doubleValue();
        this.agglomerator = agglomerator;
    }

    public int numActiveClusters() {
        return this.theActiveNodeDistanceMatrix.numKeys();
    }

    public int numDistances() {
        return this.theActiveNodeDistanceMatrix.numPairs();
    }

    @Override // edu.berkeley.compbio.ml.cluster.AbstractClusteringMethod
    public ClusterMove<T, HierarchicalCentroidCluster<T>> bestClusterMove(T t) throws NoGoodClusterException {
        ClusterMove<T, HierarchicalCentroidCluster<T>> clusterMove = new ClusterMove<>();
        clusterMove.bestDistance = Double.POSITIVE_INFINITY;
        HierarchicalCentroidCluster<T> hierarchicalCentroidCluster = this.sampleToLeafClusterMap.get(t);
        if (hierarchicalCentroidCluster != null) {
            clusterMove.bestCluster = hierarchicalCentroidCluster;
            clusterMove.bestDistance = CMAESOptimizer.DEFAULT_STOPFITNESS;
            return clusterMove;
        }
        PointClusterFilter filter = this.prohibitionModel == null ? null : this.prohibitionModel.getFilter(t);
        for (C c : getClusters()) {
            if (filter == null || !filter.isProhibited(c)) {
                double distanceFromTo = this.measure.distanceFromTo(t, c.getCentroid());
                if (distanceFromTo < clusterMove.bestDistance) {
                    clusterMove.bestCluster = c;
                    clusterMove.bestDistance = distanceFromTo;
                }
            }
        }
        if (clusterMove.bestCluster == null) {
            throw new NoGoodClusterException("No cluster found for point: " + t);
        }
        return clusterMove;
    }

    @Override // edu.berkeley.compbio.ml.cluster.AbstractSupervisedOnlineClusteringMethod
    protected synchronized void trainWithKnownTrainingLabels(ClusterableIterator<T> clusterableIterator) {
        Clusterable clusterable = (Clusterable) clusterableIterator.next();
        Clusterable clusterable2 = (Clusterable) clusterableIterator.next();
        HierarchicalCentroidCluster hierarchicalCentroidCluster = new HierarchicalCentroidCluster(this.idCount.getAndIncrement(), clusterable);
        hierarchicalCentroidCluster.getMutableWeightedLabels().incrementItemCount(1);
        hierarchicalCentroidCluster.doneLabelling();
        HierarchicalCentroidCluster hierarchicalCentroidCluster2 = new HierarchicalCentroidCluster(this.idCount.getAndIncrement(), clusterable2);
        hierarchicalCentroidCluster2.getMutableWeightedLabels().incrementItemCount(1);
        hierarchicalCentroidCluster2.doneLabelling();
        this.theActiveNodeDistanceMatrix.put(hierarchicalCentroidCluster, hierarchicalCentroidCluster2, Float.valueOf((float) this.measure.distanceFromTo(hierarchicalCentroidCluster.getPayload().getCentroid(), hierarchicalCentroidCluster2.getPayload().getCentroid())));
        Parallel.forEach(clusterableIterator, new Function<T, Void>() { // from class: edu.berkeley.compbio.ml.cluster.hierarchical.OnlineAgglomerativeClustering.1
            static final /* synthetic */ boolean $assertionsDisabled;

            @Override // com.davidsoergel.conja.Function
            public Void apply(T t) {
                HierarchicalCentroidCluster hierarchicalCentroidCluster3 = new HierarchicalCentroidCluster(OnlineAgglomerativeClustering.this.idCount.getAndIncrement(), t);
                hierarchicalCentroidCluster3.getMutableWeightedLabels().incrementItemCount(1);
                hierarchicalCentroidCluster3.doneLabelling();
                InsertionTrackingSet<HierarchicalCentroidCluster<T>> keys = OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.getKeys();
                int hashCode = hierarchicalCentroidCluster3.hashCode();
                HashMap hashMap = new HashMap(OnlineAgglomerativeClustering.this.getClusters().size());
                for (HierarchicalCentroidCluster<T> hierarchicalCentroidCluster4 : keys) {
                    if (hashCode <= hierarchicalCentroidCluster4.hashCode() && !hierarchicalCentroidCluster3.equals(hierarchicalCentroidCluster4)) {
                        hashMap.put(hierarchicalCentroidCluster4, Float.valueOf((float) OnlineAgglomerativeClustering.this.measure.distanceFromTo(((CentroidCluster) hierarchicalCentroidCluster3.getPayload()).getCentroid(), hierarchicalCentroidCluster4.getPayload().getCentroid())));
                    }
                }
                synchronized (OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix) {
                    InsertionTrackingSet<HierarchicalCentroidCluster<T>> keys2 = OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.getKeys();
                    if (keys2.size() > 0) {
                        HashSet<HierarchicalCentroidCluster> hashSet = new HashSet(keys2);
                        hashMap.keySet().retainAll(hashSet);
                        for (Map.Entry entry : hashMap.entrySet()) {
                            HierarchicalCentroidCluster hierarchicalCentroidCluster5 = (HierarchicalCentroidCluster) entry.getKey();
                            OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.put(hierarchicalCentroidCluster3, hierarchicalCentroidCluster5, Float.valueOf(((Float) entry.getValue()).floatValue()));
                            hashSet.remove(hierarchicalCentroidCluster5);
                        }
                        for (HierarchicalCentroidCluster hierarchicalCentroidCluster6 : hashSet) {
                            OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.put(hierarchicalCentroidCluster3, hierarchicalCentroidCluster6, Float.valueOf((float) OnlineAgglomerativeClustering.this.measure.distanceFromTo(hierarchicalCentroidCluster3.getCentroid(), hierarchicalCentroidCluster6.getCentroid())));
                        }
                        while (((Float) OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.getSmallestValue()).floatValue() <= OnlineAgglomerativeClustering.this.threshold) {
                            try {
                                HierarchicalCentroidCluster<T> hierarchicalCentroidCluster7 = (HierarchicalCentroidCluster) OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.getKey1WithSmallestValue();
                                HierarchicalCentroidCluster<T> hierarchicalCentroidCluster8 = (HierarchicalCentroidCluster) OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.getKey2WithSmallestValue();
                                if (!$assertionsDisabled && !OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.getKeys().contains(hierarchicalCentroidCluster7)) {
                                    throw new AssertionError();
                                }
                                if (!$assertionsDisabled && !OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.getKeys().contains(hierarchicalCentroidCluster8)) {
                                    throw new AssertionError();
                                }
                                if (!$assertionsDisabled && OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.getKeys().contains(hierarchicalCentroidCluster7)) {
                                    throw new AssertionError();
                                }
                                if (!$assertionsDisabled && OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.getKeys().contains(hierarchicalCentroidCluster8)) {
                                    throw new AssertionError();
                                }
                                HierarchicalCentroidCluster<T> joinNodes = OnlineAgglomerativeClustering.this.agglomerator.joinNodes(OnlineAgglomerativeClustering.this.idCount.getAndIncrement(), hierarchicalCentroidCluster7, hierarchicalCentroidCluster8, OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix);
                                OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.remove(hierarchicalCentroidCluster7);
                                OnlineAgglomerativeClustering.this.theActiveNodeDistanceMatrix.remove(hierarchicalCentroidCluster8);
                                OnlineAgglomerativeClustering.this.addCluster(joinNodes);
                            } catch (NoSuchElementException e) {
                            }
                        }
                    }
                }
                return null;
            }

            static {
                $assertionsDisabled = !OnlineAgglomerativeClustering.class.desiredAssertionStatus();
            }
        });
        logger.info("Batch clustering remaining " + this.theActiveNodeDistanceMatrix.numKeys() + " nodes");
        BatchAgglomerativeClusteringMethod batchAgglomerativeClusteringMethod = new BatchAgglomerativeClusteringMethod(this.measure, this.potentialTrainingBins, this.predictLabelSets, this.prohibitionModel, this.testLabels, this.theClusters, this.assignments, this.n, this.agglomerator, this.theActiveNodeDistanceMatrix);
        batchAgglomerativeClusteringMethod.train();
        this.theRoot = batchAgglomerativeClusteringMethod.getTree();
        normalizeClusterLabelProbabilities();
        doneLabellingClusters();
    }

    @Override // edu.berkeley.compbio.ml.cluster.hierarchical.HierarchicalClusteringMethod
    public HierarchicalCentroidCluster<T> getTree() {
        return this.theRoot;
    }
}
