package org.nd4j.linalg.learning;

import java.util.HashMap;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.AdaBelief;

/* loaded from: input_file:org/nd4j/linalg/learning/AdaBeliefUpdater.class */
public class AdaBeliefUpdater implements GradientUpdater<AdaBelief> {
    public static final String M_STATE = "M";
    public static final String S_STATE = "S";
    private AdaBelief config;
    private INDArray m;
    private INDArray s;
    private char gradientReshapeOrder;

    public AdaBeliefUpdater(AdaBelief adaBelief) {
        this.config = adaBelief;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void setState(@NonNull Map<String, INDArray> map, boolean z) {
        if (map == null) {
            throw new NullPointerException("stateMap is marked non-null but is null");
        }
        if (!map.containsKey("M") || !map.containsKey(S_STATE) || map.size() != 2) {
            throw new IllegalStateException("State map should contain only keys [M,S] but has keys " + map.keySet());
        }
        this.m = map.get("M");
        this.s = map.get(S_STATE);
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public Map<String, INDArray> getState() {
        HashMap hashMap = new HashMap();
        hashMap.put("M", this.m);
        hashMap.put(S_STATE, this.s);
        return hashMap;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void setStateViewArray(INDArray iNDArray, long[] jArr, char c, boolean z) {
        if (!iNDArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (z) {
            iNDArray.assign((Number) 0);
        }
        long length = iNDArray.length();
        this.m = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0L, length / 2));
        this.s = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(length / 2, length));
        this.m = Shape.newShapeNoCopy(this.m, jArr, c == 'f');
        this.s = Shape.newShapeNoCopy(this.s, jArr, c == 'f');
        if (this.m == null || this.s == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view arrays");
        }
        this.gradientReshapeOrder = c;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void applyUpdater(INDArray iNDArray, int i, int i2) {
        if (this.m == null || this.s == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        double beta1 = this.config.getBeta1();
        double beta2 = this.config.getBeta2();
        Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaBeliefUpdater(iNDArray.reshape(this.s.shape()), this.s, this.m, this.config.getLearningRate(i, i2), beta1, beta2, this.config.getEpsilon(), i));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.learning.GradientUpdater
    public AdaBelief getConfig() {
        return this.config;
    }

    public INDArray getM() {
        return this.m;
    }

    public INDArray getS() {
        return this.s;
    }

    public char getGradientReshapeOrder() {
        return this.gradientReshapeOrder;
    }

    public void setConfig(AdaBelief adaBelief) {
        this.config = adaBelief;
    }

    public void setM(INDArray iNDArray) {
        this.m = iNDArray;
    }

    public void setS(INDArray iNDArray) {
        this.s = iNDArray;
    }

    public void setGradientReshapeOrder(char c) {
        this.gradientReshapeOrder = c;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AdaBeliefUpdater)) {
            return false;
        }
        AdaBeliefUpdater adaBeliefUpdater = (AdaBeliefUpdater) obj;
        if (!adaBeliefUpdater.canEqual(this) || getGradientReshapeOrder() != adaBeliefUpdater.getGradientReshapeOrder()) {
            return false;
        }
        AdaBelief config = getConfig();
        AdaBelief config2 = adaBeliefUpdater.getConfig();
        if (config == null) {
            if (config2 != null) {
                return false;
            }
        } else if (!config.equals(config2)) {
            return false;
        }
        INDArray m = getM();
        INDArray m2 = adaBeliefUpdater.getM();
        if (m == null) {
            if (m2 != null) {
                return false;
            }
        } else if (!m.equals(m2)) {
            return false;
        }
        INDArray s = getS();
        INDArray s2 = adaBeliefUpdater.getS();
        return s == null ? s2 == null : s.equals(s2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof AdaBeliefUpdater;
    }

    public int hashCode() {
        int gradientReshapeOrder = (1 * 59) + getGradientReshapeOrder();
        AdaBelief config = getConfig();
        int hashCode = (gradientReshapeOrder * 59) + (config == null ? 43 : config.hashCode());
        INDArray m = getM();
        int hashCode2 = (hashCode * 59) + (m == null ? 43 : m.hashCode());
        INDArray s = getS();
        return (hashCode2 * 59) + (s == null ? 43 : s.hashCode());
    }

    public String toString() {
        return "AdaBeliefUpdater(config=" + getConfig() + ", m=" + getM() + ", s=" + getS() + ", gradientReshapeOrder=" + getGradientReshapeOrder() + ")";
    }
}
