package edu.berkeley.compbio.jlibsvm.multi;

import com.google.common.base.Function;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.MapMaker;
import edu.berkeley.compbio.jlibsvm.DiscreteModel;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter;
import edu.berkeley.compbio.jlibsvm.SolutionModel;
import edu.berkeley.compbio.jlibsvm.SvmException;
import edu.berkeley.compbio.jlibsvm.binary.BinaryModel;
import edu.berkeley.compbio.jlibsvm.kernel.KernelFunction;
import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModel;
import edu.berkeley.compbio.jlibsvm.scaler.ScalingModel;
import edu.berkeley.compbio.ml.MultiClassCrossValidationResults;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.IOException;
import java.lang.Comparable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import org.apache.log4j.Logger;
import org.apache.lucene.util.packed.PackedInts;
import org.jetbrains.annotations.NotNull;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;

/* loaded from: input_file:BOOT-INF/lib/jlibsvm-0.911.jar:edu/berkeley/compbio/jlibsvm/multi/MultiClassModel.class */
public class MultiClassModel<L extends Comparable, P> extends SolutionModel<L, P> implements DiscreteModel<L, P> {
    private static final Logger logger;
    private ScalingModel<P> scalingModel;
    private final OneVsAllMode oneVsAllMode;
    private final double oneVsAllThreshold;
    private final AllVsAllMode allVsAllMode;
    private final double minVoteProportion;
    private final Map<BinaryModel<L, P>, int[]> svIndexMaps;
    private final int numberOfClasses;
    private final MultiClassModel<L, P>.SymmetricHashMap2d<L, BinaryModel<L, P>> oneVsOneModels;
    private final HashMap<L, BinaryModel<L, P>> oneVsAllModels;
    private P[] allSVs;
    SvmMultiClassCrossValidationResults<L, P> crossValidationResults;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:BOOT-INF/lib/jlibsvm-0.911.jar:edu/berkeley/compbio/jlibsvm/multi/MultiClassModel$AllVsAllMode.class */
    public enum AllVsAllMode {
        None,
        AllVsAll,
        FilteredVsAll,
        FilteredVsFiltered
    }

    /* loaded from: input_file:BOOT-INF/lib/jlibsvm-0.911.jar:edu/berkeley/compbio/jlibsvm/multi/MultiClassModel$OneVsAllMode.class */
    public enum OneVsAllMode {
        None,
        Best,
        Veto,
        BreakTies,
        VetoAndBreakTies
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/jlibsvm-0.911.jar:edu/berkeley/compbio/jlibsvm/multi/MultiClassModel$SymmetricHashMap2d.class */
    public class SymmetricHashMap2d<K extends Comparable, V> {
        HashMap<K, Map<K, V>> l1Map;
        private int sizePerDimension;

        public boolean isEmpty() {
            return this.l1Map.isEmpty();
        }

        public SymmetricHashMap2d(MultiClassModel multiClassModel, MultiClassModel<L, P>.SymmetricHashMap2d<K, V> symmetricHashMap2d, Collection<K> collection) {
            this(symmetricHashMap2d.sizePerDimension);
            for (Map.Entry<K, Map<K, V>> entry : symmetricHashMap2d.l1Map.entrySet()) {
                K key = entry.getKey();
                if (!collection.contains(key)) {
                    HashMap hashMap = new HashMap(this.sizePerDimension);
                    for (Map.Entry<K, V> entry2 : entry.getValue().entrySet()) {
                        K key2 = entry2.getKey();
                        if (!collection.contains(key2)) {
                            hashMap.put(key2, entry2.getValue());
                        }
                    }
                    this.l1Map.put(key, hashMap);
                }
            }
        }

        public SymmetricHashMap2d(int i) {
            this.sizePerDimension = i;
            this.l1Map = new HashMap<>(i);
        }

        V get(K k, K k2) {
            if (k.compareTo(k2) > 0) {
                k = k2;
                k2 = k;
            }
            Map<K, V> map = this.l1Map.get(k);
            if (map == null) {
                map = new HashMap(this.sizePerDimension);
                this.l1Map.put(k, map);
            }
            return map.get(k2);
        }

        public Set<K> keySet() {
            HashSet hashSet = new HashSet();
            hashSet.addAll(this.l1Map.keySet());
            if (!this.l1Map.isEmpty()) {
                hashSet.addAll(this.l1Map.values().iterator().next().keySet());
            }
            return hashSet;
        }

        public void put(K k, K k2, V v) {
            if (k.compareTo(k2) > 0) {
                k = k2;
                k2 = k;
            }
            Map<K, V> map = this.l1Map.get(k);
            if (map == null) {
                map = new HashMap();
                this.l1Map.put(k, map);
            }
            map.put(k2, v);
        }

        public Iterable<V> values() {
            return new Iterable<V>() { // from class: edu.berkeley.compbio.jlibsvm.multi.MultiClassModel.SymmetricHashMap2d.1
                @Override // java.lang.Iterable
                public Iterator<V> iterator() {
                    return SymmetricHashMap2d.this.valueIterator();
                }
            };
        }

        public Iterator<V> valueIterator() {
            return new Iterator<V>() { // from class: edu.berkeley.compbio.jlibsvm.multi.MultiClassModel.SymmetricHashMap2d.2
                Iterator<K> k1iter;
                Iterator<V> l2iter = null;

                {
                    this.k1iter = SymmetricHashMap2d.this.l1Map.keySet().iterator();
                }

                @Override // java.util.Iterator
                public boolean hasNext() {
                    return (this.l2iter != null && this.l2iter.hasNext()) || this.k1iter.hasNext();
                }

                @Override // java.util.Iterator
                public V next() {
                    if (this.l2iter == null || !this.l2iter.hasNext()) {
                        if (!this.k1iter.hasNext()) {
                            return null;
                        }
                        this.l2iter = SymmetricHashMap2d.this.l1Map.get(this.k1iter.next()).values().iterator();
                    }
                    return this.l2iter.next();
                }

                @Override // java.util.Iterator
                public void remove() {
                    throw new UnsupportedOperationException();
                }
            };
        }
    }

    @Override // edu.berkeley.compbio.jlibsvm.SolutionModel
    public MultiClassCrossValidationResults<L> getCrossValidationResults() {
        return this.crossValidationResults;
    }

    public MultiClassModel(MultiClassModel<L, P> multiClassModel, Collection<L> collection) {
        this.scalingModel = new NoopScalingModel();
        this.allSVs = multiClassModel.allSVs;
        this.oneVsAllMode = multiClassModel.oneVsAllMode;
        this.oneVsAllThreshold = multiClassModel.oneVsAllThreshold;
        this.allVsAllMode = multiClassModel.allVsAllMode;
        this.minVoteProportion = multiClassModel.minVoteProportion;
        this.numberOfClasses = multiClassModel.numberOfClasses;
        this.svIndexMaps = multiClassModel.svIndexMaps;
        this.scalingModel = multiClassModel.scalingModel;
        this.oneVsOneModels = new SymmetricHashMap2d<>(this, multiClassModel.oneVsOneModels, collection);
        this.oneVsAllModels = new HashMap<>(multiClassModel.oneVsAllModels);
        Iterator<L> it = collection.iterator();
        while (it.hasNext()) {
            this.oneVsAllModels.remove(it.next());
        }
    }

    public MultiClassModel(ImmutableSvmParameter immutableSvmParameter, int i) {
        this.scalingModel = new NoopScalingModel();
        this.svIndexMaps = new HashMap();
        this.numberOfClasses = i;
        this.oneVsOneModels = new SymmetricHashMap2d<>(i);
        this.oneVsAllModels = new HashMap<>(i);
        this.oneVsAllThreshold = immutableSvmParameter.oneVsAllThreshold;
        this.oneVsAllMode = immutableSvmParameter.oneVsAllMode;
        this.allVsAllMode = immutableSvmParameter.allVsAllMode;
        this.minVoteProportion = immutableSvmParameter.minVoteProportion;
    }

    @NotNull
    public ScalingModel<P> getScalingModel() {
        return this.scalingModel;
    }

    public void setScalingModel(@NotNull ScalingModel<P> scalingModel) {
        this.scalingModel = scalingModel;
    }

    @Override // edu.berkeley.compbio.jlibsvm.DiscreteModel
    public L predictLabel(P p) {
        return predictLabelWithQuality(p).getBestLabel();
    }

    public L bestProbabilityLabel(Map<L, Float> map) {
        Float valueOf = Float.valueOf(PackedInts.COMPACT);
        L l = null;
        for (Map.Entry<L, Float> entry : map.entrySet()) {
            if (entry.getValue().floatValue() > valueOf.floatValue()) {
                l = entry.getKey();
                valueOf = entry.getValue();
            }
        }
        return l;
    }

    @NotNull
    public VotingResult<L> predictLabelWithQuality(P p) {
        final P scaledCopy = this.scalingModel.scaledCopy(p);
        Comparable comparable = null;
        float f = 0.0f;
        float f2 = 0.0f;
        ConcurrentMap makeComputingMap = new MapMaker().makeComputingMap(new Function<KernelFunction<P>, float[]>() { // from class: edu.berkeley.compbio.jlibsvm.multi.MultiClassModel.1
            /* JADX WARN: Multi-variable type inference failed */
            @Override // com.google.common.base.Function
            public float[] apply(@NotNull KernelFunction<P> kernelFunction) {
                float[] fArr = new float[MultiClassModel.this.allSVs.length];
                int i = 0;
                for (Object obj : MultiClassModel.this.allSVs) {
                    fArr[i] = (float) kernelFunction.evaluate(scaledCopy, obj);
                    i++;
                }
                return fArr;
            }
        });
        Map<L, Float> computeOneVsAllProbabilities = this.oneVsAllMode == OneVsAllMode.None ? null : computeOneVsAllProbabilities(makeComputingMap);
        if ((this.oneVsAllMode == OneVsAllMode.Veto || this.oneVsAllMode == OneVsAllMode.VetoAndBreakTies || this.oneVsAllMode == OneVsAllMode.Best) && computeOneVsAllProbabilities.isEmpty()) {
            return new VotingResult<>();
        }
        if (this.oneVsAllMode == OneVsAllMode.Best) {
            for (Map.Entry<L, Float> entry : computeOneVsAllProbabilities.entrySet()) {
                if (entry.getValue().floatValue() > f) {
                    f2 = f;
                    comparable = entry.getKey();
                    f = entry.getValue().floatValue();
                }
            }
            return new VotingResult<>(comparable, PackedInts.COMPACT, PackedInts.COMPACT, PackedInts.COMPACT, PackedInts.COMPACT, f, f2);
        }
        int size = this.oneVsOneModels.keySet().size();
        HashMultiset create = HashMultiset.create();
        if (this.allVsAllMode == AllVsAllMode.AllVsAll) {
            logger.debug("Sample voting using all pairs of " + size + " labels (" + (((size * (size - 1)) / 2.0d) - size) + " models)");
            for (BinaryModel<L, P> binaryModel : this.oneVsOneModels.values()) {
                create.add(binaryModel.predictLabel(makeComputingMap.get(binaryModel.param.kernel), this.svIndexMaps.get(binaryModel)));
            }
        } else {
            Set<L> keySet = computeOneVsAllProbabilities != null ? computeOneVsAllProbabilities.keySet() : this.oneVsOneModels.keySet();
            int i = this.allVsAllMode == AllVsAllMode.FilteredVsAll ? 1 : 2;
            int size2 = computeOneVsAllProbabilities != null ? computeOneVsAllProbabilities.size() : size;
            if (i == 1) {
                logger.debug("Sample voting with all " + size + " vs. " + size2 + " active labels (" + (((size * (size2 - 1)) / 2.0d) - size2) + " models)");
            } else {
                logger.debug("Sample voting using pairs of only " + size2 + " active labels (" + (((size2 * (size2 - 1)) / 2.0d) - size2) + " models)");
            }
            for (BinaryModel<L, P> binaryModel2 : this.oneVsOneModels.values()) {
                if ((keySet.contains(binaryModel2.getTrueLabel()) ? 1 : 0) + (keySet.contains(binaryModel2.getFalseLabel()) ? 1 : 0) >= i) {
                    create.add(binaryModel2.predictLabel((BinaryModel<L, P>) scaledCopy));
                }
            }
        }
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        for (E e : create.elementSet()) {
            int count = create.count(e);
            i4 += count;
            Float valueOf = Float.valueOf(1.0f);
            if (this.oneVsAllMode == OneVsAllMode.Veto || this.oneVsAllMode == OneVsAllMode.VetoAndBreakTies) {
                Float f3 = computeOneVsAllProbabilities.get(e);
                valueOf = Float.valueOf(f3 == null ? PackedInts.COMPACT : f3.floatValue());
            }
            if (count > i2 || (count == i2 && valueOf.floatValue() > f)) {
                i3 = i2;
                f2 = f;
                comparable = e;
                i2 = count;
                f = valueOf.floatValue();
            }
        }
        double d = i2 / i4;
        return d < this.minVoteProportion ? new VotingResult<>() : ((this.oneVsAllMode == OneVsAllMode.VetoAndBreakTies || this.oneVsAllMode == OneVsAllMode.Veto) && ((double) f) < this.oneVsAllThreshold) ? new VotingResult<>() : new VotingResult<>(comparable, (float) d, (float) (i3 / i4), PackedInts.COMPACT, PackedInts.COMPACT, f, f2);
    }

    public Map<L, Float> computeOneVsAllProbabilities(Map<KernelFunction<P>, float[]> map) {
        HashMap hashMap = new HashMap();
        for (BinaryModel<L, P> binaryModel : this.oneVsAllModels.values()) {
            float trueProbability = binaryModel.getTrueProbability(map.get(binaryModel.param.kernel), this.svIndexMaps.get(binaryModel));
            if (trueProbability >= this.oneVsAllThreshold) {
                hashMap.put(binaryModel.getTrueLabel(), Float.valueOf(trueProbability));
            }
        }
        return hashMap;
    }

    public Map<L, Float> predictProbability(P p) {
        if (!supportsOneVsOneProbability()) {
            throw new SvmException("Can't make probability predictions");
        }
        float[][] fArr = new float[this.numberOfClasses][this.numberOfClasses];
        ArrayList arrayList = new ArrayList(this.oneVsOneModels.keySet());
        if (!$assertionsDisabled && arrayList.size() != this.numberOfClasses) {
            throw new AssertionError();
        }
        for (int i = 0; i < this.numberOfClasses; i++) {
            Comparable comparable = (Comparable) arrayList.get(i);
            for (int i2 = i + 1; i2 < this.numberOfClasses; i2++) {
                BinaryModel<L, P> binaryModel = this.oneVsOneModels.get(comparable, (Comparable) arrayList.get(i2));
                if (binaryModel == null) {
                    fArr[i][i2] = 0.0f;
                    fArr[i2][i] = 0.0f;
                } else {
                    fArr[i][i2] = Math.min(Math.max(binaryModel.crossValidationResults.getSigmoid().predict(binaryModel.predictValue(p).floatValue()), 1.0E-7f), 1.0f - 1.0E-7f);
                    fArr[i2][i] = 1.0f - fArr[i][i2];
                }
            }
        }
        float[] multiclassProbability = multiclassProbability(this.numberOfClasses, fArr);
        HashMap hashMap = new HashMap();
        int i3 = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            hashMap.put((Comparable) it.next(), Float.valueOf(multiclassProbability[i3]));
            i3++;
        }
        return hashMap;
    }

    public boolean supportsOneVsOneProbability() {
        return this.oneVsOneModels.valueIterator().next().crossValidationResults != null;
    }

    private float[] multiclassProbability(int i, float[][] fArr) {
        float[] fArr2 = new float[i];
        int max = Math.max(100, i);
        float[][] fArr3 = new float[i][i];
        float[] fArr4 = new float[i];
        float f = 0.005f / i;
        for (int i2 = 0; i2 < i; i2++) {
            fArr2[i2] = 1.0f / i;
            fArr3[i2][i2] = 0.0f;
            for (int i3 = 0; i3 < i2; i3++) {
                float[] fArr5 = fArr3[i2];
                int i4 = i2;
                fArr5[i4] = fArr5[i4] + (fArr[i3][i2] * fArr[i3][i2]);
                fArr3[i2][i3] = fArr3[i3][i2];
            }
            for (int i5 = i2 + 1; i5 < i; i5++) {
                float[] fArr6 = fArr3[i2];
                int i6 = i2;
                fArr6[i6] = fArr6[i6] + (fArr[i5][i2] * fArr[i5][i2]);
                fArr3[i2][i5] = (-fArr[i5][i2]) * fArr[i2][i5];
            }
        }
        int i7 = 0;
        while (i7 < max) {
            float f2 = 0.0f;
            for (int i8 = 0; i8 < i; i8++) {
                fArr4[i8] = 0.0f;
                for (int i9 = 0; i9 < i; i9++) {
                    int i10 = i8;
                    fArr4[i10] = fArr4[i10] + (fArr3[i8][i9] * fArr2[i9]);
                }
                f2 += fArr2[i8] * fArr4[i8];
            }
            float f3 = 0.0f;
            for (int i11 = 0; i11 < i; i11++) {
                float abs = Math.abs(fArr4[i11] - f2);
                if (abs > f3) {
                    f3 = abs;
                }
            }
            if (f3 < f) {
                break;
            }
            for (int i12 = 0; i12 < i; i12++) {
                float f4 = ((-fArr4[i12]) + f2) / fArr3[i12][i12];
                int i13 = i12;
                fArr2[i13] = fArr2[i13] + f4;
                f2 = ((f2 + (f4 * ((f4 * fArr3[i12][i12]) + (2.0f * fArr4[i12])))) / (1.0f + f4)) / (1.0f + f4);
                for (int i14 = 0; i14 < i; i14++) {
                    fArr4[i14] = (fArr4[i14] + (f4 * fArr3[i12][i14])) / (1.0f + f4);
                    int i15 = i14;
                    fArr2[i15] = fArr2[i15] / (1.0f + f4);
                }
            }
            i7++;
        }
        if (i7 >= max) {
            logger.error("Multiclass probability attempted too many iterations");
        }
        return fArr2;
    }

    public void prepareModelSvMaps() {
        int i = 0;
        HashMap hashMap = new HashMap();
        for (BinaryModel<L, P> binaryModel : this.oneVsAllModels.values()) {
            int[] iArr = new int[binaryModel.SVs.length];
            int i2 = 0;
            for (P p : binaryModel.SVs) {
                Integer num = (Integer) hashMap.get(p);
                if (num == null) {
                    num = Integer.valueOf(i);
                    hashMap.put(p, num);
                    i++;
                }
                iArr[i2] = num.intValue();
                i2++;
            }
            this.svIndexMaps.put(binaryModel, iArr);
        }
        for (BinaryModel<L, P> binaryModel2 : this.oneVsOneModels.values()) {
            int[] iArr2 = new int[binaryModel2.SVs.length];
            int i3 = 0;
            for (P p2 : binaryModel2.SVs) {
                Integer num2 = (Integer) hashMap.get(p2);
                if (num2 == null) {
                    num2 = Integer.valueOf(i);
                    hashMap.put(p2, num2);
                    i++;
                }
                iArr2[i3] = num2.intValue();
                i3++;
            }
            this.svIndexMaps.put(binaryModel2, iArr2);
        }
        this.allSVs = (P[]) new Object[i];
        for (Map.Entry entry : hashMap.entrySet()) {
            ((P[]) this.allSVs)[((Integer) entry.getValue()).intValue()] = entry.getKey();
        }
    }

    public synchronized void putOneVsAllModel(L l, BinaryModel<L, P> binaryModel) {
        this.oneVsAllModels.put(l, binaryModel);
    }

    public synchronized void putOneVsOneModel(L l, L l2, BinaryModel<L, P> binaryModel) {
        this.oneVsOneModels.put(l, l2, binaryModel);
    }

    @Override // edu.berkeley.compbio.jlibsvm.SolutionModel
    protected void readSupportVectors(BufferedReader bufferedReader) {
        throw new UnsupportedOperationException();
    }

    protected void writeSupportVectors(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeBytes("SV\n");
        dataOutputStream.writeBytes("Saving multi-class support vectors is not implemented yet");
    }

    @Override // edu.berkeley.compbio.jlibsvm.SolutionModel
    public void writeToStream(DataOutputStream dataOutputStream) throws IOException {
        throw new NotImplementedException();
    }

    public String getInfo() {
        if (this.crossValidationResults != null) {
            return this.crossValidationResults.getInfo();
        }
        StringBuffer stringBuffer = new StringBuffer();
        if (this.oneVsAllMode != OneVsAllMode.None) {
            HashMultiset create = HashMultiset.create();
            HashMultiset create2 = HashMultiset.create();
            for (BinaryModel<L, P> binaryModel : this.oneVsAllModels.values()) {
                create.add(Float.valueOf(binaryModel.param.C));
                create2.add(binaryModel.param.kernel);
            }
            stringBuffer.append("OneVsAll:C=" + create + "; gamma=" + create2 + "   ");
        }
        if (this.allVsAllMode != AllVsAllMode.None) {
            HashMultiset create3 = HashMultiset.create();
            HashMultiset create4 = HashMultiset.create();
            for (BinaryModel<L, P> binaryModel2 : this.oneVsOneModels.values()) {
                create3.add(Float.valueOf(binaryModel2.param.C));
                create4.add(binaryModel2.param.kernel);
            }
            stringBuffer.append("AllVsAll:C=" + create3 + "; gamma=" + create4 + "   ");
        }
        return stringBuffer.toString();
    }

    @Override // edu.berkeley.compbio.jlibsvm.SolutionModel
    public Collection<L> getLabels() {
        if (this.oneVsOneModels != null && !this.oneVsOneModels.isEmpty()) {
            return this.oneVsOneModels.keySet();
        }
        if (this.oneVsAllModels == null || this.oneVsAllModels.isEmpty()) {
            throw new SvmException("Can't get labels from a MultiClassModel with no subsidiary BinaryModels");
        }
        return this.oneVsAllModels.keySet();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.berkeley.compbio.jlibsvm.DiscreteModel
    public /* bridge */ /* synthetic */ Object predictLabel(Object obj) {
        return predictLabel((MultiClassModel<L, P>) obj);
    }

    static {
        $assertionsDisabled = !MultiClassModel.class.desiredAssertionStatus();
        logger = Logger.getLogger(MultiClassModel.class);
    }
}
