package edu.berkeley.compbio.jlibsvm.multi;

import com.davidsoergel.conja.Function;
import com.davidsoergel.conja.Parallel;
import com.davidsoergel.dsutils.collections.UnorderedPair;
import com.davidsoergel.dsutils.collections.UnorderedPairIterator;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterGrid;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterPoint;
import edu.berkeley.compbio.jlibsvm.SVM;
import edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationProblem;
import edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationSVM;
import edu.berkeley.compbio.jlibsvm.binary.BooleanClassificationProblemImpl;
import edu.berkeley.compbio.jlibsvm.labelinverter.LabelInverter;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel;
import edu.berkeley.compbio.jlibsvm.util.SubtractionMap;
import java.lang.Comparable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:BOOT-INF/lib/jlibsvm-0.911.jar:edu/berkeley/compbio/jlibsvm/multi/MultiClassificationSVM.class */
public class MultiClassificationSVM<L extends Comparable<L>, P> extends SVM<L, P, MultiClassProblem<L, P>> {
    private static final Logger logger = Logger.getLogger(MultiClassificationSVM.class);
    private BinaryClassificationSVM<L, P> binarySvm;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/jlibsvm-0.911.jar:edu/berkeley/compbio/jlibsvm/multi/MultiClassificationSVM$GridTrainingResult.class */
    public class GridTrainingResult {
        SvmMultiClassCrossValidationResults<L, P> bestCrossValidationResults;
        float bestSensitivity;

        private GridTrainingResult() {
            this.bestCrossValidationResults = null;
            this.bestSensitivity = -1.0f;
        }

        synchronized void update(SvmMultiClassCrossValidationResults<L, P> svmMultiClassCrossValidationResults) {
            float classNormalizedSensitivity = svmMultiClassCrossValidationResults.classNormalizedSensitivity();
            if (classNormalizedSensitivity > this.bestSensitivity) {
                this.bestSensitivity = classNormalizedSensitivity;
                this.bestCrossValidationResults = svmMultiClassCrossValidationResults;
            }
        }
    }

    public MultiClassificationSVM(BinaryClassificationSVM<L, P> binaryClassificationSVM) {
        this.binarySvm = binaryClassificationSVM;
    }

    @Override // edu.berkeley.compbio.jlibsvm.SVM
    public String getSvmType() {
        return "multiclass " + this.binarySvm.getSvmType();
    }

    @Override // edu.berkeley.compbio.jlibsvm.SVM
    public SvmMultiClassCrossValidationResults<L, P> performCrossValidation(@NotNull MultiClassProblem<L, P> multiClassProblem, @NotNull ImmutableSvmParameter<L, P> immutableSvmParameter) {
        SvmMultiClassCrossValidationResults<L, P> svmMultiClassCrossValidationResults = new SvmMultiClassCrossValidationResults<>(multiClassProblem, discreteCrossValidation(multiClassProblem, immutableSvmParameter));
        svmMultiClassCrossValidationResults.param = immutableSvmParameter;
        return svmMultiClassCrossValidationResults;
    }

    @Override // edu.berkeley.compbio.jlibsvm.SVM
    public MultiClassModel<L, P> train(@NotNull MultiClassProblem<L, P> multiClassProblem, @NotNull ImmutableSvmParameter<L, P> immutableSvmParameter) {
        validateParam(immutableSvmParameter);
        return (!(immutableSvmParameter instanceof ImmutableSvmParameterGrid) || immutableSvmParameter.gridsearchBinaryMachinesIndependently) ? trainScaled(multiClassProblem, immutableSvmParameter) : trainGrid(multiClassProblem, (ImmutableSvmParameterGrid) immutableSvmParameter);
    }

    public MultiClassModel<L, P> trainGrid(@NotNull final MultiClassProblem<L, P> multiClassProblem, @NotNull ImmutableSvmParameterGrid<L, P> immutableSvmParameterGrid) {
        final GridTrainingResult gridTrainingResult = new GridTrainingResult();
        Parallel.forEach(immutableSvmParameterGrid.getGridParams(), new Function<ImmutableSvmParameterPoint<L, P>, Void>() { // from class: edu.berkeley.compbio.jlibsvm.multi.MultiClassificationSVM.1
            @Override // com.davidsoergel.conja.Function
            public Void apply(ImmutableSvmParameterPoint<L, P> immutableSvmParameterPoint) {
                gridTrainingResult.update(MultiClassificationSVM.this.performCrossValidation((MultiClassProblem) multiClassProblem, (ImmutableSvmParameter) immutableSvmParameterPoint));
                return null;
            }
        });
        logger.info("Chose grid point: " + gridTrainingResult.bestCrossValidationResults.param);
        MultiClassModel<L, P> trainScaled = trainScaled(multiClassProblem, gridTrainingResult.bestCrossValidationResults.param);
        trainScaled.crossValidationResults = gridTrainingResult.bestCrossValidationResults;
        return trainScaled;
    }

    public MultiClassModel<L, P> trainScaled(@NotNull MultiClassProblem<L, P> multiClassProblem, @NotNull ImmutableSvmParameter<L, P> immutableSvmParameter) {
        return (immutableSvmParameter.scalingModelLearner == null || immutableSvmParameter.scaleBinaryMachinesIndependently) ? trainWithoutScaling(multiClassProblem, immutableSvmParameter) : trainWithoutScaling(multiClassProblem.getScaledCopy(immutableSvmParameter.scalingModelLearner), immutableSvmParameter);
    }

    private MultiClassModel<L, P> trainWithoutScaling(@NotNull final MultiClassProblem<L, P> multiClassProblem, @NotNull final ImmutableSvmParameter<L, P> immutableSvmParameter) {
        int size = multiClassProblem.getLabels().size();
        final MultiClassModel<L, P> multiClassModel = new MultiClassModel<>(immutableSvmParameter, size);
        multiClassModel.setScalingModel(multiClassProblem.getScalingModel());
        final Map<L, Set<P>> examplesByLabel = multiClassProblem.getExamplesByLabel();
        if (immutableSvmParameter.oneVsAllMode != MultiClassModel.OneVsAllMode.None) {
            final ImmutableSvmParameter<L, P> withProbabilityCopy = immutableSvmParameter.withProbabilityCopy();
            logger.info("Training one-vs-all classifiers for " + size + " labels");
            final LabelInverter<L> labelInverter = multiClassProblem.getLabelInverter();
            Parallel.forEach(multiClassProblem.getLabels(), new Function<L, Void>() { // from class: edu.berkeley.compbio.jlibsvm.multi.MultiClassificationSVM.2
                @Override // com.davidsoergel.conja.Function
                public Void apply(L l) {
                    Comparable comparable = (Comparable) labelInverter.invert(l);
                    Set set = (Set) examplesByLabel.get(l);
                    Collection entrySet = multiClassProblem.getExamples().entrySet();
                    if (immutableSvmParameter.falseClassSVlimit != Integer.MAX_VALUE) {
                        ArrayList arrayList = new ArrayList(entrySet);
                        Collections.shuffle(arrayList);
                        entrySet = arrayList.subList(0, Math.min(immutableSvmParameter.falseClassSVlimit + set.size(), arrayList.size()));
                    }
                    multiClassModel.putOneVsAllModel(l, MultiClassificationSVM.this.binarySvm.train((BinaryClassificationProblem) new BooleanClassificationProblemImpl(multiClassProblem.getLabelClass(), l, set, comparable, new SubtractionMap(entrySet, set, immutableSvmParameter.falseClassSVlimit).keySet(), multiClassProblem.getExampleIds()), withProbabilityCopy));
                    return null;
                }
            });
        }
        if (immutableSvmParameter.allVsAllMode != MultiClassModel.AllVsAllMode.None) {
            logger.info("Training " + ((size * (size - 1)) / 2) + " one-vs-one classifiers for " + size + " labels");
            Parallel.forEach(new UnorderedPairIterator(multiClassProblem.getLabels(), multiClassProblem.getLabels()), new Function<UnorderedPair<L>, Void>() { // from class: edu.berkeley.compbio.jlibsvm.multi.MultiClassificationSVM.3
                @Override // com.davidsoergel.conja.Function
                public Void apply(UnorderedPair<L> unorderedPair) {
                    L key1 = unorderedPair.getKey1();
                    L key2 = unorderedPair.getKey2();
                    multiClassModel.putOneVsOneModel(key1, key2, MultiClassificationSVM.this.binarySvm.train((BinaryClassificationProblem) new BooleanClassificationProblemImpl(multiClassProblem.getLabelClass(), key1, (Set) examplesByLabel.get(key1), key2, (Set) examplesByLabel.get(key2), multiClassProblem.getExampleIds()), immutableSvmParameter));
                    return null;
                }
            });
        }
        multiClassModel.prepareModelSvMaps();
        return multiClassModel;
    }
}
