package edu.berkeley.compbio.jlibsvm.legacyexec;

import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterGrid;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterPoint;
import edu.berkeley.compbio.jlibsvm.MutableSvmProblem;
import edu.berkeley.compbio.jlibsvm.SVM;
import edu.berkeley.compbio.jlibsvm.SolutionModel;
import edu.berkeley.compbio.jlibsvm.SvmException;
import edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationProblem;
import edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationSVM;
import edu.berkeley.compbio.jlibsvm.binary.C_SVC;
import edu.berkeley.compbio.jlibsvm.binary.MutableBinaryClassificationProblemImpl;
import edu.berkeley.compbio.jlibsvm.binary.Nu_SVC;
import edu.berkeley.compbio.jlibsvm.kernel.GammaKernel;
import edu.berkeley.compbio.jlibsvm.kernel.GaussianRBFKernel;
import edu.berkeley.compbio.jlibsvm.kernel.LinearKernel;
import edu.berkeley.compbio.jlibsvm.kernel.PolynomialKernel;
import edu.berkeley.compbio.jlibsvm.kernel.PrecomputedKernel;
import edu.berkeley.compbio.jlibsvm.kernel.SigmoidKernel;
import edu.berkeley.compbio.jlibsvm.labelinverter.StringLabelInverter;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassificationSVM;
import edu.berkeley.compbio.jlibsvm.multi.MutableMultiClassProblemImpl;
import edu.berkeley.compbio.jlibsvm.oneclass.OneClassSVC;
import edu.berkeley.compbio.jlibsvm.regression.EpsilonSVR;
import edu.berkeley.compbio.jlibsvm.regression.MutableRegressionProblemImpl;
import edu.berkeley.compbio.jlibsvm.regression.Nu_SVR;
import edu.berkeley.compbio.jlibsvm.regression.RegressionSVM;
import edu.berkeley.compbio.jlibsvm.scaler.LinearScalingModelLearner;
import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModel;
import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModelLearner;
import edu.berkeley.compbio.jlibsvm.scaler.ScalingModelLearner;
import edu.berkeley.compbio.jlibsvm.scaler.ZscoreScalingModelLearner;
import edu.berkeley.compbio.jlibsvm.util.SparseVector;
import edu.berkeley.compbio.ml.CrossValidationResults;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashSet;
import java.util.Iterator;
import java.util.StringTokenizer;
import java.util.Vector;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;

/* loaded from: input_file:BOOT-INF/lib/jlibsvm-0.911.jar:edu/berkeley/compbio/jlibsvm/legacyexec/svm_train.class */
public class svm_train {
    static final int C_SVC = 0;
    static final int NU_SVC = 1;
    static final int ONE_CLASS = 2;
    static final int EPSILON_SVR = 3;
    static final int NU_SVR = 4;
    static final int LINEAR = 0;
    static final int POLY = 1;
    static final int RBF = 2;
    static final int SIGMOID = 3;
    static final int PRECOMPUTED = 4;
    SVM svm;
    ImmutableSvmParameter param;
    private MutableSvmProblem problem;
    private SolutionModel model;
    private String input_file_name;
    private String model_file_name;
    private boolean crossValidation;
    private static final Float UNSPECIFIED_GAMMA = Float.valueOf(-1.0f);

    public static void main(String[] strArr) throws IOException {
        new svm_train().run(strArr);
    }

    private void run(String[] strArr) throws IOException {
        parse_command_line(strArr);
        read_problem();
        long currentTimeMillis = System.currentTimeMillis();
        if ((this.svm instanceof BinaryClassificationSVM) && this.problem.getLabels().size() > 2) {
            this.svm = new MultiClassificationSVM((BinaryClassificationSVM) this.svm);
        }
        this.model = this.svm.train(this.problem, this.param);
        this.model.save(this.model_file_name);
        CrossValidationResults crossValidationResults = this.model.getCrossValidationResults();
        if (crossValidationResults == null && this.crossValidation) {
            crossValidationResults = this.svm.performCrossValidation(this.problem, this.param);
        }
        if (crossValidationResults != null) {
            System.out.println(crossValidationResults.toString());
        }
        System.out.println("Finished in " + (((float) (System.currentTimeMillis() - currentTimeMillis)) / 1000.0f) + " secs");
    }

    private void parse_command_line(String[] strArr) {
        ImmutableSvmParameterGrid.Builder builder = ImmutableSvmParameterGrid.builder();
        builder.nu = 0.5f;
        builder.cache_size = 100.0f;
        builder.eps = 0.001f;
        builder.p = 0.1f;
        builder.shrinking = true;
        builder.probability = false;
        builder.redistributeUnbalancedC = true;
        ScalingModelLearner noopScalingModelLearner = new NoopScalingModelLearner();
        String str = null;
        int i = 1000;
        boolean z = false;
        int i2 = 0;
        int i3 = 2;
        int i4 = 3;
        HashSet hashSet = new HashSet();
        float f = 0.0f;
        int i5 = 0;
        while (i5 < strArr.length && strArr[i5].charAt(0) == '-') {
            int i6 = i5 + 1;
            if (i6 >= strArr.length) {
                exit_with_help();
            }
            switch (strArr[i6 - 1].charAt(1)) {
                case 'a':
                    builder.allVsAllMode = MultiClassModel.AllVsAllMode.valueOf(strArr[i6]);
                    break;
                case 'b':
                    builder.probability = strArr[i6].equals("1") || Boolean.parseBoolean(strArr[i6]);
                    break;
                case 'c':
                    builder.Cset = new HashSet();
                    for (String str2 : strArr[i6].split(",")) {
                        builder.Cset.add(Float.valueOf(Float.parseFloat(str2)));
                    }
                    break;
                case 'd':
                    i4 = Integer.parseInt(strArr[i6]);
                    break;
                case 'e':
                    builder.eps = Float.parseFloat(strArr[i6]);
                    break;
                case 'f':
                    str = strArr[i6];
                    break;
                case 'g':
                    hashSet = new HashSet();
                    for (String str3 : strArr[i6].split(",")) {
                        hashSet.add(Float.valueOf(Float.parseFloat(str3)));
                    }
                    break;
                case 'h':
                    builder.shrinking = strArr[i6].equals("1") || Boolean.parseBoolean(strArr[i6]);
                    break;
                case 'i':
                default:
                    System.err.print("Unknown option: " + strArr[i6 - 1] + "\n");
                    exit_with_help();
                    break;
                case 'j':
                    builder.minVoteProportion = Double.parseDouble(strArr[i6]);
                    break;
                case 'k':
                    builder.oneVsAllThreshold = Double.parseDouble(strArr[i6]);
                    break;
                case 'l':
                    if (Integer.parseInt(strArr[i6]) == 2) {
                        z = true;
                        break;
                    } else {
                        System.err.print("-l must == 2\n");
                        exit_with_help();
                        break;
                    }
                case 'm':
                    builder.cache_size = Float.parseFloat(strArr[i6]);
                    break;
                case 'n':
                    builder.nu = Float.parseFloat(strArr[i6]);
                    break;
                case 'o':
                    builder.oneVsAllMode = MultiClassModel.OneVsAllMode.valueOf(strArr[i6]);
                    break;
                case 'p':
                    builder.p = Float.parseFloat(strArr[i6]);
                    break;
                case 'q':
                    this.crossValidation = strArr[i6].equals("1") || Boolean.parseBoolean(strArr[i6]);
                    break;
                case 'r':
                    f = Float.parseFloat(strArr[i6]);
                    break;
                case 's':
                    i2 = Integer.parseInt(strArr[i6]);
                    break;
                case 't':
                    i3 = Integer.parseInt(strArr[i6]);
                    break;
                case 'u':
                    builder.redistributeUnbalancedC = strArr[i6].equals("1") || Boolean.parseBoolean(strArr[i6]);
                    break;
                case 'v':
                    builder.crossValidationFolds = Integer.parseInt(strArr[i6]);
                    if (builder.crossValidationFolds < 2) {
                        System.err.print("n-fold cross validation: n must >= 2\n");
                        exit_with_help();
                        break;
                    } else {
                        break;
                    }
                case 'w':
                    builder.putWeight(Integer.valueOf(Integer.parseInt(strArr[i6 - 1].substring(2))), Float.valueOf(Float.parseFloat(strArr[i6])));
                    break;
                case 'x':
                    i = Integer.parseInt(strArr[i6]);
                    break;
                case 'y':
                    builder.gridsearchBinaryMachinesIndependently = strArr[i6].equals("1") || Boolean.parseBoolean(strArr[i6]);
                    break;
                case 'z':
                    builder.scaleBinaryMachinesIndependently = strArr[i6].equals("1") || Boolean.parseBoolean(strArr[i6]);
                    break;
            }
            i5 = i6 + 1;
        }
        if (str != null) {
            if (str.equals("linear")) {
                noopScalingModelLearner = new LinearScalingModelLearner(i, z);
            } else if (str.equals("zscore")) {
                noopScalingModelLearner = new ZscoreScalingModelLearner(i, z);
            }
        }
        if (i5 >= strArr.length) {
            exit_with_help();
        }
        this.input_file_name = strArr[i5];
        if (i5 < strArr.length - 1) {
            this.model_file_name = strArr[i5 + 1];
        } else {
            this.model_file_name = strArr[i5].substring(strArr[i5].lastIndexOf(47) + 1) + ".model";
        }
        if (hashSet.isEmpty()) {
            hashSet.add(UNSPECIFIED_GAMMA);
        }
        builder.kernelSet = new HashSet();
        switch (i3) {
            case 0:
                builder.kernelSet.add(new LinearKernel());
                break;
            case 1:
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    builder.kernelSet.add(new PolynomialKernel(i4, ((Float) it.next()).floatValue(), f));
                }
                break;
            case 2:
                Iterator it2 = hashSet.iterator();
                while (it2.hasNext()) {
                    builder.kernelSet.add(new GaussianRBFKernel(((Float) it2.next()).floatValue()));
                }
                break;
            case 3:
                Iterator it3 = hashSet.iterator();
                while (it3.hasNext()) {
                    builder.kernelSet.add(new SigmoidKernel(((Float) it3.next()).floatValue(), f));
                }
                break;
            case 4:
                builder.kernelSet.add(new PrecomputedKernel());
                break;
            default:
                throw new SvmException("Unknown kernel type: " + i3);
        }
        builder.scalingModelLearner = noopScalingModelLearner;
        this.param = builder.build();
        switch (i2) {
            case 0:
                this.svm = new C_SVC();
                return;
            case 1:
                this.svm = new Nu_SVC();
                return;
            case 2:
                this.svm = new OneClassSVC();
                return;
            case 3:
                this.svm = new EpsilonSVR();
                return;
            case 4:
                this.svm = new Nu_SVR();
                return;
            default:
                throw new SvmException("Unknown svm type: " + i3);
        }
    }

    private static void exit_with_help() {
        System.out.print("Usage: svm_train [options] training_set_file [model_file]\noptions:\n-s svm_type : set type of SVM (default 0)\n\t0 -- C-SVC\n\t1 -- nu-SVC\n\t2 -- one-class SVM\n\t3 -- epsilon-SVR\n\t4 -- nu-SVR\n-t kernel_type : set type of kernel function (default 2)\n\t0 -- linear: u'*v\n\t1 -- polynomial: (gamma*u'*v + coef0)^degree\n\t2 -- radial basis function: exp(-gamma*|u-v|^2)\n\t3 -- sigmoid: tanh(gamma*u'*v + coef0)\n\t4 -- precomputed kernel (kernel values in training_set_file)\n-d degree : set degree in kernel function (default 3)\n-g gamma : set gamma in kernel function (default 1/k)\n-r coef0 : set coef0 in kernel function (default 0)\n-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n-m cachesize : set cache memory size in MB (default 100)\n-e epsilon : set tolerance of termination criterion (default 0.001)\n-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)\n-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n-wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)\n-a allVsAllMode: None, AllVsAll, FilteredVsAll, FilteredVsFiltered\n-j minVoteProportion: the chosen class must have at least this proportion of the total votes\n-o oneVsAllMode: None, Best, Veto, BreakTies, VetoAndBreakTies \n-k oneVsAllProb: the chosen class must have at least this one-vs-all probability; if -b is not set, probabilities are 0 or 1\n-v n: n-fold cross validation mode\n-f scalingmode : none (default), linear, zscore\n-x scalinglimit : maximum examples to use for scaling (default 1000)\n-l 2: project to unit sphere (normalize L2 distance)\n");
        System.exit(1);
    }

    private void read_problem() throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new FileReader(this.input_file_name));
        Vector vector = new Vector();
        Vector vector2 = new Vector();
        int i = 0;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                break;
            }
            StringTokenizer stringTokenizer = new StringTokenizer(readLine, " \t\n\r\f:");
            vector.addElement(Float.valueOf(Float.parseFloat(stringTokenizer.nextToken())));
            int countTokens = stringTokenizer.countTokens() / 2;
            SparseVector sparseVector = new SparseVector(countTokens);
            for (int i2 = 0; i2 < countTokens; i2++) {
                sparseVector.indexes[i2] = Integer.parseInt(stringTokenizer.nextToken());
                sparseVector.values[i2] = Float.parseFloat(stringTokenizer.nextToken());
            }
            if (countTokens > 0) {
                i = Math.max(i, sparseVector.indexes[countTokens - 1]);
            }
            vector2.addElement(sparseVector);
        }
        if (this.svm instanceof RegressionSVM) {
            this.problem = new MutableRegressionProblemImpl(vector.size());
        } else {
            int size = new HashSet(vector).size();
            if (size == 1) {
                this.problem = new MutableRegressionProblemImpl(vector.size());
            } else if (size == 2) {
                this.problem = new MutableBinaryClassificationProblemImpl(String.class, vector.size());
            } else {
                this.problem = new MutableMultiClassProblemImpl(String.class, new StringLabelInverter(), vector.size(), new NoopScalingModel());
            }
        }
        for (int i3 = 0; i3 < vector.size(); i3++) {
            this.problem.addExampleFloat(vector2.elementAt(i3), (Float) vector.elementAt(i3));
        }
        if (this.problem instanceof BinaryClassificationProblem) {
            ((BinaryClassificationProblem) this.problem).setupLabels();
        }
        if (this.param instanceof ImmutableSvmParameterGrid) {
            Iterator it = ((ImmutableSvmParameterGrid) this.param).getGridParams().iterator();
            while (it.hasNext()) {
                updateKernelWithNumExamples((ImmutableSvmParameterPoint) it.next(), i);
            }
        } else {
            updateKernelWithNumExamples((ImmutableSvmParameterPoint) this.param, i);
        }
        bufferedReader.close();
    }

    private void updateKernelWithNumExamples(ImmutableSvmParameterPoint immutableSvmParameterPoint, int i) {
        Object obj = immutableSvmParameterPoint.kernel;
        if ((obj instanceof GammaKernel) && ((GammaKernel) obj).getGamma() == UNSPECIFIED_GAMMA.floatValue()) {
            ((GammaKernel) obj).setGamma(1.0f / i);
        }
        if (obj instanceof PrecomputedKernel) {
            throw new NotImplementedException();
        }
    }
}
