package org.nd4j.linalg.api.ops.impl.layers.recurrent;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.shade.guava.primitives.Booleans;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.class */
public class LSTMLayer extends DynamicCustomOp {
    private LSTMLayerConfig configuration;
    private LSTMLayerWeights weights;
    private SDVariable cLast;
    private SDVariable yLast;
    private String cLastName;
    private String yLastName;
    private SDVariable maxTSLength;

    public LSTMLayer() {
    }

    public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, LSTMLayerWeights lSTMLayerWeights, LSTMLayerConfig lSTMLayerConfig) {
        super((String) null, sameDiff, lSTMLayerWeights.argsWithInputs(sDVariable, sDVariable4, sDVariable2, sDVariable3));
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        this.configuration = lSTMLayerConfig;
        this.weights = lSTMLayerWeights;
        this.cLast = sDVariable2;
        this.yLast = sDVariable3;
        this.maxTSLength = sDVariable4;
        addIArgument(iArgs());
        addTArgument(tArgs());
        addBArgument(bArgs(lSTMLayerWeights, sDVariable4, sDVariable3, sDVariable2));
        Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence  methods  in LSTMLayerConfig builder to specify them");
    }

    public LSTMLayer(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, LSTMLayerWeights lSTMLayerWeights, LSTMLayerConfig lSTMLayerConfig) {
        super((String) null, (INDArray[]) null, lSTMLayerWeights.argsWithInputs(iNDArray4, iNDArray, iNDArray2, iNDArray3));
        this.configuration = lSTMLayerConfig;
        this.weights = lSTMLayerWeights;
        addIArgument(iArgs());
        addTArgument(tArgs());
        addBArgument(bArgs(this.weights, iNDArray4, iNDArray3, iNDArray2));
        Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence  methods  in LSTMLayerConfig builder to specify them");
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && 3 <= list.size() && list.size() <= 8, "Expected amount of inputs to LSTMLayer between 3 inputs minimum (input, Wx, Wr only) or 8 maximum, got %s", list);
        DataType dataType = list.get(1);
        ArrayList arrayList = new ArrayList();
        if (this.configuration.isRetFullSequence()) {
            arrayList.add(dataType);
        }
        if (this.configuration.isRetLastC()) {
            arrayList.add(dataType);
        }
        if (this.configuration.isRetLastH()) {
            arrayList.add(dataType);
        }
        Preconditions.checkState(dataType.isFPType(), "Input type 1 must be a floating point type, got %s", dataType);
        return arrayList;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        SDVariable sDVariable;
        SDVariable sDVariable2;
        SDVariable sDVariable3;
        int i = 0;
        if (this.configuration.isRetFullSequence()) {
            i = 0 + 1;
            sDVariable = list.get(0);
        } else {
            sDVariable = null;
        }
        SDVariable sDVariable4 = sDVariable;
        if (this.configuration.isRetLastH()) {
            int i2 = i;
            i++;
            sDVariable2 = list.get(i2);
        } else {
            sDVariable2 = null;
        }
        SDVariable sDVariable5 = sDVariable2;
        if (this.configuration.isRetLastC()) {
            int i3 = i;
            int i4 = i + 1;
            sDVariable3 = list.get(i3);
        } else {
            sDVariable3 = null;
        }
        return Arrays.asList(new LSTMLayerBp(this.sameDiff, arg(0), this.cLast, this.yLast, this.maxTSLength, this.weights, this.configuration, sDVariable4, sDVariable5, sDVariable3).outputVariables());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "lstmLayer";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        Map<String, Object> properties = this.configuration.toProperties(true, true);
        if (this.cLast != null) {
            properties.put("cLast", this.cLast);
        }
        if (this.yLast != null) {
            properties.put("yLast", this.yLast);
        }
        return properties;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public long[] iArgs() {
        return new long[]{this.configuration.getLstmdataformat().ordinal(), this.configuration.getDirectionMode().ordinal(), this.configuration.getGateAct().ordinal(), this.configuration.getOutAct().ordinal(), this.configuration.getCellAct().ordinal()};
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public double[] tArgs() {
        return new double[]{this.configuration.getCellClip()};
    }

    protected <T> boolean[] bArgs(LSTMLayerWeights lSTMLayerWeights, T t, T t2, T t3) {
        boolean[] zArr = new boolean[8];
        zArr[0] = lSTMLayerWeights.hasBias();
        zArr[1] = t != null;
        zArr[2] = t2 != null;
        zArr[3] = t3 != null;
        zArr[4] = lSTMLayerWeights.hasPH();
        zArr[5] = this.configuration.isRetFullSequence();
        zArr[6] = this.configuration.isRetLastH();
        zArr[7] = this.configuration.isRetLastC();
        return zArr;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public void configureFromArguments() {
        if (this.configuration != null || this.bArguments.isEmpty() || this.iArguments.isEmpty() || this.tArguments.isEmpty()) {
            return;
        }
        LSTMLayerConfig.LSTMLayerConfigBuilder builder = LSTMLayerConfig.builder();
        builder.retLastH(this.bArguments.get(6).booleanValue());
        builder.retFullSequence(this.bArguments.get(5).booleanValue());
        builder.retLastC(this.bArguments.get(4).booleanValue());
        builder.cellClip(this.tArguments.get(0).doubleValue());
        builder.lstmdataformat(LSTMDataFormat.values()[this.iArguments.get(0).intValue()]);
        builder.directionMode(LSTMDirectionMode.values()[this.iArguments.get(1).intValue()]);
        builder.gateAct(LSTMActivations.values()[this.iArguments.get(2).intValue()]);
        builder.outAct(LSTMActivations.values()[this.iArguments.get(3).intValue()]);
        builder.cellAct(LSTMActivations.values()[this.iArguments.get(4).intValue()]);
        this.configuration = builder.build();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void configureWithSameDiff(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
        String[] inputsForOp = sameDiff.getInputsForOp(this);
        LSTMLayerWeights.LSTMLayerWeightsBuilder builder = LSTMLayerWeights.builder();
        boolean booleanValue = this.bArguments.get(0).booleanValue();
        this.bArguments.get(1).booleanValue();
        this.bArguments.get(2).booleanValue();
        this.bArguments.get(3).booleanValue();
        boolean booleanValue2 = this.bArguments.get(4).booleanValue();
        this.bArguments.get(5).booleanValue();
        this.bArguments.get(6).booleanValue();
        this.bArguments.get(7).booleanValue();
        if (inputsForOp.length > 1) {
            builder.weights(sameDiff.getVariable(inputsForOp[1]));
        }
        if (inputsForOp.length > 2) {
            builder.rWeights(sameDiff.getVariable(inputsForOp[2]));
        }
        if (booleanValue) {
            builder.bias(sameDiff.getVariable(inputsForOp[3]));
        }
        if (booleanValue2) {
            builder.peepholeWeights(sameDiff.getVariable(inputsForOp[inputsForOp.length - 1]));
        }
        this.weights = builder.build();
        if (this.yLastName != null) {
            this.yLast = sameDiff.getVariable(this.yLastName);
        }
        if (this.cLastName != null) {
            this.cLast = sameDiff.getVariable(this.cLastName);
        }
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void setPropertiesForFunction(Map<String, Object> map) {
        if (this.configuration == null) {
            LSTMLayerConfig.LSTMLayerConfigBuilder builder = LSTMLayerConfig.builder();
            Boolean booleanFromProperty = getBooleanFromProperty("retFullSequence", map);
            if (booleanFromProperty != null) {
                builder.retFullSequence(booleanFromProperty.booleanValue());
            }
            String stringFromProperty = getStringFromProperty("outAct", map);
            if (stringFromProperty != null) {
                builder.outAct(LSTMActivations.valueOf(stringFromProperty));
            }
            String stringFromProperty2 = getStringFromProperty("directionMode", map);
            if (stringFromProperty2 != null) {
                builder.directionMode(LSTMDirectionMode.valueOf(stringFromProperty2));
            }
            Double doubleValueFromProperty = getDoubleValueFromProperty("cellClip", map);
            if (doubleValueFromProperty != null) {
                builder.cellClip(doubleValueFromProperty.doubleValue());
            }
            String stringFromProperty3 = getStringFromProperty("cellAct", map);
            if (stringFromProperty3 != null) {
                builder.cellAct(LSTMActivations.valueOf(stringFromProperty3));
            }
            Boolean booleanFromProperty2 = getBooleanFromProperty("retLastC", map);
            if (booleanFromProperty2 != null) {
                builder.retLastC(booleanFromProperty2.booleanValue());
            }
            Boolean booleanFromProperty3 = getBooleanFromProperty("retLastH", map);
            if (booleanFromProperty3 != null) {
                builder.retLastH(booleanFromProperty3.booleanValue());
            }
            String stringFromProperty4 = getStringFromProperty("gateAct", map);
            if (stringFromProperty4 != null) {
                builder.gateAct(LSTMActivations.valueOf(stringFromProperty4));
            }
            String stringFromProperty5 = getStringFromProperty("lstmdataformat", map);
            if (stringFromProperty5 != null) {
                builder.lstmdataformat((LSTMDataFormat) LSTMDataFormat.valueOf(LSTMDataFormat.class, stringFromProperty5));
            }
            String stringFromProperty6 = getStringFromProperty("cLast", map);
            if (stringFromProperty6 != null) {
                this.cLastName = stringFromProperty6;
            }
            String stringFromProperty7 = getStringFromProperty("cLast", map);
            if (stringFromProperty7 != null) {
                this.yLastName = stringFromProperty7;
            }
            this.configuration = builder.build();
        }
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean isConfigProperties() {
        return true;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String configFieldName() {
        return "configuration";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int getNumOutputs() {
        return Booleans.countTrue(new boolean[]{this.configuration.isRetFullSequence(), this.configuration.isRetLastH(), this.configuration.isRetLastC()});
    }

    public LSTMLayerConfig getConfiguration() {
        return this.configuration;
    }

    public LSTMLayerWeights getWeights() {
        return this.weights;
    }
}
