package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
import org.nd4j.linalg.dataset.api.preprocessor.stats.DistributionStats;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/MultiNormalizerStandardize.class */
public class MultiNormalizerStandardize extends AbstractMultiDataSetNormalizer<DistributionStats> {
    public MultiNormalizerStandardize() {
        super(new StandardizeStrategy());
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractMultiDataSetNormalizer
    protected NormalizerStats.Builder newBuilder() {
        return new DistributionStats.Builder();
    }

    public INDArray getFeatureMean(int i) {
        return getFeatureStats(i).getMean();
    }

    public INDArray getLabelMean(int i) {
        return getLabelStats(i).getMean();
    }

    public INDArray getFeatureStd(int i) {
        return getFeatureStats(i).getStd();
    }

    public INDArray getLabelStd(int i) {
        return getLabelStats(i).getStd();
    }

    public void load(@NonNull List<File> list, @NonNull List<File> list2) throws IOException {
        if (list == null) {
            throw new NullPointerException("featureFiles is marked non-null but is null");
        }
        if (list2 == null) {
            throw new NullPointerException("labelFiles is marked non-null but is null");
        }
        setFeatureStats(load(list));
        if (isFitLabel()) {
            setLabelStats(load(list2));
        }
    }

    private List<DistributionStats> load(List<File> list) throws IOException {
        ArrayList arrayList = new ArrayList(list.size() / 2);
        for (int i = 0; i < list.size() / 2; i++) {
            arrayList.add(DistributionStats.load(list.get(i * 2), list.get((i * 2) + 1)));
        }
        return arrayList;
    }

    public void save(@NonNull List<File> list, @NonNull List<File> list2) throws IOException {
        if (list == null) {
            throw new NullPointerException("featureFiles is marked non-null but is null");
        }
        if (list2 == null) {
            throw new NullPointerException("labelFiles is marked non-null but is null");
        }
        saveStats(getFeatureStats(), list);
        if (isFitLabel()) {
            saveStats(getLabelStats(), list2);
        }
    }

    private void saveStats(List<DistributionStats> list, List<File> list2) throws IOException {
        int size = list.size() * 2;
        if (size != list2.size()) {
            throw new RuntimeException(String.format("Need twice as many files as inputs / outputs (%d), got %d", Integer.valueOf(size), Integer.valueOf(list2.size())));
        }
        for (int i = 0; i < list.size(); i++) {
            list.get(i).save(list2.get(i * 2), list2.get((i * 2) + 1));
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public NormalizerType getType() {
        return NormalizerType.MULTI_STANDARDIZE;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractMultiDataSetNormalizer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof MultiNormalizerStandardize) && ((MultiNormalizerStandardize) obj).canEqual(this) && super.equals(obj);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractMultiDataSetNormalizer
    protected boolean canEqual(Object obj) {
        return obj instanceof MultiNormalizerStandardize;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractMultiDataSetNormalizer
    public int hashCode() {
        return super.hashCode();
    }
}
