package org.nd4j.autodiff.listeners.impl;

import java.text.DecimalFormat;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/listeners/impl/ScoreListener.class */
public class ScoreListener extends BaseListener {
    private final int frequency;
    private final boolean reportEpochs;
    private final boolean reportIterPerformance;
    private long epochExampleCount;
    private int epochBatchCount;
    private long etlTotalTimeEpoch;
    private long lastIterTime;
    private long etlTimeSumSinceLastReport;
    private long iterTimeSumSinceLastReport;
    private int examplesSinceLastReportIter;
    private long lastReportTime;
    private static final Logger log = LoggerFactory.getLogger(ScoreListener.class);
    protected static final ThreadLocal<DecimalFormat> DF_2DP = new ThreadLocal<>();
    protected static final ThreadLocal<DecimalFormat> DF_2DP_SCI = new ThreadLocal<>();
    protected static final ThreadLocal<DecimalFormat> DF_5DP = new ThreadLocal<>();
    protected static final ThreadLocal<DecimalFormat> DF_5DP_SCI = new ThreadLocal<>();

    public ScoreListener() {
        this(10, true);
    }

    public ScoreListener(int i) {
        this(i, true);
    }

    public ScoreListener(int i, boolean z) {
        this(i, z, true);
    }

    public ScoreListener(int i, boolean z, boolean z2) {
        this.lastReportTime = -1L;
        Preconditions.checkArgument(i > 0, "ScoreListener frequency must be > 0, got %s", i);
        this.frequency = i;
        this.reportEpochs = z;
        this.reportIterPerformance = z2;
    }

    @Override // org.nd4j.autodiff.listeners.Listener
    public boolean isActive(Operation operation) {
        return operation == Operation.TRAINING;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void epochStart(SameDiff sameDiff, At at) {
        if (this.reportEpochs) {
            this.epochExampleCount = 0L;
            this.epochBatchCount = 0;
            this.etlTotalTimeEpoch = 0L;
        }
        this.lastReportTime = -1L;
        this.examplesSinceLastReportIter = 0;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public ListenerResponse epochEnd(SameDiff sameDiff, At at, LossCurve lossCurve, long j) {
        if (this.reportEpochs) {
            log.info("Epoch {} complete on iteration {} - {} batches ({} examples) in {} - {} batches/sec, {} examples/sec, {}", new Object[]{Integer.valueOf(at.epoch()), Integer.valueOf(at.iteration()), Integer.valueOf(this.epochBatchCount), Long.valueOf(this.epochExampleCount), formatDurationMs(j), format2dp(this.epochBatchCount / (j / 1000.0d)), format2dp(this.epochExampleCount / (j / 1000.0d)), formatDurationMs(this.etlTotalTimeEpoch) + " ETL time" + (this.etlTotalTimeEpoch > 0 ? "(" + format2dp((100.0d * this.etlTotalTimeEpoch) / j) + " %)" : "")});
        }
        return ListenerResponse.CONTINUE;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void iterationStart(SameDiff sameDiff, At at, MultiDataSet multiDataSet, long j) {
        this.lastIterTime = System.currentTimeMillis();
        this.etlTimeSumSinceLastReport += j;
        this.etlTotalTimeEpoch += j;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void iterationDone(SameDiff sameDiff, At at, MultiDataSet multiDataSet, Loss loss) {
        this.iterTimeSumSinceLastReport += System.currentTimeMillis() - this.lastIterTime;
        this.epochBatchCount++;
        if (multiDataSet.numFeatureArrays() > 0 && multiDataSet.getFeatures(0) != null) {
            int size = (int) multiDataSet.getFeatures(0).size(0);
            this.examplesSinceLastReportIter += size;
            this.epochExampleCount += size;
        }
        if (at.iteration() <= 0 || at.iteration() % this.frequency != 0) {
            return;
        }
        double d = loss.totalLoss();
        String str = "";
        if (this.etlTimeSumSinceLastReport > 0) {
            String str2 = "(" + formatDurationMs(this.etlTimeSumSinceLastReport) + " ETL";
            str = this.frequency == 1 ? str2 + ")" : str2 + " in " + this.frequency + " iter)";
        }
        if (this.reportIterPerformance) {
            long currentTimeMillis = System.currentTimeMillis();
            if (this.lastReportTime > 0) {
                log.info("Loss at epoch {}, iteration {}: {}{}, batches/sec: {}, examples/sec: {}", new Object[]{Integer.valueOf(at.epoch()), Integer.valueOf(at.iteration()), format5dp(d), str, format5dp((1000 * this.frequency) / (currentTimeMillis - this.lastReportTime)), format5dp((1000 * this.examplesSinceLastReportIter) / (currentTimeMillis - this.lastReportTime))});
            } else {
                log.info("Loss at epoch {}, iteration {}: {}{}", new Object[]{Integer.valueOf(at.epoch()), Integer.valueOf(at.iteration()), format5dp(d), str});
            }
            this.lastReportTime = currentTimeMillis;
        } else {
            log.info("Loss at epoch {}, iteration {}: {}{}", new Object[]{Integer.valueOf(at.epoch()), Integer.valueOf(at.iteration()), format5dp(d), str});
        }
        this.iterTimeSumSinceLastReport = 0L;
        this.etlTimeSumSinceLastReport = 0L;
        this.examplesSinceLastReportIter = 0;
    }

    protected String formatDurationMs(long j) {
        return j <= 100 ? j + " ms" : j <= 60000 ? format2dp(j / 1000.0d) + " sec" : j <= 3600000 ? format2dp(j / 60000.0d) + " min" : format2dp(j / 360000.0d) + " hr";
    }

    protected String format2dp(double d) {
        if (d < 0.01d) {
            DecimalFormat decimalFormat = DF_2DP_SCI.get();
            if (decimalFormat == null) {
                decimalFormat = new DecimalFormat("0.00E0");
                DF_2DP.set(decimalFormat);
            }
            return decimalFormat.format(d);
        }
        DecimalFormat decimalFormat2 = DF_2DP.get();
        if (decimalFormat2 == null) {
            decimalFormat2 = new DecimalFormat("#.00");
            DF_2DP.set(decimalFormat2);
        }
        return decimalFormat2.format(d);
    }

    protected String format5dp(double d) {
        if (d < 1.0E-4d || d > 10000.0d) {
            DecimalFormat decimalFormat = DF_5DP_SCI.get();
            if (decimalFormat == null) {
                decimalFormat = new DecimalFormat("0.00000E0");
                DF_5DP_SCI.set(decimalFormat);
            }
            return decimalFormat.format(d);
        }
        DecimalFormat decimalFormat2 = DF_5DP.get();
        if (decimalFormat2 == null) {
            decimalFormat2 = new DecimalFormat("0.00000");
            DF_5DP.set(decimalFormat2);
        }
        return decimalFormat2.format(d);
    }
}
