package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.class */
public class LSTMLayerWeights extends RNNWeights {
    private SDVariable weights;
    private INDArray iWeights;
    private SDVariable rWeights;
    private INDArray irWeights;
    private SDVariable peepholeWeights;
    private INDArray iPeepholeWeights;
    private SDVariable bias;
    private INDArray iBias;

    /* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights$LSTMLayerWeightsBuilder.class */
    public static class LSTMLayerWeightsBuilder {
        private SDVariable weights;
        private INDArray iWeights;
        private SDVariable rWeights;
        private INDArray irWeights;
        private SDVariable peepholeWeights;
        private INDArray iPeepholeWeights;
        private SDVariable bias;
        private INDArray iBias;

        LSTMLayerWeightsBuilder() {
        }

        public LSTMLayerWeightsBuilder weights(SDVariable sDVariable) {
            this.weights = sDVariable;
            return this;
        }

        public LSTMLayerWeightsBuilder iWeights(INDArray iNDArray) {
            this.iWeights = iNDArray;
            return this;
        }

        public LSTMLayerWeightsBuilder rWeights(SDVariable sDVariable) {
            this.rWeights = sDVariable;
            return this;
        }

        public LSTMLayerWeightsBuilder irWeights(INDArray iNDArray) {
            this.irWeights = iNDArray;
            return this;
        }

        public LSTMLayerWeightsBuilder peepholeWeights(SDVariable sDVariable) {
            this.peepholeWeights = sDVariable;
            return this;
        }

        public LSTMLayerWeightsBuilder iPeepholeWeights(INDArray iNDArray) {
            this.iPeepholeWeights = iNDArray;
            return this;
        }

        public LSTMLayerWeightsBuilder bias(SDVariable sDVariable) {
            this.bias = sDVariable;
            return this;
        }

        public LSTMLayerWeightsBuilder iBias(INDArray iNDArray) {
            this.iBias = iNDArray;
            return this;
        }

        public LSTMLayerWeights build() {
            return new LSTMLayerWeights(this.weights, this.iWeights, this.rWeights, this.irWeights, this.peepholeWeights, this.iPeepholeWeights, this.bias, this.iBias);
        }

        public String toString() {
            return "LSTMLayerWeights.LSTMLayerWeightsBuilder(weights=" + this.weights + ", iWeights=" + this.iWeights + ", rWeights=" + this.rWeights + ", irWeights=" + this.irWeights + ", peepholeWeights=" + this.peepholeWeights + ", iPeepholeWeights=" + this.iPeepholeWeights + ", bias=" + this.bias + ", iBias=" + this.iBias + ")";
        }
    }

    @Override // org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.RNNWeights
    public SDVariable[] args() {
        return (SDVariable[]) filterNonNull(this.weights, this.rWeights, this.peepholeWeights, this.bias);
    }

    @Override // org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.RNNWeights
    public INDArray[] arrayArgs() {
        return (INDArray[]) filterNonNull(this.iWeights, this.irWeights, this.iPeepholeWeights, this.iBias);
    }

    @Override // org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.RNNWeights
    public SDVariable[] argsWithInputs(SDVariable... sDVariableArr) {
        Preconditions.checkArgument(sDVariableArr.length == 4, "Expected 4 inputs, got %s", sDVariableArr.length);
        return (SDVariable[]) ArrayUtil.filterNull(new SDVariable[]{sDVariableArr[0], this.weights, this.rWeights, this.bias, sDVariableArr[1], sDVariableArr[2], sDVariableArr[3], this.peepholeWeights});
    }

    @Override // org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.RNNWeights
    public INDArray[] argsWithInputs(INDArray... iNDArrayArr) {
        Preconditions.checkArgument(iNDArrayArr.length == 4, "Expected 4 inputs, got %s", iNDArrayArr.length);
        return (INDArray[]) ArrayUtil.filterNull(new INDArray[]{iNDArrayArr[0], this.iWeights, this.irWeights, this.iBias, iNDArrayArr[1], iNDArrayArr[2], iNDArrayArr[3], this.iPeepholeWeights});
    }

    public boolean hasBias() {
        return (this.bias == null && this.iBias == null) ? false : true;
    }

    public boolean hasPH() {
        return (this.peepholeWeights == null && this.iPeepholeWeights == null) ? false : true;
    }

    LSTMLayerWeights(SDVariable sDVariable, INDArray iNDArray, SDVariable sDVariable2, INDArray iNDArray2, SDVariable sDVariable3, INDArray iNDArray3, SDVariable sDVariable4, INDArray iNDArray4) {
        this.weights = sDVariable;
        this.iWeights = iNDArray;
        this.rWeights = sDVariable2;
        this.irWeights = iNDArray2;
        this.peepholeWeights = sDVariable3;
        this.iPeepholeWeights = iNDArray3;
        this.bias = sDVariable4;
        this.iBias = iNDArray4;
    }

    public static LSTMLayerWeightsBuilder builder() {
        return new LSTMLayerWeightsBuilder();
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LSTMLayerWeights)) {
            return false;
        }
        LSTMLayerWeights lSTMLayerWeights = (LSTMLayerWeights) obj;
        if (!lSTMLayerWeights.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        SDVariable weights = getWeights();
        SDVariable weights2 = lSTMLayerWeights.getWeights();
        if (weights == null) {
            if (weights2 != null) {
                return false;
            }
        } else if (!weights.equals(weights2)) {
            return false;
        }
        INDArray iWeights = getIWeights();
        INDArray iWeights2 = lSTMLayerWeights.getIWeights();
        if (iWeights == null) {
            if (iWeights2 != null) {
                return false;
            }
        } else if (!iWeights.equals(iWeights2)) {
            return false;
        }
        SDVariable rWeights = getRWeights();
        SDVariable rWeights2 = lSTMLayerWeights.getRWeights();
        if (rWeights == null) {
            if (rWeights2 != null) {
                return false;
            }
        } else if (!rWeights.equals(rWeights2)) {
            return false;
        }
        INDArray irWeights = getIrWeights();
        INDArray irWeights2 = lSTMLayerWeights.getIrWeights();
        if (irWeights == null) {
            if (irWeights2 != null) {
                return false;
            }
        } else if (!irWeights.equals(irWeights2)) {
            return false;
        }
        SDVariable peepholeWeights = getPeepholeWeights();
        SDVariable peepholeWeights2 = lSTMLayerWeights.getPeepholeWeights();
        if (peepholeWeights == null) {
            if (peepholeWeights2 != null) {
                return false;
            }
        } else if (!peepholeWeights.equals(peepholeWeights2)) {
            return false;
        }
        INDArray iPeepholeWeights = getIPeepholeWeights();
        INDArray iPeepholeWeights2 = lSTMLayerWeights.getIPeepholeWeights();
        if (iPeepholeWeights == null) {
            if (iPeepholeWeights2 != null) {
                return false;
            }
        } else if (!iPeepholeWeights.equals(iPeepholeWeights2)) {
            return false;
        }
        SDVariable bias = getBias();
        SDVariable bias2 = lSTMLayerWeights.getBias();
        if (bias == null) {
            if (bias2 != null) {
                return false;
            }
        } else if (!bias.equals(bias2)) {
            return false;
        }
        INDArray iBias = getIBias();
        INDArray iBias2 = lSTMLayerWeights.getIBias();
        return iBias == null ? iBias2 == null : iBias.equals(iBias2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LSTMLayerWeights;
    }

    public int hashCode() {
        int hashCode = super.hashCode();
        SDVariable weights = getWeights();
        int hashCode2 = (hashCode * 59) + (weights == null ? 43 : weights.hashCode());
        INDArray iWeights = getIWeights();
        int hashCode3 = (hashCode2 * 59) + (iWeights == null ? 43 : iWeights.hashCode());
        SDVariable rWeights = getRWeights();
        int hashCode4 = (hashCode3 * 59) + (rWeights == null ? 43 : rWeights.hashCode());
        INDArray irWeights = getIrWeights();
        int hashCode5 = (hashCode4 * 59) + (irWeights == null ? 43 : irWeights.hashCode());
        SDVariable peepholeWeights = getPeepholeWeights();
        int hashCode6 = (hashCode5 * 59) + (peepholeWeights == null ? 43 : peepholeWeights.hashCode());
        INDArray iPeepholeWeights = getIPeepholeWeights();
        int hashCode7 = (hashCode6 * 59) + (iPeepholeWeights == null ? 43 : iPeepholeWeights.hashCode());
        SDVariable bias = getBias();
        int hashCode8 = (hashCode7 * 59) + (bias == null ? 43 : bias.hashCode());
        INDArray iBias = getIBias();
        return (hashCode8 * 59) + (iBias == null ? 43 : iBias.hashCode());
    }

    public SDVariable getWeights() {
        return this.weights;
    }

    public INDArray getIWeights() {
        return this.iWeights;
    }

    public SDVariable getRWeights() {
        return this.rWeights;
    }

    public INDArray getIrWeights() {
        return this.irWeights;
    }

    public SDVariable getPeepholeWeights() {
        return this.peepholeWeights;
    }

    public INDArray getIPeepholeWeights() {
        return this.iPeepholeWeights;
    }

    public SDVariable getBias() {
        return this.bias;
    }

    public INDArray getIBias() {
        return this.iBias;
    }

    public void setWeights(SDVariable sDVariable) {
        this.weights = sDVariable;
    }

    public void setIWeights(INDArray iNDArray) {
        this.iWeights = iNDArray;
    }

    public void setRWeights(SDVariable sDVariable) {
        this.rWeights = sDVariable;
    }

    public void setIrWeights(INDArray iNDArray) {
        this.irWeights = iNDArray;
    }

    public void setPeepholeWeights(SDVariable sDVariable) {
        this.peepholeWeights = sDVariable;
    }

    public void setIPeepholeWeights(INDArray iNDArray) {
        this.iPeepholeWeights = iNDArray;
    }

    public void setBias(SDVariable sDVariable) {
        this.bias = sDVariable;
    }

    public void setIBias(INDArray iNDArray) {
        this.iBias = iNDArray;
    }

    public String toString() {
        return "LSTMLayerWeights(weights=" + getWeights() + ", iWeights=" + getIWeights() + ", rWeights=" + getRWeights() + ", irWeights=" + getIrWeights() + ", peepholeWeights=" + getPeepholeWeights() + ", iPeepholeWeights=" + getIPeepholeWeights() + ", bias=" + getBias() + ", iBias=" + getIBias() + ")";
    }
}
