package org.nd4j.linalg.learning;

import java.util.Collections;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;

/* loaded from: input_file:org/nd4j/linalg/learning/RmsPropUpdater.class */
public class RmsPropUpdater implements GradientUpdater<RmsProp> {
    public static final String G_STATE = "G";
    private final RmsProp config;
    private INDArray lastGradient;
    private char gradientReshapeOrder;

    public RmsPropUpdater(RmsProp rmsProp) {
        this.config = rmsProp;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void setState(@NonNull Map<String, INDArray> map, boolean z) {
        if (map == null) {
            throw new NullPointerException("stateMap is marked non-null but is null");
        }
        if (!map.containsKey(G_STATE) || map.size() != 1) {
            throw new IllegalStateException("State map should contain only key [G] but has keys " + map.keySet());
        }
        this.lastGradient = map.get(G_STATE);
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public Map<String, INDArray> getState() {
        return Collections.singletonMap(G_STATE, this.lastGradient);
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void setStateViewArray(INDArray iNDArray, long[] jArr, char c, boolean z) {
        if (!iNDArray.isRowVectorOrScalar()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (z) {
            iNDArray.assign(Double.valueOf(this.config.getEpsilon()));
        }
        this.lastGradient = iNDArray;
        this.lastGradient = Shape.newShapeNoCopy(this.lastGradient, jArr, c == 'f');
        if (this.lastGradient == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view array");
        }
        this.gradientReshapeOrder = c;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void applyUpdater(INDArray iNDArray, int i, int i2) {
        if (this.lastGradient == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater(iNDArray.reshape(this.lastGradient.shape()), this.lastGradient, this.config.getLearningRate(i, i2), this.config.getRmsDecay(), this.config.getEpsilon()));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.learning.GradientUpdater
    public RmsProp getConfig() {
        return this.config;
    }

    public INDArray getLastGradient() {
        return this.lastGradient;
    }

    public char getGradientReshapeOrder() {
        return this.gradientReshapeOrder;
    }

    public void setLastGradient(INDArray iNDArray) {
        this.lastGradient = iNDArray;
    }

    public void setGradientReshapeOrder(char c) {
        this.gradientReshapeOrder = c;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof RmsPropUpdater)) {
            return false;
        }
        RmsPropUpdater rmsPropUpdater = (RmsPropUpdater) obj;
        if (!rmsPropUpdater.canEqual(this) || getGradientReshapeOrder() != rmsPropUpdater.getGradientReshapeOrder()) {
            return false;
        }
        RmsProp config = getConfig();
        RmsProp config2 = rmsPropUpdater.getConfig();
        if (config == null) {
            if (config2 != null) {
                return false;
            }
        } else if (!config.equals(config2)) {
            return false;
        }
        INDArray lastGradient = getLastGradient();
        INDArray lastGradient2 = rmsPropUpdater.getLastGradient();
        return lastGradient == null ? lastGradient2 == null : lastGradient.equals(lastGradient2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof RmsPropUpdater;
    }

    public int hashCode() {
        int gradientReshapeOrder = (1 * 59) + getGradientReshapeOrder();
        RmsProp config = getConfig();
        int hashCode = (gradientReshapeOrder * 59) + (config == null ? 43 : config.hashCode());
        INDArray lastGradient = getLastGradient();
        return (hashCode * 59) + (lastGradient == null ? 43 : lastGradient.hashCode());
    }

    public String toString() {
        return "RmsPropUpdater(config=" + getConfig() + ", lastGradient=" + getLastGradient() + ", gradientReshapeOrder=" + getGradientReshapeOrder() + ")";
    }
}
