package edu.berkeley.compbio.jlibsvm.binary;

import edu.berkeley.compbio.jlibsvm.ContinuousModel;
import edu.berkeley.compbio.jlibsvm.DiscreteModel;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterPoint;
import edu.berkeley.compbio.jlibsvm.LabelParser;
import edu.berkeley.compbio.jlibsvm.SvmException;
import edu.berkeley.compbio.jlibsvm.kernel.KernelFunction;
import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModel;
import edu.berkeley.compbio.jlibsvm.scaler.ScalingModel;
import java.io.DataOutputStream;
import java.io.IOException;
import java.lang.Comparable;
import java.util.Collection;
import java.util.Iterator;
import java.util.Properties;
import java.util.StringTokenizer;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.jena.atlas.json.io.JSWriter;
import org.apache.log4j.Logger;
import org.apache.lucene.util.packed.PackedInts;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:lib/jlibsvm-0.911.jar:edu/berkeley/compbio/jlibsvm/binary/BinaryModel.class */
public class BinaryModel<L extends Comparable, P> extends AlphaModel<L, P> implements DiscreteModel<L, P>, ContinuousModel<P> {
    public ImmutableSvmParameterPoint<L, P> param;
    private static final Logger logger = Logger.getLogger(BinaryModel.class);
    public float obj;
    public float upperBoundPositive;
    public float upperBoundNegative;
    public ScalingModel<P> scalingModel = new NoopScalingModel();
    public float r;
    public SvmBinaryCrossValidationResults<L, P> crossValidationResults;
    L trueLabel;
    L falseLabel;

    @Override // edu.berkeley.compbio.jlibsvm.SolutionModel
    public SvmBinaryCrossValidationResults<L, P> getCrossValidationResults() {
        return this.crossValidationResults;
    }

    @Override // edu.berkeley.compbio.jlibsvm.SolutionModel
    public Collection<L> getLabels() {
        return this.param.getLabels();
    }

    @Override // edu.berkeley.compbio.jlibsvm.SolutionModel
    public String getKernelName() {
        return this.param.kernel.toString();
    }

    public BinaryModel() {
    }

    public BinaryModel(Properties properties, LabelParser<L> labelParser) {
        ImmutableSvmParameterPoint.Builder builder = new ImmutableSvmParameterPoint.Builder();
        try {
            builder.kernel = (KernelFunction) Class.forName(properties.getProperty("kernel_type")).getConstructor(Properties.class).newInstance(properties);
            StringTokenizer stringTokenizer = new StringTokenizer(properties.getProperty("label"));
            while (stringTokenizer.hasMoreTokens()) {
                builder.putWeight(labelParser.parse(stringTokenizer.nextToken()), null);
            }
            this.rho = Float.parseFloat(properties.getProperty("rho"));
            this.numSVs = Integer.parseInt(properties.getProperty("total_sv"));
            this.trueLabel = "true";
            this.falseLabel = "false";
            this.param = builder.build();
        } catch (Throwable th) {
            throw new SvmException(th);
        }
    }

    public BinaryModel(ImmutableSvmParameterPoint<L, P> immutableSvmParameterPoint) {
        this.param = immutableSvmParameterPoint;
    }

    public L getFalseLabel() {
        return this.falseLabel;
    }

    @NotNull
    public ScalingModel<P> getScalingModel() {
        return this.scalingModel;
    }

    public void setScalingModel(@NotNull ScalingModel<P> scalingModel) {
        this.scalingModel = scalingModel;
    }

    public L getTrueLabel() {
        return this.trueLabel;
    }

    @Override // edu.berkeley.compbio.jlibsvm.DiscreteModel
    public L predictLabel(P p) {
        return predictValue(p).floatValue() > PackedInts.COMPACT ? this.trueLabel : this.falseLabel;
    }

    public float getSumAlpha() {
        float f = 0.0f;
        Iterator<Double> it = this.supportVectors.values().iterator();
        while (it.hasNext()) {
            f = (float) (f + it.next().doubleValue());
        }
        return f;
    }

    public float getTrueProbability(P p) {
        return this.crossValidationResults.sigmoid.predict(predictValue(p).floatValue());
    }

    public float getProbability(P p, L l) {
        if (l.equals(this.trueLabel)) {
            return getTrueProbability(p);
        }
        if (l.equals(this.falseLabel)) {
            return 1.0f - getTrueProbability(p);
        }
        throw new SvmException("Can't compute probability: " + l + " is not one of the classes in this binary model (" + this.trueLabel + JSWriter.ArraySep + this.falseLabel + ")");
    }

    @Override // edu.berkeley.compbio.jlibsvm.ContinuousModel
    public Float predictValue(P p) {
        float f = 0.0f;
        P scaledCopy = this.scalingModel.scaledCopy(p);
        for (int i = 0; i < this.numSVs; i++) {
            f = (float) (f + (this.alphas[i] * ((float) this.param.kernel.evaluate(scaledCopy, this.SVs[i]))));
        }
        return Float.valueOf(f - this.rho);
    }

    public float getTrueProbability(float[] fArr, int[] iArr) {
        float floatValue = predictValue(fArr, iArr).floatValue();
        if (this.crossValidationResults == null) {
            logger.error("Can't compute probability in binary model without crossvalidationresults");
            if (floatValue > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                return 1.0f;
            }
            return PackedInts.COMPACT;
        }
        if (this.crossValidationResults.sigmoid != null) {
            return this.crossValidationResults.sigmoid.predict(floatValue);
        }
        logger.error("Can't compute probability in binary model without sigmoid");
        if (floatValue > CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return 1.0f;
        }
        return PackedInts.COMPACT;
    }

    public Float predictValue(float[] fArr, int[] iArr) {
        float f = 0.0f;
        for (int i = 0; i < this.numSVs; i++) {
            f = (float) (f + (this.alphas[i] * fArr[iArr[i]]));
        }
        return Float.valueOf(f - this.rho);
    }

    public L predictLabel(float[] fArr, int[] iArr) {
        return predictValue(fArr, iArr).floatValue() > PackedInts.COMPACT ? this.trueLabel : this.falseLabel;
    }

    public void printSolutionInfo(BinaryClassificationProblem<L, P> binaryClassificationProblem) {
        if (logger.isDebugEnabled()) {
            logger.debug("obj = " + this.obj + ", rho = " + this.rho);
            int i = 0;
            for (int i2 = 0; i2 < this.numSVs; i2++) {
                Double valueOf = Double.valueOf(this.alphas[i2]);
                P p = this.SVs[i2];
                if (Math.abs(valueOf.doubleValue()) > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    if (binaryClassificationProblem.getTargetValue(p).equals(this.trueLabel)) {
                        if (Math.abs(valueOf.doubleValue()) >= this.upperBoundPositive) {
                            i++;
                        }
                    } else if (Math.abs(valueOf.doubleValue()) >= this.upperBoundNegative) {
                        i++;
                    }
                }
            }
            logger.debug("nSV = " + this.SVs.length + ", nBSV = " + i);
        }
    }

    @Override // edu.berkeley.compbio.jlibsvm.binary.AlphaModel, edu.berkeley.compbio.jlibsvm.SolutionModel
    public void writeToStream(DataOutputStream dataOutputStream) throws IOException {
        super.writeToStream(dataOutputStream);
        dataOutputStream.writeBytes("nr_class 2\n");
        writeSupportVectors(dataOutputStream);
        dataOutputStream.close();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.berkeley.compbio.jlibsvm.DiscreteModel
    public /* bridge */ /* synthetic */ Object predictLabel(Object obj) {
        return predictLabel((BinaryModel<L, P>) obj);
    }
}
