package org.nd4j.autodiff.samediff.internal;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.config.ExecutionResult;
import org.nd4j.autodiff.samediff.config.SDValue;
import org.nd4j.autodiff.samediff.config.SDValueType;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
import org.nd4j.autodiff.samediff.internal.memory.HashDependencyTracker;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.imports.VariableUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.custom.Invoke;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.Concat;
import org.nd4j.linalg.api.ops.impl.shape.CreateView;
import org.nd4j.linalg.api.ops.impl.shape.Stack;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayRead;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayRemove;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayScatter;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySize;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySplit;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayWrite;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.shade.wstx.util.StringUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession.class */
public class InferenceSession extends AbstractSession<INDArray, Pair<SameDiffOp, OpContext>> {
    private static final Logger log = LoggerFactory.getLogger(InferenceSession.class);
    private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\nAlternatively, arrays defined in a workspace must be replaced after the workspace has been closed.";
    protected static final String KERAS_TRAIN_TEST = "keras_learning_phase";
    protected Set<Long> freedArrays;
    private SessionMemMgr mmgr;
    private AbstractDependencyTracker<SDValue, Dep> arrayUseTracker;
    private Map<String, OpContext> opContexts;

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$ConstantDep.class */
    protected static class ConstantDep extends Dep {
        protected String constName;

        public String getConstName() {
            return this.constName;
        }

        public void setConstName(String str) {
            this.constName = str;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "InferenceSession.ConstantDep(constName=" + getConstName() + ")";
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ConstantDep)) {
                return false;
            }
            ConstantDep constantDep = (ConstantDep) obj;
            if (!constantDep.canEqual(this) || !super.equals(obj)) {
                return false;
            }
            String constName = getConstName();
            String constName2 = constantDep.getConstName();
            return constName == null ? constName2 == null : constName.equals(constName2);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof ConstantDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            int hashCode = super.hashCode();
            String constName = getConstName();
            return (hashCode * 59) + (constName == null ? 43 : constName.hashCode());
        }

        public ConstantDep(String str) {
            this.constName = str;
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$Dep.class */
    public static abstract class Dep {
        protected String frame;
        protected FrameIter parentFrame;

        public String getFrame() {
            return this.frame;
        }

        public FrameIter getParentFrame() {
            return this.parentFrame;
        }

        public void setFrame(String str) {
            this.frame = str;
        }

        public void setParentFrame(FrameIter frameIter) {
            this.parentFrame = frameIter;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Dep)) {
                return false;
            }
            Dep dep = (Dep) obj;
            if (!dep.canEqual(this)) {
                return false;
            }
            String frame = getFrame();
            String frame2 = dep.getFrame();
            if (frame == null) {
                if (frame2 != null) {
                    return false;
                }
            } else if (!frame.equals(frame2)) {
                return false;
            }
            FrameIter parentFrame = getParentFrame();
            FrameIter parentFrame2 = dep.getParentFrame();
            return parentFrame == null ? parentFrame2 == null : parentFrame.equals(parentFrame2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Dep;
        }

        public int hashCode() {
            String frame = getFrame();
            int hashCode = (1 * 59) + (frame == null ? 43 : frame.hashCode());
            FrameIter parentFrame = getParentFrame();
            return (hashCode * 59) + (parentFrame == null ? 43 : parentFrame.hashCode());
        }

        public String toString() {
            return "InferenceSession.Dep(frame=" + getFrame() + ", parentFrame=" + getParentFrame() + ")";
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$ExecDoneDep.class */
    public static class ExecDoneDep extends Dep {
        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "InferenceSession.ExecDoneDep()";
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            return (obj instanceof ExecDoneDep) && ((ExecDoneDep) obj).canEqual(this) && super.equals(obj);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof ExecDoneDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            return super.hashCode();
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$OpDep.class */
    public static class OpDep extends Dep {
        protected String opName;
        protected int iter;

        protected OpDep(@NonNull String str, @NonNull String str2, int i, FrameIter frameIter) {
            if (str == null) {
                throw new NullPointerException("opName is marked non-null but is null");
            }
            if (str2 == null) {
                throw new NullPointerException("frame is marked non-null but is null");
            }
            this.opName = str;
            this.frame = str2;
            this.iter = i;
            this.parentFrame = frameIter;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "OpDep(" + this.opName + ",frame=" + this.frame + ",iter=" + this.iter + (this.parentFrame == null ? "" : ",parent=" + this.parentFrame) + ")";
        }

        public OpDep(String str, int i) {
            this.opName = str;
            this.iter = i;
        }

        public String getOpName() {
            return this.opName;
        }

        public int getIter() {
            return this.iter;
        }

        public void setOpName(String str) {
            this.opName = str;
        }

        public void setIter(int i) {
            this.iter = i;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof OpDep)) {
                return false;
            }
            OpDep opDep = (OpDep) obj;
            if (!opDep.canEqual(this) || !super.equals(obj) || getIter() != opDep.getIter()) {
                return false;
            }
            String opName = getOpName();
            String opName2 = opDep.getOpName();
            return opName == null ? opName2 == null : opName.equals(opName2);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof OpDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            int hashCode = (super.hashCode() * 59) + getIter();
            String opName = getOpName();
            return (hashCode * 59) + (opName == null ? 43 : opName.hashCode());
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$PlaceholderDep.class */
    protected static class PlaceholderDep extends Dep {
        protected String phName;

        public String getPhName() {
            return this.phName;
        }

        public void setPhName(String str) {
            this.phName = str;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "InferenceSession.PlaceholderDep(phName=" + getPhName() + ")";
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof PlaceholderDep)) {
                return false;
            }
            PlaceholderDep placeholderDep = (PlaceholderDep) obj;
            if (!placeholderDep.canEqual(this) || !super.equals(obj)) {
                return false;
            }
            String phName = getPhName();
            String phName2 = placeholderDep.getPhName();
            return phName == null ? phName2 == null : phName.equals(phName2);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof PlaceholderDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            int hashCode = super.hashCode();
            String phName = getPhName();
            return (hashCode * 59) + (phName == null ? 43 : phName.hashCode());
        }

        public PlaceholderDep(String str) {
            this.phName = str;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$ReqOutputDep.class */
    public static class ReqOutputDep extends Dep {
        protected String outputName;

        public String getOutputName() {
            return this.outputName;
        }

        public void setOutputName(String str) {
            this.outputName = str;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "InferenceSession.ReqOutputDep(outputName=" + getOutputName() + ")";
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ReqOutputDep)) {
                return false;
            }
            ReqOutputDep reqOutputDep = (ReqOutputDep) obj;
            if (!reqOutputDep.canEqual(this) || !super.equals(obj)) {
                return false;
            }
            String outputName = getOutputName();
            String outputName2 = reqOutputDep.getOutputName();
            return outputName == null ? outputName2 == null : outputName.equals(outputName2);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof ReqOutputDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            int hashCode = super.hashCode();
            String outputName = getOutputName();
            return (hashCode * 59) + (outputName == null ? 43 : outputName.hashCode());
        }

        public ReqOutputDep(String str) {
            this.outputName = str;
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$VariableDep.class */
    protected static class VariableDep extends Dep {
        protected String varName;

        public String getVarName() {
            return this.varName;
        }

        public void setVarName(String str) {
            this.varName = str;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "InferenceSession.VariableDep(varName=" + getVarName() + ")";
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof VariableDep)) {
                return false;
            }
            VariableDep variableDep = (VariableDep) obj;
            if (!variableDep.canEqual(this) || !super.equals(obj)) {
                return false;
            }
            String varName = getVarName();
            String varName2 = variableDep.getVarName();
            return varName == null ? varName2 == null : varName.equals(varName2);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof VariableDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            int hashCode = super.hashCode();
            String varName = getVarName();
            return (hashCode * 59) + (varName == null ? 43 : varName.hashCode());
        }

        public VariableDep(String str) {
            this.varName = str;
        }
    }

    public InferenceSession(@NonNull SameDiff sameDiff) {
        super(sameDiff);
        this.freedArrays = new LinkedHashSet();
        this.arrayUseTracker = new HashDependencyTracker();
        this.opContexts = new HashMap();
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        this.mmgr = new ArrayCacheMemoryMgr();
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    protected Map<String, INDArray> preprocessPlaceholders(Map<String, INDArray> map, At at) {
        this.arrayUseTracker.clear();
        for (SDVariable sDVariable : this.sameDiff.variables()) {
            if (sDVariable.getVariableType() == VariableType.CONSTANT) {
                this.arrayUseTracker.addDependency(SDValue.create(sDVariable.getArr()), new ConstantDep(sDVariable.name()));
            } else if (sDVariable.getVariableType() == VariableType.VARIABLE) {
                this.arrayUseTracker.addDependency(SDValue.create(sDVariable.getArr()), new VariableDep(sDVariable.name()));
            }
        }
        boolean z = false;
        List<String> inputs = this.sameDiff.inputs();
        if (inputs != null && !inputs.isEmpty()) {
            for (String str : inputs) {
                if (str.endsWith(KERAS_TRAIN_TEST) && !map.containsKey(str)) {
                    INDArray assign = this.mmgr.allocate(false, DataType.BOOL, new long[0]).assign(at.operation().isTrainingPhase());
                    map = new HashMap(map);
                    map.put(str, assign);
                    z = true;
                }
            }
        }
        if (map == null || map.isEmpty()) {
            return map;
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            Preconditions.checkState(this.sameDiff.hasVariable(entry.getKey()), "Invalid placeholder passed for execution: No variable/placeholder with name %s exists", entry.getKey());
            INDArray value = entry.getValue();
            SDValue create = SDValue.create(value);
            if (value.isAttached()) {
                MemoryWorkspace parentWorkspace = value.data() == null ? null : value.data().getParentWorkspace();
                if (parentWorkspace != null && parentWorkspace.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
                    if (!parentWorkspace.isScopeActive()) {
                        throw new ND4JIllegalStateException("Placeholder \"" + entry.getKey() + "\" array uses leaked workspace pointer from workspace [" + parentWorkspace.getId() + "]: Workspace the array was defined in is no longer open.\nAll open workspaces: " + DefaultOpExecutioner.allOpenWorkspaces() + "\nIf required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\nAlternatively, arrays defined in a workspace must be replaced after the workspace has been closed.");
                    }
                    if (parentWorkspace.getGenerationId() != value.data().getGenerationId()) {
                        String key = entry.getKey();
                        String id = parentWorkspace.getId();
                        long generationId = value.data().getGenerationId();
                        long generationId2 = parentWorkspace.getGenerationId();
                        DefaultOpExecutioner.allOpenWorkspaces();
                        ND4JIllegalStateException nD4JIllegalStateException = new ND4JIllegalStateException("Placeholder \"" + key + "\" array uses outdated workspace pointer from workspace [" + id + "]: Workspace array was defined in has been closed and reopened at least once since array creation. Array WS iteration: " + generationId + ". Workspace current iteration: " + nD4JIllegalStateException + "\nAll open workspaces: " + generationId2 + "\nIf required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\nAlternatively, arrays defined in a workspace must be replaced after the workspace has been closed.");
                        throw nD4JIllegalStateException;
                    }
                }
            }
            DataType dataType = this.sameDiff.getVariable(entry.getKey()).dataType();
            if (z && entry.getKey().endsWith(KERAS_TRAIN_TEST)) {
                this.arrayUseTracker.addDependency(create, new ExecDoneDep());
            } else if (value.dataType() == dataType) {
                this.arrayUseTracker.addDependency(create, new PlaceholderDep(entry.getKey()));
            } else {
                INDArray allocate = this.mmgr.allocate(false, dataType, value.shape());
                allocate.assign(value);
                value = allocate;
                this.arrayUseTracker.addDependency(create, new ExecDoneDep());
            }
            hashMap.put(entry.getKey(), value);
        }
        return hashMap;
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    protected Map<String, SDValue> postProcessOutputValues(Map<String, SDValue> map) {
        if (this.dt.hasNewAllSatisfied()) {
            for (AbstractSession.ExecStep execStep : this.dt.getNewAllSatisfiedList()) {
                if (execStep.getType() == AbstractSession.ExecType.OP) {
                    this.arrayUseTracker.markSatisfied(new OpDep(execStep.getName(), execStep.getFrameIter().getFrame(), execStep.getFrameIter().getIteration(), execStep.getFrameIter().getParentFrame()), true);
                }
            }
        }
        this.arrayUseTracker.markSatisfied(new ExecDoneDep(), true);
        if (this.arrayUseTracker.hasNewAllSatisfied()) {
            for (SDValue sDValue : this.arrayUseTracker.getNewAllSatisfiedList()) {
                switch (sDValue.getSdValueType()) {
                    case LIST:
                        for (INDArray iNDArray : sDValue.getListValue()) {
                            if (iNDArray != null && !this.freedArrays.contains(Long.valueOf(iNDArray.getId())) && this.sameDiff.isEnableCache()) {
                                this.mmgr.release(iNDArray);
                                this.freedArrays.add(Long.valueOf(iNDArray.getId()));
                            }
                        }
                        break;
                    case TENSOR:
                        if (!this.freedArrays.contains(Long.valueOf(sDValue.getTensorValue().getId())) && this.sameDiff.isEnableCache()) {
                            this.mmgr.release(sDValue.getTensorValue());
                            this.freedArrays.add(Long.valueOf(sDValue.getTensorValue().getId()));
                            break;
                        }
                        break;
                }
            }
        }
        return map;
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    protected Map<String, INDArray> postProcessOutput(Map<String, INDArray> map) {
        return map;
    }

    /* renamed from: getOutputs, reason: avoid collision after fix types in other method */
    public ExecutionResult getOutputs2(Pair<SameDiffOp, OpContext> pair, FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3, List<Listener> list, At at, MultiDataSet multiDataSet, Set<String> set4, Map<String, SDValue> map) {
        SameDiffOp sameDiffOp = (SameDiffOp) pair.getFirst();
        at.setFrameIter(frameIter);
        if (list != null && list.size() > 0) {
            SameDiffOp sameDiffOp2 = this.sameDiff.getOps().get(sameDiffOp.getOp().getOwnName());
            for (Listener listener : list) {
                if (listener.isActive(at.operation())) {
                    listener.preOpExecution(this.sameDiff, at, sameDiffOp2, (OpContext) pair.getSecond());
                }
            }
        }
        if (this.sameDiff.isDebugMode()) {
            log.info("Executing samediff op: " + sameDiffOp.getName());
        }
        ExecutionResult doExec = doExec(sameDiffOp.getOp(), (OpContext) pair.getRight(), frameIter, set, set2, set3, map);
        List<String> outputsOfOp = sameDiffOp.getOutputsOfOp();
        if (log.isTraceEnabled()) {
            StringBuilder sb = new StringBuilder();
            sb.append(sameDiffOp.getName()).append(" - ").append(frameIter).append(" outputs: ");
            for (int i = 0; i < doExec.numResults(); i++) {
                if (i > 0) {
                    sb.append(", ");
                }
                if (doExec.hasSingle()) {
                    sb.append("(").append(i).append(" - ").append(outputsOfOp.get(i)).append(" = ").append(doExec.resultAt(i) == null ? null : Long.valueOf(doExec.resultAt(i).getId())).append(")");
                } else if (doExec.hasValues()) {
                    SDValue valueWithKeyAtIndex = doExec.valueWithKeyAtIndex(i, false);
                    sb.append("(").append(i).append(" - ").append(outputsOfOp.get(i)).append(" = ").append(valueWithKeyAtIndex == null ? null : (valueWithKeyAtIndex == null || valueWithKeyAtIndex.getSdValueType() != SDValueType.LIST) ? valueWithKeyAtIndex != null ? String.valueOf(valueWithKeyAtIndex.getTensorValue().getId()) : null : StringUtil.concatEntries((Collection) valueWithKeyAtIndex.getListValue().stream().map(iNDArray -> {
                        return Long.valueOf(iNDArray.getId());
                    }).collect(Collectors.toList()), ",", ",")).append(")");
                }
            }
            log.trace(sb.toString());
        }
        if (list != null && list.size() > 0) {
            Map map2 = null;
            for (Listener listener2 : list) {
                if (listener2.isActive(at.operation())) {
                    if (map2 == null) {
                        HashMap hashMap = new HashMap();
                        for (int i2 = 0; i2 < doExec.numResults(); i2++) {
                            hashMap.put(sameDiffOp.outputsOfOp.get(i2), doExec.resultAt(i2));
                        }
                        map2 = Collections.unmodifiableMap(hashMap);
                    }
                    listener2.opExecution(this.sameDiff, at, multiDataSet, sameDiffOp, (OpContext) pair.getSecond(), doExec.outputsToArray(outputsOfOp));
                    for (String str : map2.keySet()) {
                        listener2.activationAvailable(this.sameDiff, at, multiDataSet, sameDiffOp, str, (INDArray) map2.get(str));
                    }
                }
            }
        }
        sameDiffOp.getOp().clearArrays();
        if (pair.getSecond() != null) {
            ((OpContext) pair.getSecond()).purge();
        }
        SameDiffOp sameDiffOp3 = this.sameDiff.getOps().get(sameDiffOp.getName());
        List<String> outputsOfOp2 = sameDiffOp3.getOutputsOfOp();
        for (int i3 = 0; i3 < doExec.numResults(); i3++) {
            if ((!doExec.hasSingle() || doExec.resultAt(i3) != null) && (!doExec.hasValues() || doExec.valueWithKeyAtIndex(i3, false) != null || !(sameDiffOp3.getOp() instanceof Switch))) {
                String str2 = outputsOfOp2.get(i3);
                List<String> inputsForOp = ((Variable) this.sameDiff.getVariables().get(str2)).getInputsForOp();
                if (inputsForOp != null) {
                    for (String str3 : inputsForOp) {
                        if (this.subgraphOps.contains(str3)) {
                            SameDiffOp sameDiffOp4 = this.sameDiff.getOps().get(str3);
                            if (sameDiffOp4.getOp() instanceof Enter) {
                                Enter enter = (Enter) sameDiffOp4.getOp();
                                if (enter.isConstant()) {
                                    addToArrayTracker(doExec, i3, new ExecDoneDep());
                                } else {
                                    addToArrayTracker(doExec, i3, new OpDep(str3, enter.getFrameName(), 0, frameIter));
                                }
                            } else if (sameDiffOp4.getOp() instanceof NextIteration) {
                                addToArrayTracker(doExec, i3, new OpDep(str3, frameIter.getFrame(), frameIter.getIteration() + 1, frameIter.getParentFrame()));
                            } else if (sameDiffOp4.getOp() instanceof Exit) {
                                FrameIter parentFrame = frameIter.getParentFrame();
                                addToArrayTracker(doExec, i3, new OpDep(str3, parentFrame.getFrame(), parentFrame.getIteration(), parentFrame.getParentFrame()));
                            } else {
                                addToArrayTracker(doExec, i3, new OpDep(str3, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame()));
                            }
                        }
                    }
                }
                if (AbstractSession.OUTER_FRAME.equals(frameIter.getFrame()) && set4.contains(str2)) {
                    addToArrayTracker(doExec, i3, new ReqOutputDep(str2));
                } else if ((inputsForOp == null || inputsForOp.isEmpty()) && doExec.getValueOutputs() != null && !this.arrayUseTracker.hasDependency(doExec.valueWithKeyAtIndex(i3, false))) {
                    SDValue valueWithKeyAtIndex2 = doExec.valueWithKeyAtIndex(i3, false);
                    if (log.isTraceEnabled() && valueWithKeyAtIndex2 != null && valueWithKeyAtIndex2.getTensorValue() != null) {
                        log.trace("Found array id {} (output of {}) not required anywhere, deallocating", Long.valueOf(valueWithKeyAtIndex2.getTensorValue().getId()), sameDiffOp3.getName());
                    }
                    if (valueWithKeyAtIndex2 != null && valueWithKeyAtIndex2.getTensorValue() != null && !this.freedArrays.contains(Long.valueOf(valueWithKeyAtIndex2.getTensorValue().getId()))) {
                        this.mmgr.release(valueWithKeyAtIndex2.getTensorValue());
                        this.freedArrays.add(Long.valueOf(valueWithKeyAtIndex2.getTensorValue().getId()));
                    }
                } else if ((inputsForOp == null || inputsForOp.isEmpty()) && doExec.getOutputs() != null && !this.arrayUseTracker.hasDependency(SDValue.create(doExec.resultAt(i3)))) {
                    INDArray resultAt = doExec.resultAt(i3);
                    if (log.isTraceEnabled() && resultAt != null && resultAt != null) {
                        log.trace("Found array id {} (output of {}) not required anywhere, deallocating", Long.valueOf(resultAt.getId()), sameDiffOp3.getName());
                    }
                    if (resultAt != null && !this.freedArrays.contains(Long.valueOf(resultAt.getId()))) {
                        this.mmgr.release(resultAt);
                        this.freedArrays.add(Long.valueOf(resultAt.getId()));
                    }
                }
            }
        }
        this.arrayUseTracker.markSatisfied(new OpDep(sameDiffOp.getName(), frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame()), true);
        if (this.arrayUseTracker.hasNewAllSatisfied()) {
            for (SDValue sDValue : this.arrayUseTracker.getNewAllSatisfiedList()) {
                if (log.isTraceEnabled() && sDValue.getSdValueType() == SDValueType.TENSOR) {
                    INDArray tensorValue = sDValue.getTensorValue();
                    log.trace("Closing array... id={}, {}", Long.valueOf(tensorValue.getId()), tensorValue.shapeInfoToString());
                }
                if (!(sameDiffOp.getOp() instanceof Switch)) {
                    switch (sDValue.getSdValueType()) {
                        case LIST:
                            for (INDArray iNDArray2 : sDValue.getListValue()) {
                                if (iNDArray2 != null && !this.freedArrays.contains(Long.valueOf(iNDArray2.getId())) && this.sameDiff.isEnableCache()) {
                                    this.mmgr.release(iNDArray2);
                                    this.freedArrays.add(Long.valueOf(iNDArray2.getId()));
                                }
                            }
                            break;
                        case TENSOR:
                            if (!this.freedArrays.contains(Long.valueOf(sDValue.getTensorValue().getId())) && this.sameDiff.isEnableCache()) {
                                this.mmgr.release(sDValue.getTensorValue());
                                this.freedArrays.add(Long.valueOf(sDValue.getTensorValue().getId()));
                                break;
                            }
                            break;
                    }
                }
            }
        }
        return doExec;
    }

    private void addToArrayTracker(ExecutionResult executionResult, int i, Dep dep) {
        if (executionResult.hasSingle()) {
            this.arrayUseTracker.addDependency(SDValue.create(executionResult.resultOrValueAt(i, false)), dep);
        } else {
            this.arrayUseTracker.addDependency(executionResult.valueWithKeyAtIndex(i, false), dep);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public ExecutionResult doExec(DifferentialFunction differentialFunction, OpContext opContext, FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3, Map<String, SDValue> map) {
        INDArray inputArray;
        int size = (set == null ? 0 : set.size()) + (set3 == null ? 0 : set3.size()) + (set2 == null ? 0 : set2.size());
        boolean z = (set == null || set.size() == 0) && (set2 == null || set2.size() == 0);
        if (differentialFunction instanceof Identity) {
            String[] argNames = ((Identity) differentialFunction).argNames();
            Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in identity op, got %s", argNames);
            AbstractSession.VarId varId = frameIter.toVarId(argNames[0]);
            return ExecutionResult.createValue(varId.getVariable(), getSdValue(varId));
        }
        if (differentialFunction instanceof Switch) {
            String[] argNames2 = ((Switch) differentialFunction).argNames();
            AbstractSession.VarId varId2 = frameIter.toVarId(argNames2[1]);
            SDValue sdValue = getSdValue(varId2);
            INDArray tensorValue = sdValue.getSdValueType() == SDValueType.LIST ? sdValue.getListValue().get(0) : sdValue.getTensorValue();
            if (tensorValue != null && tensorValue.isEmpty()) {
                tensorValue = Nd4j.scalar(false);
            }
            if (tensorValue == null && !set3.isEmpty() && set3.contains(argNames2[1])) {
                tensorValue = getTensorFromOutputs(new AbstractSession.VarId(argNames2[1], AbstractSession.OUTER_FRAME, 0, null));
            }
            Preconditions.checkNotNull(tensorValue, "Error during graph execution: Predicate array was null. VarId=%s", varId2);
            Preconditions.checkState(tensorValue.isScalar() && tensorValue.dataType() == DataType.BOOL, "Expected boolean predicate: got %ndSInfo", tensorValue);
            AbstractSession.VarId varId3 = frameIter.toVarId(argNames2[0]);
            SDValue sdValue2 = getSdValue(varId3);
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            ExecutionResult.ExecutionResultBuilder valueOutputs = ExecutionResult.builder().valueOutputs(linkedHashMap);
            if (tensorValue.getDouble(0L) == 0.0d) {
                if (varId3.getVariable().equals(varId2.getVariable())) {
                    SDValue create = SDValue.create((List<INDArray>) Arrays.asList(sdValue2.getTensorValue(), null));
                    linkedHashMap.put(varId2.getVariable(), create);
                    putNodeValue(create, varId3);
                    putNodeValue(create, new AbstractSession.VarId(varId3.getVariable() + ":1", varId3.getFrame(), varId3.getIteration(), varId3.getParentFrame()));
                } else {
                    linkedHashMap.put(varId3.getVariable(), sdValue2);
                    linkedHashMap.put(varId2.getVariable(), null);
                }
            } else if (varId3.getVariable().equals(varId2.getVariable())) {
                SDValue create2 = SDValue.create((List<INDArray>) Arrays.asList(null, sdValue2.getTensorValue()));
                linkedHashMap.put(varId2.getVariable(), create2);
                linkedHashMap.put(varId2.getVariable() + ":1", create2);
            } else {
                linkedHashMap.put(varId3.getVariable(), null);
                linkedHashMap.put(varId2.getVariable(), sdValue2);
            }
            return valueOutputs.build();
        }
        if (differentialFunction instanceof Enter) {
            Enter enter = (Enter) differentialFunction;
            String[] argNames3 = enter.argNames();
            Preconditions.checkState(argNames3.length == 1, "Expected only 1 arg name for enter op: got %s", argNames3);
            Preconditions.checkState(size == 1, "Expected exactly 1 op input for Enter op \"%s\", got %s+%s", enter.getOwnName(), set, set3);
            AbstractSession.VarId varId4 = z ? new AbstractSession.VarId(set3.iterator().next(), AbstractSession.OUTER_FRAME, 0, null) : (set2 == null || set2.size() <= 0) ? set.iterator().next() : set2.iterator().next();
            varId4.setVariable(VariableUtils.stripVarSuffix(varId4.getVariable()));
            if (!this.nodeValueOutputs.containsKey(varId4)) {
                INDArray tensorFromOutputs = getTensorFromOutputs(varId4);
                if (tensorFromOutputs == null) {
                    Preconditions.throwStateEx("Could not find array for Enter operation %s with output %s (frame=%s, iteration=%s)", new Object[]{differentialFunction.getOwnName(), this.sameDiff.getOps().get(differentialFunction.getOwnName()).getOutputsOfOp().get(0), frameIter.getFrame(), Integer.valueOf(frameIter.getIteration())});
                }
                return ExecutionResult.createFrom((List<String>) Arrays.asList(varId4.getVariable()), new INDArray[]{tensorFromOutputs});
            }
            SDValue sdValue3 = getSdValue(varId4);
            if (sdValue3 != null && sdValue3.getSdValueType() == SDValueType.LIST) {
                return ExecutionResult.createValue(varId4.getVariable(), sdValue3);
            }
            if (sdValue3 == null || sdValue3.getSdValueType() != SDValueType.TENSOR) {
                throw new IllegalStateException("Illegal value type " + sdValue3.getSdValueType() + " for input " + varId4);
            }
            INDArray tensorFromOutputs2 = getTensorFromOutputs(varId4);
            if (tensorFromOutputs2 == null) {
                Preconditions.throwStateEx("Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", new Object[]{differentialFunction.getOwnName(), this.sameDiff.getOps().get(differentialFunction.getOwnName()).getOutputsOfOp().get(0), frameIter.getFrame(), Integer.valueOf(frameIter.getIteration())});
            }
            return ExecutionResult.createFrom((List<String>) Arrays.asList(varId4.getVariable()), new INDArray[]{tensorFromOutputs2});
        }
        if (differentialFunction instanceof Exit) {
            AbstractSession.VarId varId5 = z ? new AbstractSession.VarId(set3.iterator().next(), AbstractSession.OUTER_FRAME, 0, null) : (set2 == null || set2.size() <= 0) ? set.iterator().next() : set2.iterator().next();
            return ExecutionResult.createValue(varId5.getVariable(), getSdValue(varId5));
        }
        if (differentialFunction instanceof NextIteration) {
            Preconditions.checkState(size == 1, "Expected exactly 1 op input for NextIteration: got %s+%s", set, set3);
            AbstractSession.VarId next = (set2 == null || set2.isEmpty()) ? set.iterator().next() : set2.iterator().next();
            Preconditions.checkState(frameIter.getFrame().equals(next.getFrame()), "Expected same frame for NextIteration input vs. output: got input %s, output %s", next, frameIter);
            Preconditions.checkState(frameIter.getIteration() == next.getIteration() + 1, "Expected output iteration for NextIteration output to be 1 larger than the input iteration. Input: %s, output %s", next, frameIter);
            if (!this.nodeValueOutputs.containsKey(next) || getSdValue(next) == null) {
                INDArray tensorFromOutputs3 = getTensorFromOutputs(next);
                if (tensorFromOutputs3 == null) {
                    Preconditions.throwStateEx("Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", new Object[]{differentialFunction.getOwnName(), this.sameDiff.getOps().get(differentialFunction.getOwnName()).getOutputsOfOp().get(0), frameIter.getFrame(), Integer.valueOf(frameIter.getIteration())});
                }
                return ExecutionResult.createFrom((List<String>) Arrays.asList(next.getVariable()), new INDArray[]{tensorFromOutputs3});
            }
            SDValue sdValue4 = getSdValue(next);
            if (sdValue4 != null && sdValue4.getSdValueType() == SDValueType.LIST) {
                return ExecutionResult.createValue(next.getVariable(), sdValue4);
            }
            if (sdValue4 == null || sdValue4.getSdValueType() != SDValueType.TENSOR) {
                throw new IllegalStateException("Illegal value type " + sdValue4.getSdValueType() + " for input " + next);
            }
            INDArray tensorFromOutputs4 = getTensorFromOutputs(next);
            if (tensorFromOutputs4 == null) {
                Preconditions.throwStateEx("Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", new Object[]{differentialFunction.getOwnName(), this.sameDiff.getOps().get(differentialFunction.getOwnName()).getOutputsOfOp().get(0), frameIter.getFrame(), Integer.valueOf(frameIter.getIteration())});
            }
            return ExecutionResult.createFrom((List<String>) Arrays.asList(next.getVariable()), new INDArray[]{tensorFromOutputs4});
        }
        if (differentialFunction instanceof Merge) {
            Merge merge = (Merge) differentialFunction;
            String[] inputsForOp = this.sameDiff.getInputsForOp(differentialFunction);
            AbstractSession.VarId varId6 = frameIter.toVarId(inputsForOp[0]);
            AbstractSession.VarId varId7 = frameIter.toVarId(inputsForOp[1]);
            SDValue sdValue5 = getSdValue(varId6);
            SDValue sdValue6 = getSdValue(varId7);
            String str = sdValue6 != null ? inputsForOp[1] : inputsForOp[0];
            AbstractSession.VarId varId8 = sdValue6 != null ? varId7 : varId6;
            if (sdValue5 == null && sdValue6 == null) {
                throw new IllegalStateException("Merge node " + merge.getOwnName() + " has no available inputs (all inputs: " + Arrays.toString(inputsForOp) + ") - should not be executed at this point");
            }
            log.trace("Returning input \"{}\" for merge node \"{}\"", merge.getOwnName(), str);
            SDValue sdValue7 = getSdValue(varId8);
            if (sdValue7.getSdValueType() == SDValueType.LIST) {
                return ExecutionResult.createValue(varId8.getVariable(), getSdValue(varId8));
            }
            if (sdValue7.getSdValueType() != SDValueType.TENSOR) {
                throw new IllegalStateException("Illegal value type " + sdValue7.getSdValueType() + " for input " + inputsForOp);
            }
            INDArray tensorFromOutputs5 = getTensorFromOutputs(varId8);
            if (tensorFromOutputs5 == null) {
                Preconditions.throwStateEx("Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", new Object[]{differentialFunction.getOwnName(), this.sameDiff.getOps().get(differentialFunction.getOwnName()).getOutputsOfOp().get(0), frameIter.getFrame(), Integer.valueOf(frameIter.getIteration())});
            }
            return ExecutionResult.createFrom((List<String>) Arrays.asList(varId8.getVariable()), new INDArray[]{tensorFromOutputs5});
        }
        if (differentialFunction instanceof LoopCond) {
            String[] argNames4 = ((LoopCond) differentialFunction).argNames();
            Preconditions.checkState(argNames4.length == 1, "Expected only 1 arg name in LoopCond op, got %s", argNames4);
            AbstractSession.VarId varId9 = frameIter.toVarId(argNames4[0]);
            SDValue sdValue8 = getSdValue(varId9);
            if (sdValue8.getTensorValue() == null) {
                throw new IllegalStateException("Node value output at " + varId9.getVariable() + " was not a boolean tensor!");
            }
            Preconditions.checkNotNull(sdValue8, "Input to LoopCond op must not be null");
            Preconditions.checkState(sdValue8.getTensorValue().isScalar() && sdValue8.getTensorValue().dataType() == DataType.BOOL, "LoopCond input must be a scalar boolean, got %ndShape");
            return ExecutionResult.createValue(varId9.getVariable(), sdValue8);
        }
        if (differentialFunction instanceof BaseTensorOp) {
            return getOutputsHelperTensorArrayOps(differentialFunction, frameIter, set, set2, map);
        }
        if (differentialFunction instanceof Identity) {
            return ExecutionResult.createValue(differentialFunction.outputVariablesNames()[0], getSdValue((AbstractSession.VarId) new ArrayList(set).get(0)));
        }
        if (differentialFunction instanceof Assign) {
            ArrayList arrayList = new ArrayList(set);
            if (arrayList.size() > 1) {
                getSdValue((AbstractSession.VarId) arrayList.get(0));
                SDValue sdValue9 = getSdValue((AbstractSession.VarId) arrayList.get(1));
                switch (r0.getSdValueType()) {
                    case LIST:
                        return ExecutionResult.createValue(differentialFunction.outputVariablesNames()[0], sdValue9);
                    case TENSOR:
                        Assign assign = (Assign) differentialFunction;
                        Nd4j.exec(assign, opContext);
                        return ExecutionResult.createFrom(assign, opContext);
                }
            }
            return ExecutionResult.createValue(differentialFunction.outputVariablesNames()[0], getSdValue((AbstractSession.VarId) arrayList.get(0)));
        }
        if (differentialFunction instanceof GradientBackwardsMarker) {
            return ExecutionResult.createFrom((List<String>) Arrays.asList("gradientbackwardsmarker"), new INDArray[]{this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(Float.valueOf(1.0f))});
        }
        if (differentialFunction instanceof CreateView) {
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            String[] argNames5 = differentialFunction.argNames();
            for (AbstractSession.VarId varId10 : set) {
                linkedHashMap2.put(varId10.getVariable(), varId10);
            }
            SDValue sdValue10 = getSdValue((AbstractSession.VarId) linkedHashMap2.get(argNames5[0]));
            if (sdValue10 == null) {
                sdValue10 = SDValue.create(opContext.getInputArray(0));
            }
            INDArray[] iNDArrayArr = new INDArray[argNames5.length - 1];
            for (int i = 1; i < argNames5.length; i++) {
                iNDArrayArr[i - 1] = getSdValue((AbstractSession.VarId) linkedHashMap2.get(argNames5[i])).getTensorValue();
            }
            return ExecutionResult.createFrom(differentialFunction.outputVariablesNames()[0], CreateView.createFrom(sdValue10.getTensorValue(), iNDArrayArr));
        }
        if (differentialFunction instanceof ExternalErrorsFunction) {
            String gradPlaceholderName = ((ExternalErrorsFunction) differentialFunction).getGradPlaceholderName();
            INDArray tensorFromOutputs6 = getTensorFromOutputs(new AbstractSession.VarId(gradPlaceholderName, AbstractSession.OUTER_FRAME, 0, null));
            Preconditions.checkState(tensorFromOutputs6 != null, "Could not find external errors placeholder array: %s", tensorFromOutputs6);
            INDArray allocate = this.mmgr.allocate(false, tensorFromOutputs6.dataType(), tensorFromOutputs6.shape());
            allocate.assign(tensorFromOutputs6);
            return ExecutionResult.createFrom((List<String>) Arrays.asList(gradPlaceholderName), new INDArray[]{allocate});
        }
        if (!(differentialFunction instanceof Invoke)) {
            if (!(differentialFunction instanceof Assert)) {
                if (differentialFunction instanceof CustomOp) {
                    CustomOp customOp = (CustomOp) differentialFunction;
                    Nd4j.exec(customOp, opContext);
                    return ExecutionResult.createFrom((DifferentialFunction) customOp, opContext);
                }
                if (!(differentialFunction instanceof Op)) {
                    throw new UnsupportedOperationException("Execution not yet implemented for: " + differentialFunction.getClass().getName());
                }
                Op op = (Op) differentialFunction;
                Nd4j.exec(op, opContext);
                return ExecutionResult.createFrom((DifferentialFunction) op, opContext);
            }
            Assert r0 = (Assert) differentialFunction;
            if (opContext.getInputArray(0).getDouble(0L) != 0.0d) {
                return ExecutionResult.createFrom(r0, opContext);
            }
            String str2 = "Assertion failed for operation \"" + differentialFunction.getOwnName() + "\" during execution";
            if (r0.numInputArguments() >= 3 && (inputArray = opContext.getInputArray(2)) != null && inputArray.dataType() == DataType.UTF8) {
                str2 = str2 + ": " + inputArray.getString(0L);
            }
            if (r0.numInputArguments() >= 5) {
                str2 = str2 + "\n" + opContext.getInputArray(4);
            }
            throw new IllegalStateException(str2);
        }
        Invoke invoke = (Invoke) differentialFunction;
        boolean z2 = false;
        Iterator<AbstractSession.VarId> it = set.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            if (this.nodeValueOutputs.containsKey(it.next())) {
                z2 = true;
                break;
            }
        }
        if (!z2) {
            Iterator<Map.Entry<String, SDValue>> it2 = map.entrySet().iterator();
            while (true) {
                if (!it2.hasNext()) {
                    break;
                }
                if (set3.contains(it2.next().getKey())) {
                    z2 = true;
                    break;
                }
            }
        }
        LinkedHashMap linkedHashMap3 = new LinkedHashMap();
        LinkedHashMap linkedHashMap4 = new LinkedHashMap();
        if (z2) {
            HashMap hashMap = new HashMap();
            for (AbstractSession.VarId varId11 : set) {
                hashMap.put(varId11.getVariable(), varId11);
            }
            for (int i2 = 0; i2 < invoke.getInputVarNames().length; i2++) {
                AbstractSession.VarId varId12 = (AbstractSession.VarId) hashMap.get(invoke.getInputVarNames()[i2]);
                if (set3.contains(invoke.getInputVarNames()[i2])) {
                    if (map.containsKey(invoke.getInputVarNames()[i2])) {
                        linkedHashMap4.put(invoke.getInputVarNames()[i2], map.get(invoke.getInputVarNames()[i2]));
                    } else if (linkedHashMap3.containsKey(invoke.getInputVarNames()[i2])) {
                        linkedHashMap4.put(invoke.getInputVarNames()[i2], SDValue.create((INDArray) linkedHashMap3.get(invoke.getInputVarNames()[i2])));
                    }
                } else if (this.sameDiff.getArrForVarName(invoke.getInputVarNames()[i2]) != null) {
                    linkedHashMap4.put(invoke.getInputVarNames()[i2], SDValue.create(this.sameDiff.getArrForVarName(invoke.getInputVarNames()[i2])));
                } else if (this.nodeValueOutputs.containsKey(varId12)) {
                    linkedHashMap4.put(varId12.getVariable(), getSdValue(varId12));
                } else {
                    linkedHashMap4.put(varId12.getVariable(), SDValue.create(opContext.getInputArray(i2)));
                }
            }
        } else {
            int i3 = 0;
            Iterator<AbstractSession.VarId> it3 = set.iterator();
            while (it3.hasNext()) {
                linkedHashMap3.put(it3.next().getVariable(), opContext.getInputArray(i3));
                i3++;
            }
        }
        if (linkedHashMap4.size() + linkedHashMap3.size() != differentialFunction.args().length) {
            throw new IllegalArgumentException("Value inputs and inputs combined did not fulfill all arguments. Inputs were: " + Arrays.toString(differentialFunction.argNames()) + " for op name " + differentialFunction.getOwnName());
        }
        return Invoke.doInvoke(invoke, linkedHashMap3, linkedHashMap4);
    }

    private SDValue getPreviousValue(AbstractSession.VarId varId) {
        return getPreviousValue(varId, 1);
    }

    private SDValue getPreviousValue(AbstractSession.VarId varId, int i) {
        return this.nodeValueOutputs.get(new AbstractSession.VarId(varId.getVariable(), varId.getFrame(), varId.getIteration() - i, varId.getParentFrame()));
    }

    private SDValue getValueAtIteration(String str, String str2, int i, FrameIter frameIter) {
        return this.nodeValueOutputs.get(new AbstractSession.VarId(str, str2, i, frameIter));
    }

    public ExecutionResult getOutputsHelperTensorArrayOps(DifferentialFunction differentialFunction, FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Map<String, SDValue> map) {
        if (differentialFunction instanceof TensorArray) {
            AbstractSession.VarId varId = frameIter.toVarId(differentialFunction.outputVariable().name());
            if (this.nodeValueOutputs.containsKey(varId)) {
                return ExecutionResult.createValue(varId.getVariable(), this.nodeValueOutputs.get(varId));
            }
            Preconditions.checkState(!this.nodeValueOutputs.containsKey(varId), "TensorArray already exists for %s when executing TensorArrayV3", varId);
            ArrayList arrayList = new ArrayList();
            if (differentialFunction.args().length > 0) {
                INDArray arr = differentialFunction.arg(0).getArr();
                TensorArray tensorArray = (TensorArray) differentialFunction;
                long[] requiredShape = tensorArray.args().length > 1 ? tensorArray.requiredShape() : null;
                for (int i = 0; i < arr.getInt(0); i++) {
                    arrayList.add(null);
                }
            }
            SDValue create = SDValue.create((List<INDArray>) arrayList);
            putNodeValue(create, varId);
            return ExecutionResult.createValue(varId.getVariable(), create);
        }
        if (differentialFunction instanceof TensorArrayRead) {
            INDArray array = getArray(differentialFunction.arg(1), set, set2);
            Preconditions.checkState(array.isScalar(), "TensorArrayRead input argument 1 should be scalar - has shape %ndShape", array);
            int i2 = array.getInt(0);
            SDVariable arg = differentialFunction.arg(0);
            AbstractSession.VarId lookup = set == null ? null : lookup(arg.name(), set, false);
            if (lookup == null && set2 != null) {
                lookup = lookup(arg.name(), set2, false);
            }
            Preconditions.checkState(lookup != null, "Could not find input %s", arg.name());
            TensorArray tensorArray2 = TensorArray.getTensorArray(this.sameDiff, arg);
            List<INDArray> tensorArraysInSession = !this.nodeValueOutputs.containsKey(lookup) ? getTensorArraysInSession(TensorArray.getTensorArray(this.sameDiff, arg).getVar().name()) : getSdValue(lookup).getListValue();
            if (tensorArray2.args().length > 1) {
                long[] requiredShape2 = tensorArray2.requiredShape();
                for (int i3 = 0; i3 < tensorArraysInSession.size(); i3++) {
                    if (tensorArraysInSession.get(i3) != null && !Arrays.equals(requiredShape2, tensorArraysInSession.get(i3).shape())) {
                        throw new IllegalArgumentException("Element " + i3 + " of list " + lookup.getVariable() + " did not have correct shape of " + Arrays.toString(requiredShape2) + " was shape " + Arrays.toString(tensorArraysInSession.get(i3).shape()));
                    }
                }
            }
            Preconditions.checkState(tensorArraysInSession != null, "Could not find TensorList for %s", lookup);
            Preconditions.checkState(tensorArraysInSession.size() > i2, "Cannot get index %s from TensorList of size %s (array not present?) - VarId=%s", Integer.valueOf(i2), Integer.valueOf(tensorArraysInSession.size()), lookup);
            INDArray iNDArray = tensorArraysInSession.get(i2);
            log.trace("Reading item at index " + i2 + " for list " + lookup + " with value " + iNDArray + " with list of " + tensorArraysInSession);
            return ExecutionResult.createFrom(lookup.getVariable(), iNDArray);
        }
        if (differentialFunction instanceof TensorArrayWrite) {
            SDVariable arg2 = differentialFunction.arg(0);
            AbstractSession.VarId lookup2 = set == null ? null : lookup(arg2.name(), set, false);
            if (lookup2 == null && set2 != null) {
                lookup2 = lookup(arg2.name(), set2, false);
            }
            if (lookup2 == null && arg2.getVariableType() == VariableType.PLACEHOLDER) {
                lookup2 = new AbstractSession.VarId(arg2.name(), frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame());
                map.get(arg2.name());
            }
            Preconditions.checkState(lookup2 != null, "Could not find input %s", arg2.name());
            INDArray array2 = getArray(this.sameDiff.getVariable(differentialFunction.arg(1).name()), set, set2);
            Preconditions.checkState(array2.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", array2);
            int i4 = array2.getInt(0);
            String name = differentialFunction.arg(2).name();
            SDVariable variable = this.sameDiff.getVariable(name);
            INDArray array3 = getArray(variable, set, set2);
            Preconditions.checkState(array3 != null, "Could not find array for %s", name);
            TensorArray tensorArray3 = TensorArray.getTensorArray(this.sameDiff, arg2);
            AbstractSession.VarId varId2 = new AbstractSession.VarId(tensorArray3.outputVariable().name(), AbstractSession.OUTER_FRAME, 0, null);
            if (tensorArray3.args().length > 1) {
                long[] longVector = tensorArray3.arg(1).getArr().toLongVector();
                if (!Arrays.equals(array3.shape(), longVector)) {
                    throw new IllegalArgumentException("Unable to write array of shape " + Arrays.toString(array3.shape()) + " must be " + longVector + " for op " + differentialFunction.getOwnName() + " and tensor array " + tensorArray3.getOwnName());
                }
            }
            Preconditions.checkState(this.nodeValueOutputs.containsKey(varId2), "Tensor array does not exist for %s", varId2);
            SDValue sdValue = getSdValue(varId2);
            List<INDArray> listValue = sdValue.getListValue();
            if (i4 < 0 && listValue != null && !listValue.isEmpty()) {
                i4 += listValue.size() + 1;
            } else if (i4 < 0) {
                i4 = 0;
            }
            while (listValue.size() <= i4) {
                listValue.add(null);
            }
            setArrayAtIndex(listValue, i4, array3);
            log.trace("Setting item at index " + i4 + " for list " + varId2 + " with value " + array3 + " with whole list of after write " + listValue + " and value array " + array3);
            log.trace("Writing value " + variable + " to list " + varId2.getVariable() + " at iteration " + varId2.getIteration());
            this.arrayUseTracker.addDependency(sdValue, new ExecDoneDep());
            return ExecutionResult.createValue(differentialFunction.outputVariable().name(), sdValue);
        }
        if (differentialFunction instanceof TensorArraySize) {
            SDVariable arg3 = differentialFunction.arg(0);
            TensorArray tensorArray4 = TensorArray.getTensorArray(this.sameDiff, arg3);
            AbstractSession.VarId lookup3 = set == null ? null : lookup(arg3.name(), set, false);
            if (lookup3 == null && set2 != null) {
                lookup3 = lookup(arg3.name(), set2, false);
            }
            List<INDArray> listValue2 = getSdValue(lookup3).getListValue();
            return ExecutionResult.createFrom(tensorArray4.getVar().name(), this.mmgr.allocate(false, DataType.INT, new long[0]).assign(Integer.valueOf(listValue2 == null ? 0 : listValue2.size())));
        }
        if (differentialFunction instanceof TensorArrayConcat) {
            SDVariable arg4 = differentialFunction.arg(0);
            AbstractSession.VarId lookup4 = set == null ? null : lookup(arg4.name(), set, false);
            if (lookup4 == null && set2 != null) {
                lookup4 = lookup(arg4.name(), set2, false);
            }
            Concat concat = new Concat(0, (INDArray[]) ((List) getSdValue(lookup4).getListValue().stream().filter(iNDArray2 -> {
                return iNDArray2 != null;
            }).collect(Collectors.toList())).toArray(new INDArray[0]));
            INDArray allocate = this.mmgr.allocate(false, concat.calculateOutputShape().get(0));
            concat.setOutputArgument(0, allocate);
            Nd4j.exec(concat);
            return ExecutionResult.createFrom(lookup4.getVariable(), allocate);
        }
        if (differentialFunction instanceof TensorArrayGather) {
            SDVariable arg5 = differentialFunction.arg(0);
            AbstractSession.VarId lookup5 = set == null ? null : lookup(arg5.name(), set, false);
            if (lookup5 == null && set2 != null) {
                lookup5 = lookup(arg5.name(), set2, false);
            }
            List<INDArray> listValue3 = getSdValue(lookup5).getListValue();
            Preconditions.checkState(listValue3 != null, "Could not find TensorArray: %s", lookup5);
            String name2 = differentialFunction.arg(1).name();
            INDArray arr2 = this.sameDiff.getVariable(name2).getArr();
            Preconditions.checkState(arr2.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", arr2, name2);
            Preconditions.checkState(arr2.dataType().isIntType(), "Indices variable for TensorArrayGather should be an integer type, got %s for array %s", arr2.dataType(), name2);
            int[] intVector = arr2.toIntVector();
            log.trace("Gathering op " + differentialFunction.getOwnName() + " from indices " + Arrays.toString(intVector) + " named " + name2 + " from list " + lookup5.getVariable());
            if (intVector.length <= 0) {
                return ExecutionResult.createFrom(lookup5.getVariable(), Nd4j.zeros(differentialFunction.arg().dataType(), 0));
            }
            ArrayList arrayList2 = new ArrayList();
            if (intVector.length == 1 || (intVector.length > 0 && intVector[0] < 0)) {
                arrayList2.addAll(listValue3);
            } else {
                int length = intVector.length;
                for (int i5 = 0; i5 < length; i5++) {
                    int i6 = intVector[i5];
                    Preconditions.checkState(i6 >= 0, "Index for TensorArrayGather must be >= 0, got %s", i6);
                    if (listValue3.get(i6) != null) {
                        log.trace("Gathering op " + differentialFunction.getOwnName() + " at index " + i6 + " adding value " + listValue3.get(i6).toStringFull() + " from full list " + listValue3);
                        arrayList2.add(listValue3.get(i6));
                    }
                }
            }
            Stack stack = new Stack((INDArray[]) ((List) arrayList2.stream().filter(iNDArray3 -> {
                return iNDArray3 != null;
            }).collect(Collectors.toList())).toArray(new INDArray[0]), (INDArray) null, 0);
            INDArray allocate2 = this.mmgr.allocate(false, stack.calculateOutputShape().get(0));
            stack.setOutputArgument(0, allocate2);
            Nd4j.exec(stack);
            return ExecutionResult.createFrom(lookup5.getVariable(), allocate2);
        }
        if (differentialFunction instanceof TensorArrayScatter) {
            TensorArray tensorArray5 = TensorArray.getTensorArray(this.sameDiff, differentialFunction.arg(0));
            AbstractSession.VarId lookup6 = set == null ? null : lookup(tensorArray5.outputVariablesNames()[0], set, false);
            if (lookup6 == null && set2 != null) {
                lookup6 = lookup(tensorArray5.outputVariablesNames()[0], set2, false);
            }
            SDValue sdValue2 = getSdValue(lookup6);
            List<INDArray> listValue4 = sdValue2.getListValue();
            Preconditions.checkState(listValue4 != null, "Could not find TensorArray: %s", lookup6);
            String name3 = differentialFunction.arg(1).name();
            INDArray arr3 = this.sameDiff.getVariable(name3).getArr();
            Preconditions.checkState(arr3.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", arr3, name3);
            Preconditions.checkState(arr3.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", arr3.dataType(), name3);
            int[] intVector2 = arr3.toIntVector();
            String name4 = differentialFunction.arg(2).name();
            SDVariable variable2 = this.sameDiff.getVariable(name4);
            INDArray array4 = getArray(variable2, set, set2);
            while (listValue4.size() < intVector2.length) {
                listValue4.add(null);
            }
            if (intVector2.length == 1 && intVector2[0] == -1) {
                intVector2 = ArrayUtil.range(0, (int) array4.size(0));
            }
            for (int i7 = 0; i7 < intVector2.length; i7++) {
                if (array4.size(0) < intVector2[i7]) {
                    throw new IllegalArgumentException("Unable to obtain slice from values array named " + name4 + " with shape " + Arrays.toString(array4.shape()) + " at index " + intVector2[i7] + " at node named " + differentialFunction.getOwnName() + " with inputs " + Arrays.toString(differentialFunction.argNames()));
                }
            }
            for (int i8 = 0; i8 < intVector2.length; i8++) {
                if (intVector2[i8] >= array4.size(0)) {
                    throw new IllegalStateException("Unable to pull slice from value array " + variable2.name() + " of shape " + Arrays.toString(array4.shape()) + " index was" + intVector2[i8] + " all indices were " + Arrays.toString(intVector2));
                }
                INDArray dup = this.mmgr.dup(array4.slice(intVector2[i8]));
                if (tensorArray5.args().length > 1) {
                    long[] longVector2 = tensorArray5.arg(1).getArr().toLongVector();
                    if (!Arrays.equals(dup.shape(), longVector2)) {
                        throw new IllegalArgumentException("Unable to write array of shape " + Arrays.toString(dup.shape()) + " must be " + longVector2 + " for op " + differentialFunction.getOwnName() + " and tensor array " + tensorArray5.getOwnName());
                    }
                }
                SDValue create2 = SDValue.create(dup);
                int i9 = intVector2[i8];
                if (array4.rank() == 1 && dup.rank() > 0) {
                    dup = dup.reshape(new long[0]);
                }
                if (i9 >= listValue4.size()) {
                    while (listValue4.size() <= i9) {
                        listValue4.add(null);
                    }
                }
                log.trace("Scattering item at index " + i8 + " for list " + lookup6 + " with value " + dup + " from whole list of " + listValue4 + " from values array " + array4.toStringFull() + " named " + variable2.name());
                setArrayAtIndex(listValue4, i9, dup);
                this.arrayUseTracker.addDependency(create2, new ExecDoneDep());
            }
            return ExecutionResult.createValue(name4, sdValue2);
        }
        if (!(differentialFunction instanceof TensorArraySplit)) {
            if (!(differentialFunction instanceof TensorArrayRemove)) {
                throw new IllegalStateException("Execution support not yet implemented for: " + differentialFunction.getClass().getName());
            }
            SDVariable arg6 = differentialFunction.arg(0);
            SDVariable arg7 = differentialFunction.arg(1);
            List<INDArray> tensorArraysInSession2 = getTensorArraysInSession(arg6.name());
            if (tensorArraysInSession2 == null) {
                tensorArraysInSession2 = new ArrayList();
            } else if (tensorArraysInSession2 != null) {
                tensorArraysInSession2.remove(arg7.getArr(true).getInt(0));
            }
            AbstractSession.VarId lookup7 = set == null ? null : lookup(arg6.name(), set, false);
            if (lookup7 == null && set2 != null) {
                lookup7 = lookup(arg6.name(), set2, false);
            }
            while (this.sameDiff.getVariableOutputOp(arg6.name()) instanceof Enter) {
                arg6 = this.sameDiff.getVariableOutputOp(arg6.name()).arg();
                lookup7 = lookup7.getParentFrame().toVarId(arg6.name());
            }
            putNodeValue(SDValue.create(tensorArraysInSession2), lookup7);
            return ExecutionResult.createValue(lookup7.getVariable(), tensorArraysInSession2);
        }
        SDVariable arg8 = differentialFunction.arg(0);
        AbstractSession.VarId lookup8 = set == null ? null : lookup(arg8.name(), set, false);
        if (lookup8 == null && set2 != null) {
            lookup8 = lookup(arg8.name(), set2, false);
        }
        while (this.sameDiff.getVariableOutputOp(arg8.name()) instanceof Enter) {
            arg8 = this.sameDiff.getVariableOutputOp(arg8.name()).arg();
            lookup8 = lookup8.getParentFrame().toVarId(arg8.name());
        }
        SDValue sdValue3 = getSdValue(lookup8);
        List<INDArray> listValue5 = sdValue3.getListValue();
        Preconditions.checkState(listValue5 != null, "Could not find TensorArray: %s", lookup8);
        INDArray array5 = getArray(this.sameDiff.getVariable(differentialFunction.arg(1).name()), set, set2);
        String name5 = differentialFunction.arg(2).name();
        INDArray array6 = getArray(this.sameDiff.getVariable(name5), set, set2);
        Preconditions.checkState(array6.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", array6, name5);
        Preconditions.checkState(array6.dataType().isIntType(), "Indices variable for TensorArraySplit should be an integer type, got %s for array %s", array6.dataType(), name5);
        int[] intVector3 = array6.toIntVector();
        while (listValue5.size() <= intVector3.length) {
            listValue5.add(null);
        }
        INDArrayIndex[] iNDArrayIndexArr = (INDArrayIndex[]) ArrayUtil.nTimes(array5.rank(), NDArrayIndex.all(), INDArrayIndex.class);
        int i10 = 0;
        for (int i11 = 0; i11 < intVector3.length; i11++) {
            iNDArrayIndexArr[0] = NDArrayIndex.interval(i10, i10 + intVector3[i11]);
            INDArray dup2 = this.mmgr.dup(array5.get(iNDArrayIndexArr));
            SDValue create3 = SDValue.create(dup2);
            setArrayAtIndex(listValue5, i11, dup2);
            i10 += intVector3[i11];
            this.arrayUseTracker.addDependency(create3, new ExecDoneDep());
        }
        return ExecutionResult.createValue(name5, sdValue3);
    }

    private Map<Pair<String, Integer>, SDValue> valuesFor(String str) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<AbstractSession.VarId, SDValue> entry : this.nodeValueOutputs.entrySet()) {
            if (entry.getKey().getVariable().equals(str)) {
                hashMap.put(Pair.of(entry.getKey().getVariable(), Integer.valueOf(entry.getKey().getIteration())), entry.getValue());
            }
        }
        return hashMap;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public INDArray getConstantOrVariable(String str) {
        Preconditions.checkState(this.sameDiff.getVariable(str).isConstant() || this.sameDiff.getVariable(str).getVariableType() == VariableType.VARIABLE, "Variable %s is not a constant", str);
        return this.sameDiff.getArrForVarName(str);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public Pair<SameDiffOp, OpContext> getAndParameterizeOp(String str, FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3, Map<String, INDArray> map, Set<String> set4, Map<String, SDValue> map2) {
        SameDiffOp sameDiffOp = this.sameDiff.getOps().get(str);
        DifferentialFunction op = sameDiffOp.getOp();
        Preconditions.checkNotNull(op, "No differential function found with name \"%s\"", str);
        if ((op instanceof LoopCond) || (op instanceof Enter) || (op instanceof Exit) || (op instanceof NextIteration) || (op instanceof Merge) || (op instanceof Switch) || (op instanceof BaseTensorOp) || (op instanceof Invoke)) {
            return new Pair<>(sameDiffOp, (Object) null);
        }
        String[] argNames = op.argNames();
        int length = argNames == null ? 0 : argNames.length;
        int size = set == null ? 0 : set.size();
        int size2 = set2 == null ? 0 : set2.size();
        int size3 = set3 == null ? 0 : set3.size();
        if (length != size + size3 + size2) {
            if (length > 1) {
                Collections.addAll(new LinkedHashSet(), argNames);
            } else {
                Preconditions.checkState(length == size + size3, "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", op.getClass().getSimpleName(), str, argNames, set, set3);
            }
        }
        INDArray[] iNDArrayArr = null;
        if (argNames != null && argNames.length > 0) {
            iNDArrayArr = new INDArray[argNames.length];
            int i = 0;
            for (String str2 : argNames) {
                SDVariable variable = this.sameDiff.getVariable(str2);
                if (variable.isConstant()) {
                    iNDArrayArr[i] = variable.getArr();
                } else if (variable.getVariableType() == VariableType.VARIABLE) {
                    iNDArrayArr[i] = variable.getArr();
                } else if (!variable.isPlaceHolder()) {
                    SDValue sdValue = getSdValue(lookup(str2, set, set2, true));
                    if (sdValue != null) {
                        switch (sdValue.getSdValueType()) {
                            case LIST:
                                DifferentialFunction variableOutputOp = this.sameDiff.getVariableOutputOp(str2);
                                if (!(variableOutputOp instanceof Switch) || variableOutputOp.argNames().length != 2 || !variableOutputOp.argNames()[0].equals(variableOutputOp.argNames()[1])) {
                                    iNDArrayArr[i] = Nd4j.empty(DataType.FLOAT);
                                    break;
                                } else {
                                    int i2 = 0;
                                    while (true) {
                                        if (i2 >= sdValue.getListValue().size()) {
                                            break;
                                        }
                                        if (sdValue.getListValue().get(i2) != null) {
                                            iNDArrayArr[i] = sdValue.getListValue().get(i2);
                                            break;
                                        } else {
                                            i2++;
                                        }
                                    }
                                }
                                break;
                            case TENSOR:
                                iNDArrayArr[i] = sdValue.getTensorValue();
                                break;
                        }
                    }
                } else if (map != null && map.containsKey(str2)) {
                    iNDArrayArr[i] = map.get(str2);
                } else {
                    if (map2 == null || !map2.containsKey(str2)) {
                        throw new IllegalArgumentException(String.format(str2, new Object[0]));
                    }
                    iNDArrayArr[i] = map2.get(str2).getTensorValue();
                }
                Preconditions.checkNotNull(iNDArrayArr[i], "Could not parameterize op %s: array %s (variable %s) is null", str, Integer.valueOf(i), variable.name());
                i++;
            }
        }
        boolean z = !frameIter.getFrame().equals(AbstractSession.OUTER_FRAME) && frameIter.getIteration() > 0;
        OpContext opContext = this.opContexts.get(str);
        if (opContext == null) {
            opContext = Nd4j.getExecutioner().buildContext();
            this.opContexts.put(str, opContext);
        }
        if (op instanceof CustomOp) {
            DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) op;
            if ((op instanceof Identity) || (op instanceof CreateView)) {
                if (iNDArrayArr != null) {
                    opContext.setInputArrays(iNDArrayArr);
                }
                opContext.setOutputArrays(iNDArrayArr[0]);
                return new Pair<>(sameDiffOp, opContext);
            }
            opContext.setArgs(iNDArrayArr, dynamicCustomOp.iArgs(), dynamicCustomOp.dArgs(), dynamicCustomOp.tArgs(), dynamicCustomOp.bArgs());
            if (op instanceof Assign) {
                opContext.setOutputArray(0, opContext.getInputArray(0));
            } else {
                List<LongShapeDescriptor> calculateOutputShape = dynamicCustomOp.calculateOutputShape(opContext);
                Preconditions.checkState(calculateOutputShape != null && calculateOutputShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", dynamicCustomOp.opName(), dynamicCustomOp.getOwnName());
                String[] outputVariablesNames = op.outputVariablesNames();
                Preconditions.checkState(outputVariablesNames.length == calculateOutputShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation with %s outputs (number of shapes and outputs must be equal)", op.opName(), Integer.valueOf(calculateOutputShape.size()), Integer.valueOf(outputVariablesNames.length));
                for (int i3 = 0; i3 < calculateOutputShape.size(); i3++) {
                    LongShapeDescriptor longShapeDescriptor = calculateOutputShape.get(i3);
                    DataType dataType = this.sameDiff.getVariable(outputVariablesNames[i3]).dataType();
                    if (dataType != longShapeDescriptor.dataType()) {
                        longShapeDescriptor = longShapeDescriptor.asDataType(dataType);
                    }
                    INDArray allocate = this.mmgr.allocate(set4.contains(outputVariablesNames[i3]), longShapeDescriptor);
                    if (longShapeDescriptor.isEmpty() && !allocate.isEmpty()) {
                        throw new IllegalStateException("Output shape was empty, but created array was not.");
                    }
                    opContext.setOutputArray(i3, allocate);
                }
            }
        } else if (op instanceof Op) {
            Op op2 = (Op) op;
            boolean z2 = false;
            boolean z3 = false;
            if ((op2 instanceof ReduceOp) && ((ReduceOp) op2).getOpType() != Op.Type.REDUCE3 && op.argNames().length == 2) {
                SDVariable arg = op.arg(1);
                Preconditions.checkState(arg.dataType().isIntType(), "Legacy op %s input 1 (axis) was expected to be an integer type, is %s", op.getClass(), arg.dataType());
                INDArray array = getArray(arg, set, set2);
                Preconditions.checkState(array != null, "Could not get axis argument for op %s: %s", op.getOwnName(), op.getClass());
                if (array.isEmpty()) {
                    op.setDimensions(null);
                    z3 = true;
                    ((BaseReduceOp) op2).setEmptyReduce(true);
                } else {
                    op.setDimensions(Shape.normalizeAxis(iNDArrayArr[0].rank(), array.toIntVector()));
                    ((BaseReduceOp) op2).setEmptyReduce(false);
                }
                z2 = true;
            } else if ((op2 instanceof ScalarOp) && op.argNames().length == 2) {
                INDArray array2 = getArray(op.arg(1), set, set2);
                Preconditions.checkState(array2 != null, "Could not get scalar argument for op %s: %s", op.getOwnName(), op.getClass());
                Preconditions.checkState(array2.isScalar(), "Scalar argument for op %s (%s) is not a scalar: has shape %ndShape", op.getOwnName(), op.getClass(), array2);
                ((ScalarOp) op2).setScalar(array2);
            }
            if (iNDArrayArr != null && iNDArrayArr.length > 0) {
                opContext.setInputArray(0, iNDArrayArr[0]);
                if (iNDArrayArr.length == 2 && !z2) {
                    opContext.setInputArray(1, iNDArrayArr[1]);
                }
            }
            boolean contains = set4.contains(((BaseOp) op2).outputVariablesNames()[0]);
            if (z3) {
                opContext.setOutputArray(0, this.mmgr.allocate(false, opContext.getInputArray(0).dataType(), opContext.getInputArray(0).shape()));
            } else {
                List<LongShapeDescriptor> calculateOutputShape2 = ((BaseOp) op2).calculateOutputShape(opContext);
                Preconditions.checkState(calculateOutputShape2 != null && calculateOutputShape2.size() == 1, "Could not calculate output shape for op: %s", op2.getClass());
                opContext.setOutputArray(0, this.mmgr.allocate(contains, calculateOutputShape2.get(0)));
            }
        }
        return new Pair<>(sameDiffOp, opContext);
    }

    protected INDArray getArray(SDVariable sDVariable, Collection<AbstractSession.VarId> collection, Collection<AbstractSession.VarId> collection2) {
        String name = sDVariable.name();
        if (sDVariable.getVariableType() == VariableType.CONSTANT || sDVariable.getVariableType() == VariableType.VARIABLE) {
            return getConstantOrVariable(name);
        }
        AbstractSession.VarId lookup = lookup(name, collection, collection2, false);
        Preconditions.checkState(lookup != null, "Could not find array for variable %s", sDVariable.name());
        return getTensorFromOutputs(lookup);
    }

    public SessionMemMgr getMmgr() {
        return this.mmgr;
    }

    public void setMmgr(SessionMemMgr sessionMemMgr) {
        this.mmgr = sessionMemMgr;
    }

    public AbstractDependencyTracker<SDValue, Dep> getArrayUseTracker() {
        return this.arrayUseTracker;
    }

    public void setArrayUseTracker(AbstractDependencyTracker<SDValue, Dep> abstractDependencyTracker) {
        this.arrayUseTracker = abstractDependencyTracker;
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public /* bridge */ /* synthetic */ ExecutionResult getOutputs(Pair<SameDiffOp, OpContext> pair, FrameIter frameIter, Set set, Set set2, Set set3, List list, At at, MultiDataSet multiDataSet, Set set4, Map map) {
        return getOutputs2(pair, frameIter, (Set<AbstractSession.VarId>) set, (Set<AbstractSession.VarId>) set2, (Set<String>) set3, (List<Listener>) list, at, multiDataSet, (Set<String>) set4, (Map<String, SDValue>) map);
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public /* bridge */ /* synthetic */ Pair<SameDiffOp, OpContext> getAndParameterizeOp(String str, FrameIter frameIter, Set set, Set set2, Set set3, Map<String, INDArray> map, Set set4, Map map2) {
        return getAndParameterizeOp(str, frameIter, (Set<AbstractSession.VarId>) set, (Set<AbstractSession.VarId>) set2, (Set<String>) set3, map, (Set<String>) set4, (Map<String, SDValue>) map2);
    }
}
