package org.nd4j.linalg.api.ops.impl.transforms.custom;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.shade.guava.primitives.Ints;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.class */
public class LayerNorm extends DynamicCustomOp {
    private boolean noBias;
    private boolean channelsFirst;

    public LayerNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, boolean z, int... iArr) {
        super(null, sameDiff, (SDVariable[]) wrapFilterNull(sDVariable, sDVariable2, sDVariable3), false);
        this.noBias = false;
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        if (sDVariable == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("gain is marked non-null but is null");
        }
        this.noBias = sDVariable3 == null;
        this.channelsFirst = z;
        setDimensions(iArr);
    }

    public LayerNorm(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, boolean z, int... iArr) {
        this(sameDiff, sDVariable, sDVariable2, (SDVariable) null, z, iArr);
    }

    public LayerNorm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z, int... iArr) {
        super("layer_norm", (INDArray[]) wrapFilterNull(iNDArray, iNDArray2, iNDArray3), wrapOrNull(iNDArray4));
        this.noBias = false;
        this.noBias = iNDArray3 == null;
        this.channelsFirst = z;
        setDimensions(iArr);
    }

    public LayerNorm(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, boolean z, int... iArr) {
        this(iNDArray, iNDArray2, (INDArray) null, z, iArr);
        if (iNDArray == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("gain is marked non-null but is null");
        }
    }

    public LayerNorm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, boolean z, int... iArr) {
        this(iNDArray, iNDArray2, (INDArray) null, iNDArray3, z, iArr);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.ReduceOp
    public void setDimensions(int[] iArr) {
        Preconditions.checkArgument(iArr != null, "LayerNorm: You have to provide dimensions");
        Preconditions.checkArgument(iArr.length > 0, "LayerNorm: You have to provide dimensions");
        this.dimensions = iArr;
        this.iArguments.clear();
        addIArgument(iArr);
        this.bArguments.clear();
        this.bArguments.add(Boolean.valueOf(this.channelsFirst));
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public void addBArgument(boolean... zArr) {
        super.addBArgument(zArr);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        HashMap hashMap = new HashMap();
        hashMap.put("noBias", Boolean.valueOf(this.noBias));
        hashMap.put("channelsFirst", Boolean.valueOf(this.channelsFirst));
        if (this.dimensions != null) {
            hashMap.put("dimensions", this.dimensions);
        }
        return hashMap;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public void configureFromArguments() {
        if (!this.bArguments.isEmpty() && this.bArguments.size() > 1) {
            this.noBias = this.bArguments.get(1).booleanValue();
        }
        if (!this.bArguments.isEmpty()) {
            this.channelsFirst = this.bArguments.get(0).booleanValue();
        }
        if (this.iArguments.isEmpty()) {
            return;
        }
        this.dimensions = Ints.toArray(this.iArguments);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void setPropertiesForFunction(Map<String, Object> map) {
        Boolean booleanFromProperty = getBooleanFromProperty("noBias", map);
        if (booleanFromProperty != null) {
            this.noBias = booleanFromProperty.booleanValue();
        }
        Boolean booleanFromProperty2 = getBooleanFromProperty("channelsFirst", map);
        if (booleanFromProperty2 != null) {
            this.channelsFirst = booleanFromProperty2.booleanValue();
        }
        if (map.containsKey("dimensions") && (map.get("dimensions") instanceof Long)) {
            this.dimensions = new int[]{((Long) map.get("dimensions")).intValue()};
        } else if (map.containsKey("dimensions") && (map.get("dimensions") instanceof int[])) {
            this.dimensions = (int[]) map.get("dimensions");
        }
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "layer_norm";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow name found for shape " + opName());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx name found for shape " + opName());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return this.noBias ? new LayerNormBp(this.sameDiff, arg(0), arg(1), list.get(0), this.channelsFirst, this.dimensions).outputs() : new LayerNormBp(this.sameDiff, arg(0), arg(1), arg(2), list.get(0), this.channelsFirst, this.dimensions).outputs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() >= 2 && list.size() <= 3, "Expected exactly 2 or 3 input datatypes, got %s", list);
        DataType dataType = list.get(0);
        Iterator<DataType> it = list.iterator();
        while (it.hasNext()) {
            DataType next = it.next();
            Preconditions.checkState(next.isFPType(), "Input %s datatype must be a floating point type, got datypes %s", list);
            Preconditions.checkState(dataType == next, "All datatypes must be same type, got input datatypes %s", list);
        }
        return Collections.singletonList(dataType);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public int numOutputArguments() {
        return this.noBias ? 2 : 3;
    }

    public LayerNorm() {
        this.noBias = false;
    }
}
