package edu.berkeley.compbio.jlibsvm.binary;

import com.davidsoergel.conja.Function;
import com.davidsoergel.conja.Parallel;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterGrid;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterPoint;
import edu.berkeley.compbio.jlibsvm.SVM;
import edu.berkeley.compbio.jlibsvm.SvmException;
import java.lang.Comparable;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:lib/jlibsvm-0.911.jar:edu/berkeley/compbio/jlibsvm/binary/BinaryClassificationSVM.class */
public abstract class BinaryClassificationSVM<L extends Comparable, P> extends SVM<L, P, BinaryClassificationProblem<L, P>> {
    private static final Logger logger = Logger.getLogger(BinaryClassificationSVM.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/jlibsvm-0.911.jar:edu/berkeley/compbio/jlibsvm/binary/BinaryClassificationSVM$GridTrainingResult.class */
    public class GridTrainingResult {
        ImmutableSvmParameterPoint<L, P> bestParam;
        SvmBinaryCrossValidationResults<L, P> bestCrossValidationResults;
        float bestSensitivity;

        private GridTrainingResult() {
            this.bestParam = null;
            this.bestCrossValidationResults = null;
            this.bestSensitivity = -1.0f;
        }

        synchronized void update(ImmutableSvmParameterPoint<L, P> immutableSvmParameterPoint, SvmBinaryCrossValidationResults<L, P> svmBinaryCrossValidationResults) {
            float classNormalizedSensitivity = svmBinaryCrossValidationResults.classNormalizedSensitivity();
            if (classNormalizedSensitivity > this.bestSensitivity) {
                this.bestParam = immutableSvmParameterPoint;
                this.bestSensitivity = classNormalizedSensitivity;
                this.bestCrossValidationResults = svmBinaryCrossValidationResults;
            }
        }
    }

    @Override // edu.berkeley.compbio.jlibsvm.SVM
    public BinaryModel<L, P> train(@NotNull BinaryClassificationProblem<L, P> binaryClassificationProblem, @NotNull ImmutableSvmParameter<L, P> immutableSvmParameter) {
        validateParam(immutableSvmParameter);
        return immutableSvmParameter instanceof ImmutableSvmParameterGrid ? trainGrid(binaryClassificationProblem, (ImmutableSvmParameterGrid) immutableSvmParameter) : immutableSvmParameter.probability ? trainScaledWithCV(binaryClassificationProblem, (ImmutableSvmParameterPoint) immutableSvmParameter) : trainScaled(binaryClassificationProblem, (ImmutableSvmParameterPoint) immutableSvmParameter);
    }

    private BinaryModel<L, P> trainGrid(@NotNull final BinaryClassificationProblem<L, P> binaryClassificationProblem, @NotNull ImmutableSvmParameterGrid<L, P> immutableSvmParameterGrid) {
        final GridTrainingResult gridTrainingResult = new GridTrainingResult();
        Parallel.forEach(immutableSvmParameterGrid.getGridParams(), new Function<ImmutableSvmParameterPoint<L, P>, Void>() { // from class: edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationSVM.1
            @Override // com.davidsoergel.conja.Function
            public Void apply(ImmutableSvmParameterPoint<L, P> immutableSvmParameterPoint) {
                SvmBinaryCrossValidationResults<L, P> performCrossValidation = BinaryClassificationSVM.this.performCrossValidation((BinaryClassificationProblem) binaryClassificationProblem, (ImmutableSvmParameter) immutableSvmParameterPoint);
                BinaryClassificationSVM.logger.info("CV results for grid point " + immutableSvmParameterPoint + ": " + performCrossValidation);
                gridTrainingResult.update(immutableSvmParameterPoint, performCrossValidation);
                return null;
            }
        });
        logger.info("Chose grid point: " + gridTrainingResult.bestParam);
        BinaryModel<L, P> trainScaled = trainScaled(binaryClassificationProblem, gridTrainingResult.bestParam);
        synchronized (gridTrainingResult) {
            trainScaled.crossValidationResults = gridTrainingResult.bestCrossValidationResults;
        }
        return trainScaled;
    }

    private BinaryModel<L, P> trainScaledWithCV(@NotNull BinaryClassificationProblem<L, P> binaryClassificationProblem, @NotNull ImmutableSvmParameterPoint<L, P> immutableSvmParameterPoint) {
        SvmBinaryCrossValidationResults<L, P> svmBinaryCrossValidationResults = null;
        try {
            svmBinaryCrossValidationResults = performCrossValidation((BinaryClassificationProblem) binaryClassificationProblem, (ImmutableSvmParameter) immutableSvmParameterPoint);
        } catch (SvmException e) {
            logger.debug("Could not perform cross-validation", e);
        }
        BinaryModel<L, P> trainScaled = trainScaled(binaryClassificationProblem, immutableSvmParameterPoint);
        trainScaled.crossValidationResults = svmBinaryCrossValidationResults;
        trainScaled.printSolutionInfo(binaryClassificationProblem);
        return trainScaled;
    }

    @Override // edu.berkeley.compbio.jlibsvm.SVM
    public SvmBinaryCrossValidationResults<L, P> performCrossValidation(@NotNull BinaryClassificationProblem<L, P> binaryClassificationProblem, @NotNull ImmutableSvmParameter<L, P> immutableSvmParameter) {
        return new SvmBinaryCrossValidationResults<>(binaryClassificationProblem, continuousCrossValidation(binaryClassificationProblem, (ImmutableSvmParameterPoint) immutableSvmParameter.noProbabilityCopy()), immutableSvmParameter.probability);
    }

    protected abstract BinaryModel<L, P> trainOne(@NotNull BinaryClassificationProblem<L, P> binaryClassificationProblem, float f, float f2, @NotNull ImmutableSvmParameterPoint<L, P> immutableSvmParameterPoint);

    private BinaryModel<L, P> trainScaled(@NotNull BinaryClassificationProblem<L, P> binaryClassificationProblem, @NotNull ImmutableSvmParameterPoint<L, P> immutableSvmParameterPoint) {
        if (immutableSvmParameterPoint.scalingModelLearner != null && immutableSvmParameterPoint.scaleBinaryMachinesIndependently) {
            binaryClassificationProblem = binaryClassificationProblem.getScaledCopy(immutableSvmParameterPoint.scalingModelLearner);
        }
        BinaryModel<L, P> trainWeighted = trainWeighted(binaryClassificationProblem, immutableSvmParameterPoint);
        trainWeighted.printSolutionInfo(binaryClassificationProblem);
        return trainWeighted;
    }

    private BinaryModel<L, P> trainWeighted(@NotNull BinaryClassificationProblem<L, P> binaryClassificationProblem, @NotNull ImmutableSvmParameterPoint<L, P> immutableSvmParameterPoint) {
        float f = immutableSvmParameterPoint.C;
        float f2 = immutableSvmParameterPoint.C;
        if (immutableSvmParameterPoint.redistributeUnbalancedC) {
            Float weight = immutableSvmParameterPoint.getWeight(binaryClassificationProblem.getTrueLabel());
            if (weight != null) {
                f *= weight.floatValue();
            }
            Float weight2 = immutableSvmParameterPoint.getWeight(binaryClassificationProblem.getFalseLabel());
            if (weight2 != null) {
                f2 *= weight2.floatValue();
            }
        }
        return trainOne(binaryClassificationProblem, f, f2, immutableSvmParameterPoint);
    }
}
