package org.nd4j.autodiff.samediff.config;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;

/* loaded from: input_file:org/nd4j/autodiff/samediff/config/ExecutionResult.class */
public class ExecutionResult {
    private Map<String, Optional<INDArray>> outputs;
    private Map<String, SDValue> valueOutputs;

    /* loaded from: input_file:org/nd4j/autodiff/samediff/config/ExecutionResult$ExecutionResultBuilder.class */
    public static class ExecutionResultBuilder {
        private Map<String, Optional<INDArray>> outputs;
        private Map<String, SDValue> valueOutputs;

        ExecutionResultBuilder() {
        }

        public ExecutionResultBuilder outputs(Map<String, Optional<INDArray>> map) {
            this.outputs = map;
            return this;
        }

        public ExecutionResultBuilder valueOutputs(Map<String, SDValue> map) {
            this.valueOutputs = map;
            return this;
        }

        public ExecutionResult build() {
            return new ExecutionResult(this.outputs, this.valueOutputs);
        }

        public String toString() {
            return "ExecutionResult.ExecutionResultBuilder(outputs=" + this.outputs + ", valueOutputs=" + this.valueOutputs + ")";
        }
    }

    public static ExecutionResult createFrom(List<String> list, List<INDArray> list2) {
        Preconditions.checkState(list.size() == list2.size(), "Inputs and names must be equal size!");
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < list2.size(); i++) {
            linkedHashMap.put(list.get(i), list2.get(i) == null ? Optional.empty() : Optional.of(list2.get(i)));
        }
        return builder().outputs(linkedHashMap).build();
    }

    public static ExecutionResult createValue(String str, SDValue sDValue) {
        return builder().valueOutputs(Collections.singletonMap(str, sDValue)).build();
    }

    public static ExecutionResult createValue(String str, List list) {
        return builder().valueOutputs(Collections.singletonMap(str, SDValue.create((List<INDArray>) list))).build();
    }

    public static ExecutionResult createFrom(String str, INDArray iNDArray) {
        return createFrom((List<String>) Arrays.asList(str), (List<INDArray>) Arrays.asList(iNDArray));
    }

    public static ExecutionResult createFrom(DifferentialFunction differentialFunction, OpContext opContext) {
        return createFrom((List<String>) Arrays.asList(differentialFunction.outputVariablesNames()), (INDArray[]) opContext.getOutputArrays().toArray(new INDArray[opContext.getOutputArrays().size()]));
    }

    public static ExecutionResult createFrom(List<String> list, INDArray[] iNDArrayArr) {
        Preconditions.checkState(list.size() == iNDArrayArr.length, "Inputs and names must be equal size!");
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < iNDArrayArr.length; i++) {
            linkedHashMap.put(list.get(i), Optional.ofNullable(iNDArrayArr[i]));
        }
        return builder().outputs(linkedHashMap).build();
    }

    public INDArray[] outputsToArray(List<String> list) {
        if (this.valueOutputs == null) {
            if (this.outputs == null) {
                throw new IllegalStateException("No outputs to be converted.");
            }
            INDArray[] iNDArrayArr = new INDArray[list.size()];
            for (int i = 0; i < list.size(); i++) {
                try {
                    iNDArrayArr[i] = this.outputs.get(list.get(i)).get();
                } catch (NullPointerException e) {
                    iNDArrayArr[i] = null;
                }
            }
            return iNDArrayArr;
        }
        INDArray[] iNDArrayArr2 = new INDArray[this.valueOutputs.size()];
        int i2 = 0;
        for (Map.Entry<String, SDValue> entry : this.valueOutputs.entrySet()) {
            if (entry.getValue() != null) {
                int i3 = i2;
                i2++;
                iNDArrayArr2[i3] = entry.getValue().getTensorValue();
            }
        }
        return iNDArrayArr2;
    }

    public boolean hasValues() {
        return this.valueOutputs != null;
    }

    public boolean hasSingle() {
        return this.outputs != null;
    }

    public int numResults() {
        if (this.outputs != null && !this.outputs.isEmpty()) {
            return this.outputs.size();
        }
        if (this.valueOutputs == null || this.valueOutputs.isEmpty()) {
            return 0;
        }
        return this.valueOutputs.size();
    }

    public boolean valueExistsAtIndex(int i) {
        SDValue valueWithKey;
        if (this.outputs != null) {
            return resultAt(i) != null;
        }
        if (this.valueOutputs == null || (valueWithKey = valueWithKey(valueAtIndex(i))) == null) {
            return false;
        }
        switch (valueWithKey.getSdValueType()) {
            case TENSOR:
                return valueWithKey.getTensorValue() != null;
            case LIST:
                return valueWithKey.getListValue() != null;
            default:
                return false;
        }
    }

    public boolean isNull() {
        return this.valueOutputs == null && this.outputs == null;
    }

    public INDArray resultOrValueAt(int i, boolean z) {
        if (!hasValues()) {
            return resultAt(i);
        }
        SDValue valueWithKeyAtIndex = valueWithKeyAtIndex(i, z);
        if (valueWithKeyAtIndex != null) {
            return valueWithKeyAtIndex.getTensorValue();
        }
        return null;
    }

    private String valueAtIndex(int i) {
        int i2 = 0;
        for (String str : this.valueOutputs != null ? this.valueOutputs.keySet() : this.outputs.keySet()) {
            if (i2 == i) {
                return str;
            }
            i2++;
        }
        return null;
    }

    public SDValue valueWithKeyAtIndex(int i, boolean z) {
        if (this.valueOutputs == null) {
            return null;
        }
        String valueAtIndex = valueAtIndex(i);
        if (!this.valueOutputs.containsKey(valueAtIndex)) {
            return this.valueOutputs.get(valueAtIndex);
        }
        SDValue sDValue = this.valueOutputs.get(valueAtIndex);
        return (sDValue != null && sDValue.getSdValueType() == SDValueType.LIST && z) ? SDValue.create(Nd4j.empty(DataType.FLOAT)) : sDValue;
    }

    public SDValue valueWithKey(String str) {
        if (this.valueOutputs == null) {
            return null;
        }
        return this.valueOutputs.get(str);
    }

    public INDArray resultAt(int i) {
        if (this.outputs == null) {
            return null;
        }
        return this.outputs.get(valueAtIndex(i)).get();
    }

    public static Map<String, INDArray> unpack(Map<String, Optional<INDArray>> map) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<String, Optional<INDArray>> entry : map.entrySet()) {
            linkedHashMap.put(entry.getKey(), entry.getValue().get());
        }
        return linkedHashMap;
    }

    public static Map<String, Optional<INDArray>> pack(Map<String, INDArray> map) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            linkedHashMap.put(entry.getKey(), Optional.ofNullable(entry.getValue().get(new INDArrayIndex[0])));
        }
        return linkedHashMap;
    }

    ExecutionResult(Map<String, Optional<INDArray>> map, Map<String, SDValue> map2) {
        this.outputs = map;
        this.valueOutputs = map2;
    }

    public static ExecutionResultBuilder builder() {
        return new ExecutionResultBuilder();
    }

    public Map<String, Optional<INDArray>> getOutputs() {
        return this.outputs;
    }

    public Map<String, SDValue> getValueOutputs() {
        return this.valueOutputs;
    }

    public void setOutputs(Map<String, Optional<INDArray>> map) {
        this.outputs = map;
    }

    public void setValueOutputs(Map<String, SDValue> map) {
        this.valueOutputs = map;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ExecutionResult)) {
            return false;
        }
        ExecutionResult executionResult = (ExecutionResult) obj;
        if (!executionResult.canEqual(this)) {
            return false;
        }
        Map<String, Optional<INDArray>> outputs = getOutputs();
        Map<String, Optional<INDArray>> outputs2 = executionResult.getOutputs();
        if (outputs == null) {
            if (outputs2 != null) {
                return false;
            }
        } else if (!outputs.equals(outputs2)) {
            return false;
        }
        Map<String, SDValue> valueOutputs = getValueOutputs();
        Map<String, SDValue> valueOutputs2 = executionResult.getValueOutputs();
        return valueOutputs == null ? valueOutputs2 == null : valueOutputs.equals(valueOutputs2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ExecutionResult;
    }

    public int hashCode() {
        Map<String, Optional<INDArray>> outputs = getOutputs();
        int hashCode = (1 * 59) + (outputs == null ? 43 : outputs.hashCode());
        Map<String, SDValue> valueOutputs = getValueOutputs();
        return (hashCode * 59) + (valueOutputs == null ? 43 : valueOutputs.hashCode());
    }

    public String toString() {
        return "ExecutionResult(outputs=" + getOutputs() + ", valueOutputs=" + getValueOutputs() + ")";
    }
}
