package org.nd4j.autodiff.samediff;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.custom.Invoke;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.collect.Sets;

/* loaded from: input_file:org/nd4j/autodiff/samediff/ControlFlow.class */
public class ControlFlow {

    /* loaded from: input_file:org/nd4j/autodiff/samediff/ControlFlow$LoopArgs.class */
    public static class LoopArgs {
        private SDVariable condIn;
        private SDVariable maxIters;
        private SDVariable startIter;
        private SDVariable[] extraArgs;

        /* loaded from: input_file:org/nd4j/autodiff/samediff/ControlFlow$LoopArgs$LoopArgsBuilder.class */
        public static class LoopArgsBuilder {
            private SDVariable condIn;
            private SDVariable maxIters;
            private SDVariable startIter;
            private SDVariable[] extraArgs;

            LoopArgsBuilder() {
            }

            public LoopArgsBuilder condIn(SDVariable sDVariable) {
                this.condIn = sDVariable;
                return this;
            }

            public LoopArgsBuilder maxIters(SDVariable sDVariable) {
                this.maxIters = sDVariable;
                return this;
            }

            public LoopArgsBuilder startIter(SDVariable sDVariable) {
                this.startIter = sDVariable;
                return this;
            }

            public LoopArgsBuilder extraArgs(SDVariable[] sDVariableArr) {
                this.extraArgs = sDVariableArr;
                return this;
            }

            public LoopArgs build() {
                return new LoopArgs(this.condIn, this.maxIters, this.startIter, this.extraArgs);
            }

            public String toString() {
                return "ControlFlow.LoopArgs.LoopArgsBuilder(condIn=" + this.condIn + ", maxIters=" + this.maxIters + ", startIter=" + this.startIter + ", extraArgs=" + Arrays.deepToString(this.extraArgs) + ")";
            }
        }

        public SDVariable[] toArgs() {
            SDVariable[] sDVariableArr = new SDVariable[3 + this.extraArgs.length];
            sDVariableArr[0] = this.startIter;
            sDVariableArr[1] = this.maxIters;
            sDVariableArr[2] = this.condIn;
            for (int i = 0; i < this.extraArgs.length; i++) {
                sDVariableArr[i + 3] = this.extraArgs[i];
            }
            return sDVariableArr;
        }

        LoopArgs(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable[] sDVariableArr) {
            this.condIn = sDVariable;
            this.maxIters = sDVariable2;
            this.startIter = sDVariable3;
            this.extraArgs = sDVariableArr;
        }

        public static LoopArgsBuilder builder() {
            return new LoopArgsBuilder();
        }

        public SDVariable getCondIn() {
            return this.condIn;
        }

        public SDVariable getMaxIters() {
            return this.maxIters;
        }

        public SDVariable getStartIter() {
            return this.startIter;
        }

        public SDVariable[] getExtraArgs() {
            return this.extraArgs;
        }

        public void setCondIn(SDVariable sDVariable) {
            this.condIn = sDVariable;
        }

        public void setMaxIters(SDVariable sDVariable) {
            this.maxIters = sDVariable;
        }

        public void setStartIter(SDVariable sDVariable) {
            this.startIter = sDVariable;
        }

        public void setExtraArgs(SDVariable[] sDVariableArr) {
            this.extraArgs = sDVariableArr;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof LoopArgs)) {
                return false;
            }
            LoopArgs loopArgs = (LoopArgs) obj;
            if (!loopArgs.canEqual(this)) {
                return false;
            }
            SDVariable condIn = getCondIn();
            SDVariable condIn2 = loopArgs.getCondIn();
            if (condIn == null) {
                if (condIn2 != null) {
                    return false;
                }
            } else if (!condIn.equals(condIn2)) {
                return false;
            }
            SDVariable maxIters = getMaxIters();
            SDVariable maxIters2 = loopArgs.getMaxIters();
            if (maxIters == null) {
                if (maxIters2 != null) {
                    return false;
                }
            } else if (!maxIters.equals(maxIters2)) {
                return false;
            }
            SDVariable startIter = getStartIter();
            SDVariable startIter2 = loopArgs.getStartIter();
            if (startIter == null) {
                if (startIter2 != null) {
                    return false;
                }
            } else if (!startIter.equals(startIter2)) {
                return false;
            }
            return Arrays.deepEquals(getExtraArgs(), loopArgs.getExtraArgs());
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof LoopArgs;
        }

        public int hashCode() {
            SDVariable condIn = getCondIn();
            int hashCode = (1 * 59) + (condIn == null ? 43 : condIn.hashCode());
            SDVariable maxIters = getMaxIters();
            int hashCode2 = (hashCode * 59) + (maxIters == null ? 43 : maxIters.hashCode());
            SDVariable startIter = getStartIter();
            return (((hashCode2 * 59) + (startIter == null ? 43 : startIter.hashCode())) * 59) + Arrays.deepHashCode(getExtraArgs());
        }

        public String toString() {
            return "ControlFlow.LoopArgs(condIn=" + getCondIn() + ", maxIters=" + getMaxIters() + ", startIter=" + getStartIter() + ", extraArgs=" + Arrays.deepToString(getExtraArgs()) + ")";
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/ControlFlow$LoopLambdaArgs.class */
    public static class LoopLambdaArgs {
        private SDVariable iterStart;
        private SDVariable iterCount;
        private SDVariable condIn;
        private SDVariable[] extraArgs;

        /* loaded from: input_file:org/nd4j/autodiff/samediff/ControlFlow$LoopLambdaArgs$LoopLambdaArgsBuilder.class */
        public static class LoopLambdaArgsBuilder {
            private SDVariable iterStart;
            private SDVariable iterCount;
            private SDVariable[] extraArgs;
            private SDVariable condIn;

            LoopLambdaArgsBuilder() {
            }

            public LoopLambdaArgsBuilder iterStart(SDVariable sDVariable) {
                this.iterStart = sDVariable;
                return this;
            }

            public LoopLambdaArgsBuilder iterCount(SDVariable sDVariable) {
                this.iterCount = sDVariable;
                return this;
            }

            public LoopLambdaArgsBuilder extraArgs(SDVariable[] sDVariableArr) {
                this.extraArgs = sDVariableArr;
                return this;
            }

            public LoopLambdaArgsBuilder condIn(SDVariable sDVariable) {
                this.condIn = sDVariable;
                return this;
            }

            public LoopLambdaArgs build() {
                return new LoopLambdaArgs(this.iterStart, this.iterCount, this.extraArgs, this.condIn);
            }

            public String toString() {
                return "ControlFlow.LoopLambdaArgs.LoopLambdaArgsBuilder(iterStart=" + this.iterStart + ", iterCount=" + this.iterCount + ", extraArgs=" + Arrays.deepToString(this.extraArgs) + ", condIn=" + this.condIn + ")";
            }
        }

        public LoopLambdaArgs(SDVariable sDVariable, SDVariable sDVariable2, SDVariable[] sDVariableArr, SDVariable sDVariable3) {
            if (sDVariable3.dataType() != DataType.BOOL) {
                throw new IllegalArgumentException("Data type for condition must be boolean!");
            }
            if (!sDVariable2.dataType().isNumerical()) {
                throw new IllegalArgumentException("Data type for condition must be numerical!");
            }
            this.iterCount = sDVariable2;
            this.extraArgs = sDVariableArr;
            this.condIn = sDVariable3;
            this.iterStart = sDVariable;
        }

        public Invoke.InvokeParams invokeParams(String str, String[] strArr, String[] strArr2) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(this.iterStart);
            arrayList.add(this.iterCount);
            arrayList.add(this.condIn);
            arrayList.addAll(Arrays.asList(this.extraArgs));
            return Invoke.InvokeParams.builder().functionName(str).inputs((SDVariable[]) arrayList.toArray(new SDVariable[arrayList.size()])).subGraphInputVarNames(strArr).subGraphOutputVarNames(strArr2).inputVarNames((String[]) ((List) arrayList.stream().map(sDVariable -> {
                return sDVariable.name();
            }).collect(Collectors.toList())).toArray(new String[arrayList.size()])).build();
        }

        public static LoopLambdaArgsBuilder builder() {
            return new LoopLambdaArgsBuilder();
        }

        public SDVariable getIterStart() {
            return this.iterStart;
        }

        public SDVariable getIterCount() {
            return this.iterCount;
        }

        public SDVariable getCondIn() {
            return this.condIn;
        }

        public SDVariable[] getExtraArgs() {
            return this.extraArgs;
        }

        public void setIterStart(SDVariable sDVariable) {
            this.iterStart = sDVariable;
        }

        public void setIterCount(SDVariable sDVariable) {
            this.iterCount = sDVariable;
        }

        public void setCondIn(SDVariable sDVariable) {
            this.condIn = sDVariable;
        }

        public void setExtraArgs(SDVariable[] sDVariableArr) {
            this.extraArgs = sDVariableArr;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof LoopLambdaArgs)) {
                return false;
            }
            LoopLambdaArgs loopLambdaArgs = (LoopLambdaArgs) obj;
            if (!loopLambdaArgs.canEqual(this)) {
                return false;
            }
            SDVariable iterStart = getIterStart();
            SDVariable iterStart2 = loopLambdaArgs.getIterStart();
            if (iterStart == null) {
                if (iterStart2 != null) {
                    return false;
                }
            } else if (!iterStart.equals(iterStart2)) {
                return false;
            }
            SDVariable iterCount = getIterCount();
            SDVariable iterCount2 = loopLambdaArgs.getIterCount();
            if (iterCount == null) {
                if (iterCount2 != null) {
                    return false;
                }
            } else if (!iterCount.equals(iterCount2)) {
                return false;
            }
            SDVariable condIn = getCondIn();
            SDVariable condIn2 = loopLambdaArgs.getCondIn();
            if (condIn == null) {
                if (condIn2 != null) {
                    return false;
                }
            } else if (!condIn.equals(condIn2)) {
                return false;
            }
            return Arrays.deepEquals(getExtraArgs(), loopLambdaArgs.getExtraArgs());
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof LoopLambdaArgs;
        }

        public int hashCode() {
            SDVariable iterStart = getIterStart();
            int hashCode = (1 * 59) + (iterStart == null ? 43 : iterStart.hashCode());
            SDVariable iterCount = getIterCount();
            int hashCode2 = (hashCode * 59) + (iterCount == null ? 43 : iterCount.hashCode());
            SDVariable condIn = getCondIn();
            return (((hashCode2 * 59) + (condIn == null ? 43 : condIn.hashCode())) * 59) + Arrays.deepHashCode(getExtraArgs());
        }

        public String toString() {
            return "ControlFlow.LoopLambdaArgs(iterStart=" + getIterStart() + ", iterCount=" + getIterCount() + ", condIn=" + getCondIn() + ", extraArgs=" + Arrays.deepToString(getExtraArgs()) + ")";
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/ControlFlow$LoopParams.class */
    public static class LoopParams {
        private String[] outputVarNames;
        private String loopName;
        private SameDiff parent;
        private SameDiff functionBody;
        private String functionName;
        private SDVariable[] loopVars;
        private String[] functionBodyInputs;
        private String[] functionBodyOutputs;

        /* loaded from: input_file:org/nd4j/autodiff/samediff/ControlFlow$LoopParams$LoopParamsBuilder.class */
        public static class LoopParamsBuilder {
            private String[] outputVarNames;
            private String loopName;
            private SameDiff parent;
            private SameDiff functionBody;
            private String functionName;
            private SDVariable[] loopVars;
            private String[] functionBodyInputs;
            private String[] functionBodyOutputs;

            LoopParamsBuilder() {
            }

            public LoopParamsBuilder outputVarNames(String[] strArr) {
                this.outputVarNames = strArr;
                return this;
            }

            public LoopParamsBuilder loopName(String str) {
                this.loopName = str;
                return this;
            }

            public LoopParamsBuilder parent(SameDiff sameDiff) {
                this.parent = sameDiff;
                return this;
            }

            public LoopParamsBuilder functionBody(SameDiff sameDiff) {
                this.functionBody = sameDiff;
                return this;
            }

            public LoopParamsBuilder functionName(String str) {
                this.functionName = str;
                return this;
            }

            public LoopParamsBuilder loopVars(SDVariable[] sDVariableArr) {
                this.loopVars = sDVariableArr;
                return this;
            }

            public LoopParamsBuilder functionBodyInputs(String[] strArr) {
                this.functionBodyInputs = strArr;
                return this;
            }

            public LoopParamsBuilder functionBodyOutputs(String[] strArr) {
                this.functionBodyOutputs = strArr;
                return this;
            }

            public LoopParams build() {
                return new LoopParams(this.outputVarNames, this.loopName, this.parent, this.functionBody, this.functionName, this.loopVars, this.functionBodyInputs, this.functionBodyOutputs);
            }

            public String toString() {
                return "ControlFlow.LoopParams.LoopParamsBuilder(outputVarNames=" + Arrays.deepToString(this.outputVarNames) + ", loopName=" + this.loopName + ", parent=" + this.parent + ", functionBody=" + this.functionBody + ", functionName=" + this.functionName + ", loopVars=" + Arrays.deepToString(this.loopVars) + ", functionBodyInputs=" + Arrays.deepToString(this.functionBodyInputs) + ", functionBodyOutputs=" + Arrays.deepToString(this.functionBodyOutputs) + ")";
            }
        }

        LoopParams(String[] strArr, String str, SameDiff sameDiff, SameDiff sameDiff2, String str2, SDVariable[] sDVariableArr, String[] strArr2, String[] strArr3) {
            this.outputVarNames = strArr;
            this.loopName = str;
            this.parent = sameDiff;
            this.functionBody = sameDiff2;
            this.functionName = str2;
            this.loopVars = sDVariableArr;
            this.functionBodyInputs = strArr2;
            this.functionBodyOutputs = strArr3;
        }

        public static LoopParamsBuilder builder() {
            return new LoopParamsBuilder();
        }

        public String[] getOutputVarNames() {
            return this.outputVarNames;
        }

        public String getLoopName() {
            return this.loopName;
        }

        public SameDiff getParent() {
            return this.parent;
        }

        public SameDiff getFunctionBody() {
            return this.functionBody;
        }

        public String getFunctionName() {
            return this.functionName;
        }

        public SDVariable[] getLoopVars() {
            return this.loopVars;
        }

        public String[] getFunctionBodyInputs() {
            return this.functionBodyInputs;
        }

        public String[] getFunctionBodyOutputs() {
            return this.functionBodyOutputs;
        }

        public void setOutputVarNames(String[] strArr) {
            this.outputVarNames = strArr;
        }

        public void setLoopName(String str) {
            this.loopName = str;
        }

        public void setParent(SameDiff sameDiff) {
            this.parent = sameDiff;
        }

        public void setFunctionBody(SameDiff sameDiff) {
            this.functionBody = sameDiff;
        }

        public void setFunctionName(String str) {
            this.functionName = str;
        }

        public void setLoopVars(SDVariable[] sDVariableArr) {
            this.loopVars = sDVariableArr;
        }

        public void setFunctionBodyInputs(String[] strArr) {
            this.functionBodyInputs = strArr;
        }

        public void setFunctionBodyOutputs(String[] strArr) {
            this.functionBodyOutputs = strArr;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof LoopParams)) {
                return false;
            }
            LoopParams loopParams = (LoopParams) obj;
            if (!loopParams.canEqual(this) || !Arrays.deepEquals(getOutputVarNames(), loopParams.getOutputVarNames())) {
                return false;
            }
            String loopName = getLoopName();
            String loopName2 = loopParams.getLoopName();
            if (loopName == null) {
                if (loopName2 != null) {
                    return false;
                }
            } else if (!loopName.equals(loopName2)) {
                return false;
            }
            SameDiff parent = getParent();
            SameDiff parent2 = loopParams.getParent();
            if (parent == null) {
                if (parent2 != null) {
                    return false;
                }
            } else if (!parent.equals(parent2)) {
                return false;
            }
            SameDiff functionBody = getFunctionBody();
            SameDiff functionBody2 = loopParams.getFunctionBody();
            if (functionBody == null) {
                if (functionBody2 != null) {
                    return false;
                }
            } else if (!functionBody.equals(functionBody2)) {
                return false;
            }
            String functionName = getFunctionName();
            String functionName2 = loopParams.getFunctionName();
            if (functionName == null) {
                if (functionName2 != null) {
                    return false;
                }
            } else if (!functionName.equals(functionName2)) {
                return false;
            }
            return Arrays.deepEquals(getLoopVars(), loopParams.getLoopVars()) && Arrays.deepEquals(getFunctionBodyInputs(), loopParams.getFunctionBodyInputs()) && Arrays.deepEquals(getFunctionBodyOutputs(), loopParams.getFunctionBodyOutputs());
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof LoopParams;
        }

        public int hashCode() {
            int deepHashCode = (1 * 59) + Arrays.deepHashCode(getOutputVarNames());
            String loopName = getLoopName();
            int hashCode = (deepHashCode * 59) + (loopName == null ? 43 : loopName.hashCode());
            SameDiff parent = getParent();
            int hashCode2 = (hashCode * 59) + (parent == null ? 43 : parent.hashCode());
            SameDiff functionBody = getFunctionBody();
            int hashCode3 = (hashCode2 * 59) + (functionBody == null ? 43 : functionBody.hashCode());
            String functionName = getFunctionName();
            return (((((((hashCode3 * 59) + (functionName == null ? 43 : functionName.hashCode())) * 59) + Arrays.deepHashCode(getLoopVars())) * 59) + Arrays.deepHashCode(getFunctionBodyInputs())) * 59) + Arrays.deepHashCode(getFunctionBodyOutputs());
        }

        public String toString() {
            return "ControlFlow.LoopParams(outputVarNames=" + Arrays.deepToString(getOutputVarNames()) + ", loopName=" + getLoopName() + ", parent=" + getParent() + ", functionBody=" + getFunctionBody() + ", functionName=" + getFunctionName() + ", loopVars=" + Arrays.deepToString(getLoopVars()) + ", functionBodyInputs=" + Arrays.deepToString(getFunctionBodyInputs()) + ", functionBodyOutputs=" + Arrays.deepToString(getFunctionBodyOutputs()) + ")";
        }
    }

    public static SDVariable[] initializeLoopBody(String[] strArr, SameDiff sameDiff, int i) {
        Preconditions.checkState(strArr != null && strArr.length == 2, "Number of input names must be 2.");
        return new SDVariable[]{sameDiff.constant(strArr[1], i), sameDiff.var(strArr[0], Nd4j.zeros(1))};
    }

    public static SDVariable[] initializeLoopBody(String[] strArr, SameDiff sameDiff, int i, boolean z) {
        Preconditions.checkState(strArr != null && strArr.length == 3, "Number of input names must be 3.");
        return new SDVariable[]{sameDiff.var(strArr[0], Nd4j.zeros(1)), sameDiff.constant(strArr[1], i), sameDiff.constant(strArr[2], z)};
    }

    public static SDVariable[] args(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable[] sDVariableArr) {
        return LoopArgs.builder().extraArgs(sDVariableArr).condIn(sDVariable2).maxIters(sDVariable).startIter(sDVariable3).build().toArgs();
    }

    public static SDVariable ifCond(SameDiff sameDiff, String str, String str2, @NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda, @NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda2, @NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda3) {
        if (sameDiffNoArgSingleLambda == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (sameDiffNoArgSingleLambda2 == null) {
            throw new NullPointerException("trueBody is marked non-null but is null");
        }
        if (sameDiffNoArgSingleLambda3 == null) {
            throw new NullPointerException("falseBody is marked non-null but is null");
        }
        NameScope withNameScope = sameDiff.withNameScope(sameDiff.newBlockName(str2 == null ? "if" : str2));
        NameScope withNameScope2 = sameDiff.withNameScope("cond");
        SDVariable define = sameDiffNoArgSingleLambda.define(sameDiff);
        withNameScope2.close();
        if (define.dataType() != DataType.BOOL) {
            Iterator<SDVariable> it = sameDiff.getVariablesInScope(withNameScope).iterator();
            while (it.hasNext()) {
                sameDiff.getVariables().remove(it.next().name());
            }
            for (SameDiffOp sameDiffOp : sameDiff.getOpsInScope(withNameScope)) {
                Iterator<String> it2 = sameDiffOp.getInputsToOp().iterator();
                while (it2.hasNext()) {
                    sameDiff.removeArgFromOp(it2.next(), sameDiffOp.getOp());
                }
                sameDiff.getOps().remove(sameDiffOp.getName());
            }
            throw new IllegalStateException("Can not use " + define.name() + " as the condition of an If statement, the condition must be a boolean.");
        }
        HashMap hashMap = new HashMap();
        HashSet newHashSet = Sets.newHashSet(sameDiff.variableMap().keySet());
        sameDiff.addArgumentInterceptor(sDVariable -> {
            if (sDVariable == null) {
                return null;
            }
            if (newHashSet == null || !newHashSet.contains(sDVariable.name())) {
                return sDVariable;
            }
            if (hashMap.containsKey(sDVariable.name())) {
                return ((SDVariable[]) hashMap.get(sDVariable.name()))[1];
            }
            SDVariable[] switchOp = sameDiff.switchOp(sDVariable, define);
            hashMap.put(sDVariable.name(), switchOp);
            return switchOp[1];
        });
        NameScope withNameScope3 = sameDiff.withNameScope("trueBody");
        SDVariable define2 = sameDiffNoArgSingleLambda2.define(sameDiff);
        sameDiff.removeArgumentInterceptor();
        if (newHashSet.contains(define2.name())) {
            SDVariable[] switchOp = sameDiff.switchOp(define2, define);
            hashMap.put(define2.name(), switchOp);
            define2 = switchOp[1];
        }
        withNameScope3.close();
        HashSet newHashSet2 = Sets.newHashSet(sameDiff.variableMap().keySet());
        sameDiff.addArgumentInterceptor(sDVariable2 -> {
            if (!newHashSet2.contains(sDVariable2.name())) {
                return sDVariable2;
            }
            if (hashMap.containsKey(sDVariable2.name())) {
                return ((SDVariable[]) hashMap.get(sDVariable2.name()))[0];
            }
            SDVariable[] switchOp2 = sameDiff.switchOp(sDVariable2, define);
            hashMap.put(sDVariable2.name(), switchOp2);
            return switchOp2[0];
        });
        NameScope withNameScope4 = sameDiff.withNameScope("falseBody");
        SDVariable define3 = sameDiffNoArgSingleLambda3.define(sameDiff);
        sameDiff.removeArgumentInterceptor();
        if (newHashSet2.contains(define3.name())) {
            SDVariable[] switchOp2 = sameDiff.switchOp(define3, define);
            hashMap.put(define3.name(), switchOp2);
            define3 = switchOp2[0];
        }
        withNameScope4.close();
        SDVariable merge = sameDiff.merge(define2, define3);
        withNameScope.close();
        return sameDiff.updateVariableNameAndReference(merge, str);
    }

    public static SDVariable[] loopWithConditions(LoopParams loopParams) {
        return loopWithConditions(loopParams.outputVarNames, loopParams.loopName, loopParams.parent, loopParams.functionBody, loopParams.functionName, loopParams.loopVars, loopParams.functionBodyInputs, loopParams.functionBodyOutputs);
    }

    public static SDVariable[] loopWithConditions(String[] strArr, String str, SameDiff sameDiff, SameDiff sameDiff2, String str2, SDVariable[] sDVariableArr, String[] strArr2, String[] strArr3) {
        Preconditions.checkState((strArr2 == null || strArr3 == null || strArr2.length != strArr3.length) ? false : true, "Sub graph input and output names must  be defined and equal in length.");
        Preconditions.checkState(sDVariableArr.length == strArr2.length, "Loop variables and function body inputs must be equal in length.");
        for (SDVariable sDVariable : sDVariableArr) {
            if (sDVariable.getSameDiff() != sameDiff) {
                throw new IllegalArgumentException("Variable named " + sDVariable.name() + " does not have correct samediff instance. Must have parent outer samediff instance.");
            }
        }
        return sameDiff.whileLoop(strArr, str, sDVariableArr, condBody(), loopBody(sameDiff, sameDiff2, str2, strArr2, strArr3));
    }

    public static LoopLambdaArgs argsFromInputs(SDVariable[] sDVariableArr) {
        SDVariable[] sDVariableArr2 = sDVariableArr.length > 3 ? new SDVariable[sDVariableArr.length - 3] : new SDVariable[0];
        if (sDVariableArr2.length > 0) {
            for (int i = 0; i < sDVariableArr2.length; i++) {
                sDVariableArr2[i] = sDVariableArr[i + 3];
            }
        }
        return LoopLambdaArgs.builder().iterCount(sDVariableArr[1]).iterStart(sDVariableArr[0]).condIn(sDVariableArr[2]).extraArgs(sDVariableArr2).build();
    }

    public static SameDiffLambda loopBody(SameDiff sameDiff, SameDiff sameDiff2, String str, String[] strArr, String[] strArr2) {
        Preconditions.checkState((strArr == null || strArr2 == null || strArr.length != strArr2.length) ? false : true, "Sub graph input and output names must  be defined and equal in length.");
        sameDiff.putSubFunction(str, sameDiff2);
        return (sameDiff3, sDVariableArr) -> {
            SDVariable[] invoke = sameDiff3.invoke(argsFromInputs(sDVariableArr).invokeParams(str, strArr, strArr2));
            ArrayList arrayList = new ArrayList();
            arrayList.add(sDVariableArr[0].add(1.0d));
            arrayList.add(sDVariableArr[1]);
            arrayList.add(invoke[2]);
            for (int i = 3; i < invoke.length; i++) {
                arrayList.add(invoke[i]);
            }
            return (SDVariable[]) arrayList.toArray(new SDVariable[arrayList.size()]);
        };
    }

    public static SDVariable[] whileLoop(SameDiff sameDiff, String[] strArr, String str, @NonNull SDVariable[] sDVariableArr, @NonNull SameDiffSingleLambda sameDiffSingleLambda, @NonNull SameDiffLambda sameDiffLambda) {
        if (sDVariableArr == null) {
            throw new NullPointerException("loopVars is marked non-null but is null");
        }
        if (sameDiffSingleLambda == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (sameDiffLambda == null) {
            throw new NullPointerException("body is marked non-null but is null");
        }
        String newBlockName = sameDiff.newBlockName(str == null ? "while" : str);
        NameScope withNameScope = sameDiff.withNameScope(newBlockName);
        SDVariable scalar = sameDiff.scalar(sameDiff.generateNewVarName("counter", 0), 0);
        SDVariable[] sDVariableArr2 = new SDVariable[sDVariableArr.length];
        for (int i = 0; i < sDVariableArr.length; i++) {
            sDVariableArr2[i] = new Enter(sameDiff, newBlockName, sDVariableArr[i]).outputVariable();
        }
        SDVariable[] sDVariableArr3 = new SDVariable[sDVariableArr.length];
        Merge[] mergeArr = new Merge[sDVariableArr.length];
        for (int i2 = 0; i2 < sDVariableArr.length; i2++) {
            mergeArr[i2] = new Merge(sameDiff, sDVariableArr2[i2], sDVariableArr2[i2]);
            sDVariableArr3[i2] = mergeArr[i2].outputVariable();
        }
        Merge merge = new Merge(sameDiff, scalar, scalar);
        SDVariable outputVariable = merge.outputVariable();
        NameScope withNameScope2 = sameDiff.withNameScope("cond");
        SDVariable define = sameDiffSingleLambda.define(sameDiff, sDVariableArr3);
        withNameScope2.close();
        if (define.dataType() != DataType.BOOL) {
            throw new IllegalStateException("Can not use " + define.name() + " as the condition of an While loop, the condition must be a boolean.");
        }
        HashSet newHashSet = Sets.newHashSet();
        SDVariable[] sDVariableArr4 = new SDVariable[sDVariableArr.length];
        SDVariable[] sDVariableArr5 = new SDVariable[sDVariableArr.length];
        for (int i3 = 0; i3 < sDVariableArr.length; i3++) {
            SDVariable[] switchOp = sameDiff.switchOp(sDVariableArr3[i3], define);
            sDVariableArr4[i3] = switchOp[1];
            newHashSet.add(switchOp[1].name());
            sDVariableArr5[i3] = new Exit(sameDiff, switchOp[0]).outputVariable();
        }
        HashSet newHashSet2 = Sets.newHashSet(sameDiff.variableMap().keySet());
        HashMap hashMap = new HashMap();
        sameDiff.addArgumentInterceptor(sDVariable -> {
            if (sDVariable == null) {
                return null;
            }
            if (newHashSet2.contains(sDVariable.name()) && !newHashSet.contains(sDVariable.name())) {
                if (hashMap.containsKey(sDVariable.name())) {
                    return (SDVariable) hashMap.get(sDVariable.name());
                }
                SDVariable outputVariable2 = new Enter(sameDiff, newBlockName, sDVariable, true).outputVariable();
                hashMap.put(sDVariable.name(), outputVariable2);
                return outputVariable2;
            }
            return sDVariable;
        });
        NameScope withNameScope3 = sameDiff.withNameScope("body");
        SDVariable[] define2 = sameDiffLambda.define(sameDiff, sDVariableArr4);
        if (define2.length != mergeArr.length) {
            throw new IllegalArgumentException("Number of loop variables must be equal to number of outputs.");
        }
        withNameScope3.close();
        sameDiff.removeArgumentInterceptor();
        outputVariable.add(1.0d);
        for (int i4 = 0; i4 < define2.length; i4++) {
            mergeArr[i4].replaceArg(1, new NextIteration(sameDiff, define2[i4]).outputVariable());
        }
        merge.replaceArg(1, outputVariable);
        withNameScope.close();
        return sameDiff.updateVariableNamesAndReferences(sDVariableArr5, strArr);
    }

    public static SameDiffSingleLambda condBody() {
        return (sameDiff, sDVariableArr) -> {
            SDVariable sDVariable = sDVariableArr[0];
            return sameDiff.bitwise().and(sDVariable.lt(sDVariableArr[1].castTo(sDVariable.dataType())).castTo(DataType.INT64), sDVariableArr[2].castTo(DataType.INT64)).castTo(DataType.BOOL);
        };
    }
}
