package org.nd4j.linalg.api.ops.custom;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.config.ExecutionResult;
import org.nd4j.autodiff.samediff.config.SDValue;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;

/* loaded from: input_file:org/nd4j/linalg/api/ops/custom/Invoke.class */
public class Invoke extends DynamicCustomOp {
    private String functionName;
    private String[] inputVarNames;
    private String[] outputVarNames;
    private String[] subGraphInputVarNames;
    private String[] subGraphOutputVarNames;

    /* loaded from: input_file:org/nd4j/linalg/api/ops/custom/Invoke$InvokeParams.class */
    public static class InvokeParams {
        private String functionName;
        private SDVariable[] inputs;
        private String[] inputVarNames;
        private String[] outputVarNames;
        private String[] subGraphInputVarNames;
        private String[] subGraphOutputVarNames;

        /* loaded from: input_file:org/nd4j/linalg/api/ops/custom/Invoke$InvokeParams$InvokeParamsBuilder.class */
        public static class InvokeParamsBuilder {
            private String functionName;
            private SDVariable[] inputs;
            private String[] inputVarNames;
            private String[] outputVarNames;
            private String[] subGraphInputVarNames;
            private String[] subGraphOutputVarNames;

            InvokeParamsBuilder() {
            }

            public InvokeParamsBuilder functionName(String str) {
                this.functionName = str;
                return this;
            }

            public InvokeParamsBuilder inputs(SDVariable[] sDVariableArr) {
                this.inputs = sDVariableArr;
                return this;
            }

            public InvokeParamsBuilder inputVarNames(String[] strArr) {
                this.inputVarNames = strArr;
                return this;
            }

            public InvokeParamsBuilder outputVarNames(String[] strArr) {
                this.outputVarNames = strArr;
                return this;
            }

            public InvokeParamsBuilder subGraphInputVarNames(String[] strArr) {
                this.subGraphInputVarNames = strArr;
                return this;
            }

            public InvokeParamsBuilder subGraphOutputVarNames(String[] strArr) {
                this.subGraphOutputVarNames = strArr;
                return this;
            }

            public InvokeParams build() {
                return new InvokeParams(this.functionName, this.inputs, this.inputVarNames, this.outputVarNames, this.subGraphInputVarNames, this.subGraphOutputVarNames);
            }

            public String toString() {
                return "Invoke.InvokeParams.InvokeParamsBuilder(functionName=" + this.functionName + ", inputs=" + Arrays.deepToString(this.inputs) + ", inputVarNames=" + Arrays.deepToString(this.inputVarNames) + ", outputVarNames=" + Arrays.deepToString(this.outputVarNames) + ", subGraphInputVarNames=" + Arrays.deepToString(this.subGraphInputVarNames) + ", subGraphOutputVarNames=" + Arrays.deepToString(this.subGraphOutputVarNames) + ")";
            }
        }

        InvokeParams(String str, SDVariable[] sDVariableArr, String[] strArr, String[] strArr2, String[] strArr3, String[] strArr4) {
            this.functionName = str;
            this.inputs = sDVariableArr;
            this.inputVarNames = strArr;
            this.outputVarNames = strArr2;
            this.subGraphInputVarNames = strArr3;
            this.subGraphOutputVarNames = strArr4;
        }

        public static InvokeParamsBuilder builder() {
            return new InvokeParamsBuilder();
        }

        public String getFunctionName() {
            return this.functionName;
        }

        public SDVariable[] getInputs() {
            return this.inputs;
        }

        public String[] getInputVarNames() {
            return this.inputVarNames;
        }

        public String[] getOutputVarNames() {
            return this.outputVarNames;
        }

        public String[] getSubGraphInputVarNames() {
            return this.subGraphInputVarNames;
        }

        public String[] getSubGraphOutputVarNames() {
            return this.subGraphOutputVarNames;
        }

        public void setFunctionName(String str) {
            this.functionName = str;
        }

        public void setInputs(SDVariable[] sDVariableArr) {
            this.inputs = sDVariableArr;
        }

        public void setInputVarNames(String[] strArr) {
            this.inputVarNames = strArr;
        }

        public void setOutputVarNames(String[] strArr) {
            this.outputVarNames = strArr;
        }

        public void setSubGraphInputVarNames(String[] strArr) {
            this.subGraphInputVarNames = strArr;
        }

        public void setSubGraphOutputVarNames(String[] strArr) {
            this.subGraphOutputVarNames = strArr;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof InvokeParams)) {
                return false;
            }
            InvokeParams invokeParams = (InvokeParams) obj;
            if (!invokeParams.canEqual(this)) {
                return false;
            }
            String functionName = getFunctionName();
            String functionName2 = invokeParams.getFunctionName();
            if (functionName == null) {
                if (functionName2 != null) {
                    return false;
                }
            } else if (!functionName.equals(functionName2)) {
                return false;
            }
            return Arrays.deepEquals(getInputs(), invokeParams.getInputs()) && Arrays.deepEquals(getInputVarNames(), invokeParams.getInputVarNames()) && Arrays.deepEquals(getOutputVarNames(), invokeParams.getOutputVarNames()) && Arrays.deepEquals(getSubGraphInputVarNames(), invokeParams.getSubGraphInputVarNames()) && Arrays.deepEquals(getSubGraphOutputVarNames(), invokeParams.getSubGraphOutputVarNames());
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof InvokeParams;
        }

        public int hashCode() {
            String functionName = getFunctionName();
            return (((((((((((1 * 59) + (functionName == null ? 43 : functionName.hashCode())) * 59) + Arrays.deepHashCode(getInputs())) * 59) + Arrays.deepHashCode(getInputVarNames())) * 59) + Arrays.deepHashCode(getOutputVarNames())) * 59) + Arrays.deepHashCode(getSubGraphInputVarNames())) * 59) + Arrays.deepHashCode(getSubGraphOutputVarNames());
        }

        public String toString() {
            return "Invoke.InvokeParams(functionName=" + getFunctionName() + ", inputs=" + Arrays.deepToString(getInputs()) + ", inputVarNames=" + Arrays.deepToString(getInputVarNames()) + ", outputVarNames=" + Arrays.deepToString(getOutputVarNames()) + ", subGraphInputVarNames=" + Arrays.deepToString(getSubGraphInputVarNames()) + ", subGraphOutputVarNames=" + Arrays.deepToString(getSubGraphOutputVarNames()) + ")";
        }
    }

    public Invoke() {
    }

    public Invoke(SameDiff sameDiff, InvokeParams invokeParams) {
        super(sameDiff, invokeParams.inputs);
        this.sameDiff = sameDiff;
        this.outputVarNames = invokeParams.outputVarNames;
        this.functionName = invokeParams.functionName;
        this.inputVarNames = invokeParams.inputVarNames;
        this.subGraphInputVarNames = invokeParams.subGraphInputVarNames;
        this.subGraphOutputVarNames = invokeParams.subGraphOutputVarNames;
    }

    public static ExecutionResult doInvoke(DifferentialFunction differentialFunction, Map<String, INDArray> map, Map<String, SDValue> map2) {
        Invoke invoke = (Invoke) differentialFunction;
        SameDiff function = differentialFunction.getSameDiff().getFunction(invoke.getFunctionName());
        function.setEnableCache(false);
        SDVariable[] args = differentialFunction.args();
        String[] inputVarNames = invoke.getInputVarNames();
        String[] strArr = invoke.subGraphInputVarNames;
        if (strArr == null) {
            strArr = inputVarNames;
        }
        SDVariable[] outputVariables = differentialFunction.outputVariables();
        if (inputVarNames == null) {
            inputVarNames = new String[args.length];
            for (int i = 0; i < inputVarNames.length; i++) {
                inputVarNames[i] = args[i].name();
            }
        }
        String[] outputVarNames = invoke.getOutputVarNames();
        if (outputVarNames == null) {
            outputVarNames = new String[outputVariables.length];
            for (int i2 = 0; i2 < outputVariables.length; i2++) {
                outputVarNames[i2] = outputVariables[i2].name();
            }
        }
        String[] strArr2 = invoke.subGraphOutputVarNames;
        if (strArr2 == null) {
            strArr2 = outputVarNames;
        }
        List<String> asList = Arrays.asList(strArr2);
        if (!map2.isEmpty()) {
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (int i3 = 0; i3 < inputVarNames.length; i3++) {
                linkedHashMap.put(strArr[i3], map2.get(differentialFunction.argNames()[i3]));
            }
            Map<String, SDValue> outputValues = function.outputValues(linkedHashMap, asList);
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            for (int i4 = 0; i4 < outputVarNames.length; i4++) {
                linkedHashMap2.put(outputVariables[i4].name(), outputValues.get(strArr2[i4]));
            }
            return ExecutionResult.builder().valueOutputs(linkedHashMap2).build();
        }
        INDArray[] iNDArrayArr = new INDArray[strArr2.length];
        LinkedHashMap linkedHashMap3 = new LinkedHashMap();
        for (int i5 = 0; i5 < inputVarNames.length; i5++) {
            linkedHashMap3.put(strArr[i5], map.get(differentialFunction.argNames()[i5]));
        }
        Map<String, INDArray> output = function.output(linkedHashMap3, asList);
        int i6 = 0;
        for (Map.Entry<String, INDArray> entry : output.entrySet()) {
            if (asList.contains(entry.getKey())) {
                iNDArrayArr[i6] = output.get(entry.getKey());
                i6++;
            }
        }
        return ExecutionResult.builder().outputs(ExecutionResult.pack(output)).build();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable[] outputVariables() {
        if (this.outputVariables != null) {
            return this.outputVariables;
        }
        SameDiff function = this.sameDiff.getFunction(this.functionName);
        if (function == null) {
            throw new IllegalArgumentException("Unable to determine output data types for variables. No function of " + this.functionName + " found!");
        }
        if (this.subGraphOutputVarNames == null) {
            throw new IllegalStateException("Invalid InvokeConfiguration found. Please specify sub graph output names.");
        }
        SDVariable[] sDVariableArr = new SDVariable[this.subGraphOutputVarNames.length];
        for (int i = 0; i < this.subGraphOutputVarNames.length; i++) {
            String str = this.subGraphOutputVarNames[i];
            SDVariable variable = function.getVariable(str);
            if (variable == null) {
                throw new IllegalStateException("No variable found in sub graph named " + str);
            }
            switch (variable.getVariableType()) {
                case VARIABLE:
                case ARRAY:
                case PLACEHOLDER:
                case SEQUENCE:
                    if (variable.getShape() != null) {
                        SDVariable var = this.sameDiff.var(str + "_" + this.functionName, variable.dataType(), variable.getShape());
                        var.setVariableType(VariableType.ARRAY);
                        sDVariableArr[i] = var;
                        break;
                    } else {
                        SDVariable var2 = this.sameDiff.var(str + "_" + this.functionName, variable.dataType(), new int[0]);
                        var2.setVariableType(VariableType.ARRAY);
                        sDVariableArr[i] = var2;
                        break;
                    }
                case CONSTANT:
                    SDVariable var3 = this.sameDiff.var(str + "_" + this.functionName, variable.dataType(), new int[0]);
                    var3.setVariableType(VariableType.ARRAY);
                    sDVariableArr[i] = var3;
                    break;
            }
        }
        this.outputVariables = sDVariableArr;
        if (this.outputVarNames != null && this.outputVarNames.length == sDVariableArr.length) {
            for (int i2 = 0; i2 < sDVariableArr.length; i2++) {
                if (!sDVariableArr[i2].name().equals(this.outputVarNames[i2])) {
                    this.sameDiff.updateVariableNameAndReference(sDVariableArr[i2], this.outputVarNames[i2], true);
                }
            }
        } else if (this.outputVariables == null) {
            throw new IllegalArgumentException("Invalid configuration for output variable names. Must be equal to the number of outputs.");
        }
        addOutputsToOp();
        return sDVariableArr;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int getNumOutputs() {
        if (this.subGraphOutputVarNames != null) {
            return this.subGraphOutputVarNames.length;
        }
        if (this.outputVarNames != null) {
            return this.outputVarNames.length;
        }
        return 1;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "invoke";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public void configureFromArguments() {
        super.configureFromArguments();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void configureWithSameDiff(SameDiff sameDiff) {
        super.configureWithSameDiff(sameDiff);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < getNumOutputs(); i++) {
            arrayList.add(DataType.FLOAT);
        }
        return arrayList;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape() {
        return Collections.emptyList();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < getNumOutputs(); i++) {
            arrayList.add(LongShapeDescriptor.fromShape(new int[]{1}, DataType.DOUBLE));
        }
        return arrayList;
    }

    public String getFunctionName() {
        return this.functionName;
    }

    public String[] getInputVarNames() {
        return this.inputVarNames;
    }

    public String[] getOutputVarNames() {
        return this.outputVarNames;
    }

    public String[] getSubGraphInputVarNames() {
        return this.subGraphInputVarNames;
    }

    public String[] getSubGraphOutputVarNames() {
        return this.subGraphOutputVarNames;
    }
}
