package com.stripe.rainier.sampler;

import scala.MatchError;
import scala.None$;
import scala.Some;
import scala.Tuple3;
import scala.collection.immutable.List;
import scala.collection.mutable.ListBuffer;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: Driver.scala */
/* loaded from: input_file:com/stripe/rainier/sampler/Driver$.class */
public final class Driver$ {
    public static Driver$ MODULE$;

    static {
        new Driver$();
    }

    public Tuple3<List<double[]>, MassMatrix, Stats> sample(int i, SamplerConfig samplerConfig, DensityFunction densityFunction, Progress progress, RNG rng) {
        Sampler sampler = samplerConfig.sampler();
        StepSizeTuner stepSizeTuner = samplerConfig.stepSizeTuner();
        MassMatrixTuner massMatrixTuner = samplerConfig.massMatrixTuner();
        LeapFrog leapFrog = new LeapFrog(densityFunction, samplerConfig.statsWindow());
        progress.start(i);
        Log$.MODULE$.FINE().log("Starting warmup");
        double[] initialize = leapFrog.initialize(IdentityMassMatrix$.MODULE$, rng);
        MassMatrix warmup = warmup(i, initialize, leapFrog, sampler, stepSizeTuner, massMatrixTuner, samplerConfig.warmupIterations(), progress, rng);
        leapFrog.resetStats();
        Log$.MODULE$.FINE().log("Starting sampling");
        List<double[]> collectSamples = collectSamples(i, initialize, leapFrog, sampler, stepSizeTuner.stepSize(), warmup, samplerConfig.iterations(), progress, rng);
        Log$.MODULE$.FINE().log("Finished sampling");
        progress.finish(i, "Complete", leapFrog.stats(), warmup);
        return new Tuple3<>(collectSamples, warmup, leapFrog.stats());
    }

    private MassMatrix warmup(int i, double[] dArr, LeapFrog leapFrog, Sampler sampler, StepSizeTuner stepSizeTuner, MassMatrixTuner massMatrixTuner, int i2, Progress progress, RNG rng) {
        long nanoTime = System.nanoTime();
        sampler.initialize(dArr, leapFrog, rng);
        double initialize = stepSizeTuner.initialize(dArr, leapFrog);
        MassMatrix initialize2 = massMatrixTuner.initialize(leapFrog, i2);
        Log$.MODULE$.FINER().log("Initial step size %f", BoxesRunTime.boxToDouble(initialize));
        double[] dArr2 = new double[leapFrog.nVars()];
        for (int i3 = 0; i3 < i2; i3++) {
            double warmup = sampler.warmup(dArr, leapFrog, initialize, initialize2, rng);
            initialize = stepSizeTuner.update(warmup);
            Log$.MODULE$.FINEST().log("Accept probability %f", BoxesRunTime.boxToDouble(Math.exp(warmup)));
            Log$.MODULE$.FINEST().log("Adapted step size %f", BoxesRunTime.boxToDouble(initialize));
            leapFrog.variables(dArr, dArr2);
            Some mo21update = massMatrixTuner.mo21update(dArr2);
            if (mo21update instanceof Some) {
                initialize2 = (MassMatrix) mo21update.value();
                initialize = stepSizeTuner.reset();
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                if (!None$.MODULE$.equals(mo21update)) {
                    throw new MatchError(mo21update);
                }
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
            if (System.nanoTime() > nanoTime) {
                progress.refresh(i, "Warmup", leapFrog.stats(), initialize2);
                nanoTime = System.nanoTime() + ((long) (progress.outputEverySeconds() * 1.0E9d));
            }
        }
        return initialize2;
    }

    private List<double[]> collectSamples(int i, double[] dArr, LeapFrog leapFrog, Sampler sampler, double d, MassMatrix massMatrix, int i2, Progress progress, RNG rng) {
        long nanoTime = System.nanoTime();
        ListBuffer listBuffer = new ListBuffer();
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i4 >= i2) {
                return listBuffer.toList();
            }
            sampler.run(dArr, leapFrog, d, massMatrix, rng);
            double[] dArr2 = new double[leapFrog.nVars()];
            leapFrog.variables(dArr, dArr2);
            listBuffer.$plus$eq(dArr2);
            if (System.nanoTime() > nanoTime) {
                progress.refresh(i, "Sampling", leapFrog.stats(), massMatrix);
                nanoTime = System.nanoTime() + ((long) (progress.outputEverySeconds() * 1.0E9d));
            }
            i3 = i4 + 1;
        }
    }

    private Driver$() {
        MODULE$ = this;
    }
}
