package edu.berkeley.compbio.ml.mcmc.mcmcmc;

import com.davidsoergel.dsutils.math.MersenneTwisterFast;
import edu.berkeley.compbio.ml.mcmc.MonteCarlo;
import edu.berkeley.compbio.ml.mcmc.Move;
import edu.berkeley.compbio.ml.mcmc.ProbabilityMove;
import org.apache.log4j.Logger;

/* loaded from: input_file:lib/ml-0.921.jar:edu/berkeley/compbio/ml/mcmc/mcmcmc/MetropolisCoupledSwapMove.class */
public class MetropolisCoupledSwapMove extends Move implements ProbabilityMove {
    private static final Logger logger = Logger.getLogger(MetropolisCoupledSwapMove.class);
    private ChainList chains;
    private int swap1;
    private int swap2;

    public MetropolisCoupledSwapMove(ChainList chainList) {
        this.chains = chainList;
        propose();
    }

    @Override // edu.berkeley.compbio.ml.mcmc.Move
    public final void propose() {
        this.swap1 = MersenneTwisterFast.randomInt(this.chains.size() - 1);
        this.swap2 = this.swap1 + 1;
    }

    @Override // edu.berkeley.compbio.ml.mcmc.ProbabilityMove
    public ChainList doMove(double d) {
        if (isAccepted()) {
            MonteCarlo monteCarlo = this.chains.get(this.swap1);
            MonteCarlo monteCarlo2 = this.chains.get(this.swap2);
            logger.debug("SWAPPING CHAINS " + this.swap1 + " (" + monteCarlo.getHeatFactor() + ") " + this.swap2 + " (" + monteCarlo2.getHeatFactor() + ") ");
            double heatFactor = monteCarlo.getHeatFactor();
            monteCarlo.setHeatFactor(monteCarlo2.getHeatFactor());
            monteCarlo2.setHeatFactor(heatFactor);
            if (monteCarlo.isColdest()) {
                monteCarlo.setColdest(false);
                monteCarlo2.setColdest(true);
            } else if (monteCarlo2.isColdest()) {
                monteCarlo2.setColdest(false);
                monteCarlo.setColdest(true);
            }
            ChainList chainList = new ChainList();
            chainList.addAll(this.chains);
            this.chains = chainList;
        }
        return this.chains;
    }

    private boolean isAccepted() {
        MonteCarlo monteCarlo = this.chains.get(this.swap1);
        MonteCarlo monteCarlo2 = this.chains.get(this.swap2);
        double unnormalizedLogLikelihood = (monteCarlo.unnormalizedLogLikelihood(monteCarlo2.getCurrentState()) + monteCarlo2.unnormalizedLogLikelihood(monteCarlo.getCurrentState())) - (monteCarlo.unnormalizedLogLikelihood(monteCarlo.getCurrentState()) + monteCarlo2.unnormalizedLogLikelihood(monteCarlo2.getCurrentState()));
        if (logger.isDebugEnabled()) {
            logger.debug(String.format("Swap log likelihood components: (%f * %f) / (%f * %f)", Double.valueOf(monteCarlo.unnormalizedLogLikelihood(monteCarlo2.getCurrentState())), Double.valueOf(monteCarlo2.unnormalizedLogLikelihood(monteCarlo.getCurrentState())), Double.valueOf(monteCarlo.unnormalizedLogLikelihood(monteCarlo.getCurrentState())), Double.valueOf(monteCarlo2.unnormalizedLogLikelihood(monteCarlo2.getCurrentState()))));
            logger.debug("swapLogLikelihoodRatio = " + unnormalizedLogLikelihood);
        }
        return Math.log(MersenneTwisterFast.random()) < unnormalizedLogLikelihood;
    }
}
