package edu.berkeley.compbio.jlibsvm.multi;

import com.davidsoergel.conja.Function;
import com.davidsoergel.conja.Parallel;
import com.davidsoergel.stats.DissimilarityMeasure;
import com.davidsoergel.stats.DistributionException;
import com.davidsoergel.trees.htpn.HierarchicalTypedPropertyNode;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter;
import edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationSVM;
import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModel;
import edu.berkeley.compbio.ml.cluster.AbstractClusteringMethod;
import edu.berkeley.compbio.ml.cluster.BasicBatchCluster;
import edu.berkeley.compbio.ml.cluster.ClusterException;
import edu.berkeley.compbio.ml.cluster.ClusterMove;
import edu.berkeley.compbio.ml.cluster.Clusterable;
import edu.berkeley.compbio.ml.cluster.ClusterableIterator;
import edu.berkeley.compbio.ml.cluster.ClusteringTestResults;
import edu.berkeley.compbio.ml.cluster.NoGoodClusterException;
import edu.berkeley.compbio.ml.cluster.PointClusterFilter;
import edu.berkeley.compbio.ml.cluster.ProhibitionModel;
import edu.berkeley.compbio.ml.cluster.SampleInitializedBatchClusteringMethod;
import edu.berkeley.compbio.ml.cluster.SupervisedClusteringMethod;
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* loaded from: input_file:lib/jlibsvm-0.911.jar:edu/berkeley/compbio/jlibsvm/multi/MultiClassificationSVMAdapter.class */
public class MultiClassificationSVMAdapter<T extends Clusterable<T>> extends AbstractClusteringMethod<T, BasicBatchCluster<T>> implements SampleInitializedBatchClusteringMethod<T>, SupervisedClusteringMethod<T> {
    private static final Logger logger = Logger.getLogger(MultiClassificationSVMAdapter.class);
    final ImmutableSvmParameter<BasicBatchCluster<T>, T> param;
    final Map<T, BasicBatchCluster<T>> examples;
    final Map<T, Integer> exampleIds;
    Map<String, BasicBatchCluster<T>> theClusterMap;
    private MultiClassModel<BasicBatchCluster<T>, T> model;
    private BinaryClassificationSVM<BasicBatchCluster<T>, T> binarySvm;
    final AtomicInteger trainingCount;

    public MultiClassificationSVMAdapter(Set<String> set, Map<String, Set<String>> map, ProhibitionModel<T> prohibitionModel, Set<String> set2, @NotNull ImmutableSvmParameter<BasicBatchCluster<T>, T> immutableSvmParameter) {
        super(null, set, map, prohibitionModel, set2);
        this.examples = new HashMap();
        this.exampleIds = new HashMap();
        this.trainingCount = new AtomicInteger(0);
        this.param = immutableSvmParameter;
    }

    public void setBinarySvm(BinaryClassificationSVM<BasicBatchCluster<T>, T> binaryClassificationSVM) {
        this.binarySvm = binaryClassificationSVM;
    }

    @Override // edu.berkeley.compbio.ml.cluster.SampleInitializedBatchClusteringMethod
    public void initializeWithSamples(ClusterableIterator<T> clusterableIterator) {
        Parallel.forEach(clusterableIterator, new Function<T, Void>() { // from class: edu.berkeley.compbio.jlibsvm.multi.MultiClassificationSVMAdapter.1
            @Override // com.davidsoergel.conja.Function
            public Void apply(@Nullable T t) {
                MultiClassificationSVMAdapter.this.add(t);
                return null;
            }
        });
        logger.info("Prepared " + this.trainingCount + " training samples");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void add(T t) {
        BasicBatchCluster<T> basicBatchCluster = this.theClusterMap.get((String) t.getImmutableWeightedLabels().getDominantKeyInSet(this.potentialTrainingBins));
        basicBatchCluster.add(t);
        this.examples.put(t, basicBatchCluster);
        this.exampleIds.put(t, Integer.valueOf(this.trainingCount.intValue()));
        this.trainingCount.incrementAndGet();
    }

    @Override // edu.berkeley.compbio.ml.cluster.BatchClusteringMethod
    public void createClusters() {
        this.theClusterMap = new HashMap(this.potentialTrainingBins.size());
        int i = 0;
        for (String str : this.potentialTrainingBins) {
            if (this.theClusterMap.get(str) == null) {
                int i2 = i;
                i++;
                BasicBatchCluster<T> basicBatchCluster = new BasicBatchCluster<>(i2);
                this.theClusterMap.put(str, basicBatchCluster);
                addCluster(basicBatchCluster);
            }
        }
    }

    @Override // edu.berkeley.compbio.ml.cluster.BatchClusteringMethod
    public void train() {
        MultiClassificationSVM multiClassificationSVM = new MultiClassificationSVM(this.binarySvm);
        MultiClassProblemImpl multiClassProblemImpl = new MultiClassProblemImpl(BasicBatchCluster.class, new BatchClusterLabelInverter(), this.examples, this.exampleIds, new NoopScalingModel());
        logger.debug("Performing multiclass training");
        this.model = multiClassificationSVM.train((MultiClassProblem) multiClassProblemImpl, (ImmutableSvmParameter) this.param);
        removeEmptyClusters();
        normalizeClusterLabelProbabilities();
        doneLabellingClusters();
    }

    public void putResults(HierarchicalTypedPropertyNode<String, Serializable, ?> hierarchicalTypedPropertyNode) {
    }

    private MultiClassModel<BasicBatchCluster<T>, T> makeMultiClassModelWithProhibition(@Nullable T t) {
        HashSet hashSet = new HashSet();
        PointClusterFilter<T> filter = this.prohibitionModel == null ? null : this.prohibitionModel.getFilter(t);
        for (BasicBatchCluster<T> basicBatchCluster : this.model.getLabels()) {
            if (filter != null && filter.isProhibited(basicBatchCluster)) {
                hashSet.add(basicBatchCluster);
            }
        }
        return new MultiClassModel<>(this.model, hashSet);
    }

    @Override // edu.berkeley.compbio.ml.cluster.AbstractClusteringMethod, edu.berkeley.compbio.ml.cluster.ClusteringMethod
    public synchronized ClusteringTestResults test(ClusterableIterator<T> clusterableIterator, DissimilarityMeasure<String> dissimilarityMeasure) throws DistributionException, ClusterException {
        ClusteringTestResults test2 = super.test(clusterableIterator, dissimilarityMeasure);
        test2.setInfo(this.model.getInfo());
        return test2;
    }

    @Override // edu.berkeley.compbio.ml.cluster.AbstractClusteringMethod
    public ClusterMove<T, BasicBatchCluster<T>> bestClusterMove(T t) throws NoGoodClusterException {
        MultiClassModel<BasicBatchCluster<T>, T> multiClassModel = this.model;
        if (this.prohibitionModel != null) {
            try {
                multiClassModel = makeMultiClassModelWithProhibition(t);
            } catch (NoSuchElementException e) {
            }
        }
        VotingResult<BasicBatchCluster<T>> predictLabelWithQuality = multiClassModel.predictLabelWithQuality(t);
        ClusterMove<T, BasicBatchCluster<T>> clusterMove = new ClusterMove<>();
        clusterMove.bestCluster = predictLabelWithQuality.getBestLabel();
        clusterMove.voteProportion = predictLabelWithQuality.getBestVoteProportion();
        clusterMove.secondBestVoteProportion = predictLabelWithQuality.getSecondBestVoteProportion();
        clusterMove.bestDistance = predictLabelWithQuality.getBestOneVsAllProbability();
        clusterMove.secondBestDistance = predictLabelWithQuality.getSecondBestOneVsAllProbability();
        if (clusterMove.bestCluster == null) {
            throw new NoGoodClusterException();
        }
        return clusterMove;
    }
}
