package org.nd4j.autodiff.samediff.config;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.autodiff.samediff.internal.IDependeeGroup;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/autodiff/samediff/config/SDValue.class */
public class SDValue implements IDependeeGroup<INDArray> {
    private SDValueType sdValueType;
    private INDArray tensorValue;
    private Map<String, INDArray> dictValue;
    private List<INDArray> listValue;
    private static final AtomicLong counter = new AtomicLong(0);
    protected transient long id = counter.getAndIncrement();

    private SDValue() {
    }

    @Override // org.nd4j.autodiff.samediff.internal.IDependeeGroup
    public long getId() {
        return this.id;
    }

    @Override // org.nd4j.autodiff.samediff.internal.IDependeeGroup
    public Collection<INDArray> getCollection() {
        return getListValue();
    }

    public static SDValue empty(SDValueType sDValueType, DataType dataType) {
        switch (sDValueType) {
            case LIST:
                return create((List<INDArray>) Arrays.asList(new INDArray[0]));
            case DICT:
                return create((Map<String, INDArray>) Collections.emptyMap());
            case TENSOR:
                return create(Nd4j.zeros(1).castTo(dataType));
            default:
                throw new IllegalArgumentException("Unable to create empty value, unknown value type " + sDValueType);
        }
    }

    public INDArray getTensorValue() {
        return (this.listValue == null || this.listValue.size() != 1) ? this.tensorValue : this.listValue.get(0);
    }

    public List<INDArray> getListValue() {
        return this.tensorValue != null ? Arrays.asList(this.tensorValue) : this.listValue;
    }

    public static SDValue create(INDArray iNDArray) {
        SDValue sDValue = new SDValue();
        sDValue.tensorValue = iNDArray;
        sDValue.sdValueType = SDValueType.TENSOR;
        return sDValue;
    }

    public static SDValue create(Collection<INDArray> collection) {
        SDValue sDValue = new SDValue();
        sDValue.listValue = (List) collection;
        sDValue.sdValueType = SDValueType.LIST;
        return sDValue;
    }

    public static SDValue create(List<INDArray> list) {
        SDValue sDValue = new SDValue();
        sDValue.listValue = list;
        sDValue.sdValueType = SDValueType.LIST;
        return sDValue;
    }

    public static SDValue create(Map<String, INDArray> map) {
        SDValue sDValue = new SDValue();
        sDValue.dictValue = map;
        sDValue.sdValueType = SDValueType.DICT;
        return sDValue;
    }

    public boolean equals(Object obj) {
        return ((SDValue) obj).getId() == getId();
    }

    public int hashCode() {
        return Long.hashCode(getId());
    }

    public String toString() {
        INDArray tensorValue = getTensorValue();
        StringBuilder sb = new StringBuilder();
        if (tensorValue != null) {
            sb.append("--sdValueId-");
            long id = getId();
            SDValueType sdValueType = getSdValueType();
            tensorValue.getId();
            sb.append(id + "--key--" + sb + " --Array " + sdValueType);
        } else {
            List<INDArray> listValue = getListValue();
            if (listValue != null && listValue.size() > 0) {
                sb.append("--sdValueId-");
                long id2 = getId();
                SDValueType sdValueType2 = getSdValueType();
                listValue.size();
                sb.append(id2 + "--key--" + sb + " -- List Size " + sdValueType2);
                for (INDArray iNDArray : getListValue()) {
                    if (iNDArray == null) {
                        sb.append(" --Array NULL ");
                    } else {
                        sb.append(" --Array " + iNDArray.getId() + " --\t ");
                    }
                }
            }
        }
        return sb.toString();
    }

    public SDValueType getSdValueType() {
        return this.sdValueType;
    }

    public Map<String, INDArray> getDictValue() {
        return this.dictValue;
    }
}
