package org.nd4j.autodiff.listeners;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.listeners.records.EvaluationRecord;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;

/* loaded from: input_file:org/nd4j/autodiff/listeners/BaseEvaluationListener.class */
public abstract class BaseEvaluationListener extends BaseListener {
    private Map<String, List<IEvaluation>> trainingEvaluations = new HashMap();
    private Map<String, List<IEvaluation>> validationEvaluations = new HashMap();

    public abstract ListenerEvaluations evaluations();

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public final ListenerVariables requiredVariables(SameDiff sameDiff) {
        return evaluations().requiredVariables().merge(otherRequiredVariables(sameDiff));
    }

    public ListenerVariables otherRequiredVariables(SameDiff sameDiff) {
        return ListenerVariables.empty();
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public final void epochStart(SameDiff sameDiff, At at) {
        this.trainingEvaluations = new HashMap();
        for (Map.Entry<String, List<IEvaluation>> entry : evaluations().trainEvaluations().entrySet()) {
            ArrayList arrayList = new ArrayList();
            Iterator<IEvaluation> it = entry.getValue().iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().newInstance());
            }
            this.trainingEvaluations.put(entry.getKey(), arrayList);
        }
        this.validationEvaluations = new HashMap();
        for (Map.Entry<String, List<IEvaluation>> entry2 : evaluations().validationEvaluations().entrySet()) {
            ArrayList arrayList2 = new ArrayList();
            Iterator<IEvaluation> it2 = entry2.getValue().iterator();
            while (it2.hasNext()) {
                arrayList2.add(it2.next().newInstance());
            }
            this.validationEvaluations.put(entry2.getKey(), arrayList2);
        }
        epochStartEvaluations(sameDiff, at);
    }

    public void epochStartEvaluations(SameDiff sameDiff, At at) {
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public final ListenerResponse epochEnd(SameDiff sameDiff, At at, LossCurve lossCurve, long j) {
        return epochEndEvaluations(sameDiff, at, lossCurve, j, new EvaluationRecord(this.trainingEvaluations));
    }

    public ListenerResponse epochEndEvaluations(SameDiff sameDiff, At at, LossCurve lossCurve, long j, EvaluationRecord evaluationRecord) {
        return ListenerResponse.CONTINUE;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public final ListenerResponse validationDone(SameDiff sameDiff, At at, long j) {
        return validationDoneEvaluations(sameDiff, at, j, new EvaluationRecord(this.validationEvaluations));
    }

    public ListenerResponse validationDoneEvaluations(SameDiff sameDiff, At at, long j, EvaluationRecord evaluationRecord) {
        return ListenerResponse.CONTINUE;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public final void activationAvailable(SameDiff sameDiff, At at, MultiDataSet multiDataSet, SameDiffOp sameDiffOp, String str, INDArray iNDArray) {
        if (at.operation() == Operation.TRAINING) {
            if (this.trainingEvaluations.containsKey(str)) {
                INDArray labels = multiDataSet.getLabels(evaluations().trainEvaluationLabels().get(str).intValue());
                INDArray labelsMaskArray = multiDataSet.getLabelsMaskArray(evaluations().trainEvaluationLabels().get(str).intValue());
                Iterator<IEvaluation> it = this.trainingEvaluations.get(str).iterator();
                while (it.hasNext()) {
                    it.next().eval(labels, iNDArray, labelsMaskArray);
                }
            }
        } else if (at.operation() == Operation.TRAINING_VALIDATION && this.validationEvaluations.containsKey(str)) {
            INDArray labels2 = multiDataSet.getLabels(evaluations().validationEvaluationLabels().get(str).intValue());
            INDArray labelsMaskArray2 = multiDataSet.getLabelsMaskArray(evaluations().validationEvaluationLabels().get(str).intValue());
            Iterator<IEvaluation> it2 = this.validationEvaluations.get(str).iterator();
            while (it2.hasNext()) {
                it2.next().eval(labels2, iNDArray, labelsMaskArray2);
            }
        }
        activationAvailableEvaluations(sameDiff, at, multiDataSet, sameDiffOp, str, iNDArray);
    }

    public void activationAvailableEvaluations(SameDiff sameDiff, At at, MultiDataSet multiDataSet, SameDiffOp sameDiffOp, String str, INDArray iNDArray) {
    }
}
