package org.nd4j.autodiff.samediff.optimize.optimizations;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;

/* loaded from: input_file:org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.class */
public class OptimizationUtils {
    private OptimizationUtils() {
    }

    public static void replaceOpInputsWith(SameDiff sameDiff, @NonNull String str, @NonNull String str2) {
        if (str == null) {
            throw new NullPointerException("replaceInput is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("newInput is marked non-null but is null");
        }
        if (str.equals(str2)) {
            return;
        }
        Iterator<SameDiffOp> it = sameDiff.getOps().values().iterator();
        while (it.hasNext()) {
            List<String> inputsToOp = it.next().getInputsToOp();
            while (inputsToOp != null && inputsToOp.contains(str)) {
                inputsToOp.set(inputsToOp.indexOf(str), str2);
            }
        }
        Variable variable = (Variable) sameDiff.getVariables().get(str);
        ((Variable) sameDiff.getVariables().get(str2)).setInputsForOp(variable.getInputsForOp());
        variable.setInputsForOp(new ArrayList());
    }

    public static void removeOp(@NonNull SameDiff sameDiff, @NonNull String str) {
        if (sameDiff == null) {
            throw new NullPointerException("sd is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("opToRemove is marked non-null but is null");
        }
        SameDiffOp remove = sameDiff.getOps().remove(str);
        Iterator<String> it = remove.getInputsToOp().iterator();
        while (it.hasNext()) {
            ((Variable) sameDiff.getVariables().get(it.next())).getInputsForOp().remove(remove.getName());
        }
    }

    public static void removeVariable(@NonNull SameDiff sameDiff, @NonNull String str) {
        if (sameDiff == null) {
            throw new NullPointerException("sd is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("varToRemove is marked non-null but is null");
        }
        sameDiff.getVariables().remove(str);
    }
}
