package org.nd4j.autodiff.listeners.records;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/autodiff/listeners/records/LossCurve.class */
public class LossCurve {
    private List<String> lossNames;
    private INDArray lossValues;

    public LossCurve(List<Loss> list) {
        this.lossNames = Collections.unmodifiableList(list.get(0).getLossNames());
        int length = list.get(0).lossValues().length;
        this.lossValues = Nd4j.create(DataType.FLOAT, list.size(), list.get(0).lossValues().length);
        for (int i = 0; i < list.size(); i++) {
            Loss loss = list.get(i);
            Preconditions.checkArgument(loss.getLossNames().equals(this.lossNames), "Loss names for loss %s differ from others.  Expected %s, got %s", Integer.valueOf(i), this.lossNames, loss.getLossNames());
            Preconditions.checkArgument(loss.getLosses().length == length, "Number of loss values for loss %s differ from others.  Expected %s, got %s", i, length, loss.getLosses().length);
            this.lossValues = this.lossValues.putRow(i, Nd4j.createFromArray(loss.getLosses()).castTo(DataType.FLOAT));
        }
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    public LossCurve(double[] dArr, List<String> list) {
        this.lossValues = Nd4j.createFromArray((double[][]) new double[]{dArr}).castTo(DataType.FLOAT);
        this.lossNames = list;
    }

    protected LossCurve(INDArray iNDArray, List<String> list) {
        Preconditions.checkArgument(iNDArray.rank() == 2, "lossValues must have a rank of 2, got %s", iNDArray.rank());
        Preconditions.checkArgument(iNDArray.dataType() == DataType.FLOAT, "lossValues must be type FLOAT, got %s", iNDArray.dataType());
        this.lossValues = iNDArray;
        this.lossNames = list;
    }

    public List<Loss> losses() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.lossValues.size(0); i++) {
            arrayList.add(new Loss(this.lossNames, this.lossValues.getRow(i).toDoubleVector()));
        }
        return arrayList;
    }

    public Loss meanLoss(int i) {
        return i >= 0 ? new Loss(this.lossNames, this.lossValues.getRow(i).toDoubleVector()) : new Loss(this.lossNames, this.lossValues.getRow(this.lossValues.rows() + i).toDoubleVector());
    }

    public Loss lastMeanLoss() {
        return meanLoss(-1);
    }

    public float[] meanLoss(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("lossName is marked non-null but is null");
        }
        int indexOf = this.lossNames.indexOf(str);
        Preconditions.checkArgument(indexOf >= 0, "No loss value for %s.  Existing losses: %s", str, this.lossNames);
        float[] fArr = new float[(int) this.lossValues.size(0)];
        for (int i = 0; i < this.lossValues.size(0); i++) {
            fArr[i] = this.lossValues.getFloat(i, indexOf);
        }
        return fArr;
    }

    public float[] meanLoss(@NonNull SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new NullPointerException("loss is marked non-null but is null");
        }
        return meanLoss(sDVariable.name());
    }

    public float meanLoss(@NonNull String str, int i) {
        if (str == null) {
            throw new NullPointerException("lossName is marked non-null but is null");
        }
        int indexOf = this.lossNames.indexOf(str);
        Preconditions.checkArgument(indexOf >= 0, "No loss value for %s.  Existing losses: %s", str, this.lossNames);
        return i >= 0 ? this.lossValues.getFloat(i, indexOf) : this.lossValues.getFloat(this.lossValues.rows() + i, indexOf);
    }

    public float meanLoss(@NonNull SDVariable sDVariable, int i) {
        if (sDVariable == null) {
            throw new NullPointerException("loss is marked non-null but is null");
        }
        return meanLoss(sDVariable.name(), i);
    }

    public float lastMeanLoss(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("lossName is marked non-null but is null");
        }
        int indexOf = this.lossNames.indexOf(str);
        Preconditions.checkArgument(indexOf >= 0, "No loss value for %s.  Existing losses: %s", str, this.lossNames);
        return this.lossValues.getFloat(this.lossValues.rows() - 1, indexOf);
    }

    public float lastMeanLoss(@NonNull SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new NullPointerException("loss is marked non-null but is null");
        }
        return lastMeanLoss(sDVariable.name());
    }

    public Loss lastMeanDelta() {
        return lastMeanLoss().sub(meanLoss(-2));
    }

    public double lastMeanDelta(String str) {
        return lastMeanDelta().getLoss(str);
    }

    public double lastMeanDelta(SDVariable sDVariable) {
        return lastMeanDelta(sDVariable.name());
    }

    public LossCurve addLossAndCopy(Loss loss) {
        return addLossAndCopy(loss.getLosses(), loss.lossNames());
    }

    /* JADX WARN: Type inference failed for: r6v3, types: [double[], double[][]] */
    public LossCurve addLossAndCopy(double[] dArr, List<String> list) {
        return new LossCurve(Nd4j.concat(0, this.lossValues, Nd4j.createFromArray((double[][]) new double[]{dArr}).castTo(DataType.FLOAT)), list);
    }

    public List<String> getLossNames() {
        return this.lossNames;
    }

    public INDArray getLossValues() {
        return this.lossValues;
    }
}
