package org.nd4j.autodiff.functions;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/autodiff/functions/DifferentialFunction.class */
public abstract class DifferentialFunction {
    private static final Logger log = LoggerFactory.getLogger(DifferentialFunction.class);

    @JsonIgnore
    protected SameDiff sameDiff;

    @JsonIgnore
    protected boolean inPlace;

    @JsonIgnore
    protected INDArray scalarValue;

    @JsonIgnore
    protected int[] dimensions;

    @JsonIgnore
    protected Object[] extraArgs;

    @JsonIgnore
    protected String ownName;

    @JsonIgnore
    protected boolean ownNameSetWithDefault;

    public DifferentialFunction() {
        this(true);
    }

    public DifferentialFunction(boolean z) {
        this.ownNameSetWithDefault = false;
        if (z) {
            setInstanceId();
        }
    }

    public DifferentialFunction(SameDiff sameDiff, NodeDef nodeDef, Map<String, AttrValue> map, GraphDef graphDef) {
        this.ownNameSetWithDefault = false;
        this.sameDiff = sameDiff;
        setInstanceId();
        initFromTensorFlow(nodeDef, sameDiff, map, graphDef);
    }

    public DifferentialFunction(SameDiff sameDiff, Onnx.NodeProto nodeProto, Map<String, Onnx.AttributeProto> map, Onnx.GraphProto graphProto) {
        this.ownNameSetWithDefault = false;
        this.sameDiff = sameDiff;
        setInstanceId();
        initFromOnnx(nodeProto, sameDiff, map, graphProto);
    }

    public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
        return Collections.emptyMap();
    }

    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        return Collections.emptyMap();
    }

    public Map<String, Object> propertiesForFunction() {
        Map<String, Field> fieldsForFunction = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Preconditions.checkNotNull(fieldsForFunction, "DifferentialFunctionClassHolder returned null fields for %s - op has not been added to ImportClassMapping?", getClass());
        for (Map.Entry<String, Field> entry : fieldsForFunction.entrySet()) {
            try {
                linkedHashMap.put(entry.getKey(), fieldsForFunction.get(entry.getKey()).get(this));
            } catch (IllegalAccessException e) {
                throw new RuntimeException("Unable to get property for field: " + entry.getKey(), e);
            }
        }
        return linkedHashMap;
    }

    public void configureWithSameDiff(SameDiff sameDiff) {
    }

    public void setPropertiesForFunction(Map<String, Object> map) {
        Map<String, Field> fieldsForFunction = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        for (String str : map.keySet()) {
            Field field = fieldsForFunction.get(str);
            if (field == null) {
                log.warn("No fields found for property name {} for class {}", str, getClass().getName());
            } else {
                setValueFor(field, map.get(str));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Boolean getBooleanFromProperty(String str, Map<String, Object> map) {
        if (map.containsKey(str)) {
            return (Boolean) map.get(str);
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String getStringFromProperty(String str, Map<String, Object> map) {
        if (map.containsKey(str)) {
            return (String) map.get(str);
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Integer getIntValueFromProperty(String str, Map<String, Object> map) {
        if (map.containsKey(str)) {
            return Integer.valueOf(((Number) map.get(str)).intValue());
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Long getLongValueFromProperty(String str, Map<String, Object> map) {
        if (map.containsKey(str)) {
            return Long.valueOf(((Number) map.get(str)).longValue());
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Double getDoubleValueFromProperty(String str, Map<String, Object> map) {
        if (map.containsKey(str)) {
            return Double.valueOf(((Number) map.get(str)).doubleValue());
        }
        return null;
    }

    public Object getValue(Field field) {
        try {
            return field.get(this);
        } catch (IllegalAccessException e) {
            log.error("", e);
            return null;
        }
    }

    public void setValueFor(Field field, Object obj) {
        if (obj == null) {
            if (field.getType().isPrimitive()) {
                throw new ND4JIllegalStateException("Unable to set primitive field " + field + " of type " + field.getClass() + " using null value!");
            }
        }
        if (obj != null) {
            obj = ensureProperType(field, obj);
        }
        if (isConfigProperties()) {
            String configFieldName = configFieldName();
            if (configFieldName == null) {
                configFieldName = "config";
            }
            Field field2 = null;
            Class<?> cls = getClass();
            try {
                field2 = cls.getDeclaredField(configFieldName);
            } catch (NoSuchFieldException e) {
            }
            while (field2 == null && cls.getSuperclass() != null) {
                cls = cls.getSuperclass();
                try {
                    field2 = cls.getDeclaredField(configFieldName);
                } catch (NoSuchFieldException e2) {
                }
            }
            if (field2 == null) {
                throw new IllegalStateException("Could not find field \"" + configFieldName + "\" for class " + getClass().getName());
            }
            try {
                field2.setAccessible(true);
                Object obj2 = field2.get(this);
                if (obj2 == null) {
                    Class<?> type = field2.getType();
                    try {
                        obj2 = type.newInstance();
                        field2.set(this, obj2);
                    } catch (InstantiationException e3) {
                        throw new RuntimeException("Error creating new instance of configuration object type " + type.getName(), e3);
                    }
                }
                field.set(obj2, obj);
            } catch (IllegalAccessException e4) {
                throw new RuntimeException("Error setting configuration field \"" + configFieldName + "\" for config field \"" + configFieldName + "\" on class " + getClass().getName());
            }
        }
        try {
            if (field.getType() == Float.TYPE && (obj instanceof Double)) {
                obj = Float.valueOf(((Double) obj).floatValue());
            }
            if (field.getType() == Character.TYPE && (obj instanceof Integer)) {
                obj = Character.valueOf((char) ((Integer) obj).intValue());
            }
            if (field.getType() == Character.TYPE && (obj instanceof Long)) {
                obj = Character.valueOf((char) ((Long) obj).intValue());
            }
            if (field.getType() == Integer.TYPE && (obj instanceof Long)) {
                obj = Integer.valueOf(((Long) obj).intValue());
            }
            if (field.getType().equals(Integer.class) && (obj instanceof Long)) {
                obj = Integer.valueOf(((Long) obj).intValue());
            }
            if (field.getType().equals(Long.class) && (obj instanceof Integer)) {
                obj = Long.valueOf(((Integer) obj).longValue());
            }
            if (field.getType().equals(Double.class) && (obj instanceof Long)) {
                obj = Double.valueOf(((Long) obj).doubleValue());
            }
            if (field.getType().equals(Boolean.class) || (field.getType().equals(Boolean.TYPE) && (obj instanceof Number))) {
                obj = Boolean.valueOf(((Number) obj).doubleValue() > 0.0d);
            }
            if (field.getType().equals(DataType.class) && (obj instanceof Double)) {
                obj = DataType.values()[((Double) obj).intValue()];
            }
            if ((field.getType().isEnum() && (obj instanceof Long)) || ((obj instanceof Integer) && !field.getType().equals(Integer.TYPE) && !field.getType().equals(Long.TYPE))) {
                Method method = field.getType().getMethod("values", new Class[0]);
                method.setAccessible(true);
                obj = ((Object[]) method.invoke(null, new Object[0]))[((Number) obj).intValue()];
            }
            field.set(this, obj);
        } catch (Exception e5) {
            throw new RuntimeException("Error setting property for function " + getClass().getName(), e5);
        }
    }

    private Object ensureProperType(Field field, Object obj) {
        Class<?> type = field.getType();
        Class<?> cls = obj.getClass();
        if (!type.equals(cls)) {
            if (!type.isEnum()) {
                if (type.equals(int[].class)) {
                    if (obj instanceof Number) {
                        obj = Integer.valueOf(((Number) obj).intValue());
                    }
                    return new int[]{((Integer) obj).intValue()};
                }
                if (type.equals(Integer[].class)) {
                    if (obj instanceof Number) {
                        obj = Integer.valueOf(((Number) obj).intValue());
                    }
                    return new Integer[]{(Integer) obj};
                }
                if (type.equals(long[].class)) {
                    if (obj instanceof Number) {
                        obj = Long.valueOf(((Number) obj).longValue());
                    }
                    return new long[]{((Long) obj).longValue()};
                }
                if (type.equals(Long[].class)) {
                    if (obj instanceof Number) {
                        obj = Long.valueOf(((Number) obj).longValue());
                    }
                    return new Long[]{(Long) obj};
                }
                if (type.equals(double[].class)) {
                    if (obj instanceof Number) {
                        obj = Double.valueOf(((Number) obj).doubleValue());
                    }
                    return new double[]{((Double) obj).doubleValue()};
                }
                if (type.equals(Double[].class)) {
                    if (obj instanceof Number) {
                        obj = Double.valueOf(((Number) obj).doubleValue());
                    }
                    return new Double[]{(Double) obj};
                }
                if (type.equals(float[].class)) {
                    if (obj instanceof Number) {
                        obj = Float.valueOf(((Number) obj).floatValue());
                    }
                    return new float[]{((Float) obj).floatValue()};
                }
                if (type.equals(Float[].class)) {
                    if (obj instanceof Number) {
                        obj = Float.valueOf(((Number) obj).floatValue());
                    }
                    return new Float[]{(Float) obj};
                }
            } else if (cls.equals(String.class)) {
                Object[] enumConstants = type.getEnumConstants();
                for (int i = 0; i < enumConstants.length; i++) {
                    if (enumConstants[i].toString().equalsIgnoreCase((String) obj)) {
                        return enumConstants[i];
                    }
                }
                throw new IllegalStateException("Could not find enum constant value for value \"" + obj + "\" for enum class " + type.getName());
            }
        }
        return obj;
    }

    public boolean isConfigProperties() {
        return false;
    }

    public String configFieldName() {
        return null;
    }

    public DifferentialFunction(SameDiff sameDiff, boolean z, Object[] objArr) {
        this.ownNameSetWithDefault = false;
        this.sameDiff = sameDiff;
        this.inPlace = z;
        setInstanceId();
        this.extraArgs = objArr;
    }

    public DifferentialFunction(SameDiff sameDiff, Object[] objArr) {
        this.ownNameSetWithDefault = false;
        this.sameDiff = sameDiff;
        setInstanceId();
        this.extraArgs = objArr;
    }

    public DifferentialFunction(SameDiff sameDiff, SDVariable[] sDVariableArr) {
        this(sameDiff, false, sDVariableArr);
    }

    public DifferentialFunction(SameDiff sameDiff, boolean z, SDVariable[] sDVariableArr) {
        this.ownNameSetWithDefault = false;
        this.sameDiff = sameDiff;
        this.inPlace = z;
        setInstanceId();
        if (sameDiff == null || sDVariableArr == null) {
            return;
        }
        sameDiff.addArgsFor(sDVariableArr, this);
    }

    public void replaceArg(int i, SDVariable sDVariable) {
        if (this.sameDiff != null) {
            this.sameDiff.replaceArgFor(i, sDVariable, this);
        }
    }

    public SDVariable[] outputVariables() {
        return outputVariables(getOwnName() != null ? getOwnName() : opName());
    }

    public SDVariable outputVariable() {
        return outputVariables()[0];
    }

    public List<SDVariable> outputs() {
        SDVariable[] outputVariables = outputVariables();
        if (outputVariables == null) {
            return null;
        }
        return Arrays.asList(outputVariables);
    }

    public String[] outputVariablesNames() {
        SDVariable[] outputVariables = outputVariables();
        String[] strArr = new String[outputVariables.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = outputVariables[i] == null ? "" : outputVariables[i].name();
        }
        return strArr;
    }

    public abstract SDVariable[] outputVariables(String str);

    public abstract List<SDVariable> doDiff(List<SDVariable> list);

    public SDVariable[] args() {
        if (this.sameDiff == null) {
            return null;
        }
        return this.sameDiff.getInputVariablesForOp(this);
    }

    public SDVariable arg(int i) {
        SDVariable[] args = args();
        Preconditions.checkNotNull(args, "Arguments are null for function %s", getOwnName());
        Preconditions.checkArgument(i >= 0 && i < args.length, "Invalid index: must be 0 to numArgs (0 <= idx < %s), got %s", args.length, i);
        return args[i];
    }

    public String[] argNames() {
        SDVariable[] args = args();
        String[] strArr = new String[args.length];
        for (int i = 0; i < args.length; i++) {
            strArr[i] = args[i].name();
        }
        return strArr;
    }

    public SDVariable arg() {
        if (args() == null || args().length == 0) {
            return null;
        }
        return args()[0];
    }

    public List<SDVariable> diff(List<SDVariable> list) {
        List<SDVariable> doDiff = doDiff(list);
        if (doDiff == null) {
            throw new IllegalStateException("Error executing diff operation: doDiff returned null for op: " + opName());
        }
        SDVariable[] args = args();
        boolean z = false;
        for (int i = 0; i < doDiff.size(); i++) {
            SDVariable sDVariable = args[i];
            SDVariable gradient = sDVariable.hasGradient() ? sDVariable.getGradient() : null;
            if (gradient != null) {
                if (!z) {
                    doDiff = new ArrayList(doDiff);
                    z = true;
                }
                SDVariable add = sDVariable.getSameDiff().math.add(gradient, doDiff.get(i));
                doDiff.set(i, add);
                this.sameDiff.setGradientForVariableName(sDVariable.name(), add);
            } else {
                SDVariable sDVariable2 = doDiff.get(i);
                if (this.sameDiff.hasVariable(sDVariable.name() + "-grad")) {
                    this.sameDiff.getVariable(sDVariable.name() + "-grad").add(sDVariable2);
                } else {
                    this.sameDiff.updateVariableNameAndReference(sDVariable2, sDVariable.name() + "-grad");
                    this.sameDiff.setGradientForVariableName(sDVariable.name(), sDVariable2);
                }
            }
        }
        return doDiff;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setInstanceId() {
        if (this.ownName == null) {
            this.ownNameSetWithDefault = true;
            if (this.sameDiff == null) {
                this.ownName = UUID.randomUUID().toString();
            } else {
                this.ownName = this.sameDiff.getOpName(opName());
            }
            if (this.sameDiff != null) {
                this.sameDiff.putOpForId(this.ownName, this);
            }
        }
    }

    public String opName() {
        throw new UnsupportedOperationException();
    }

    public Op.Type opType() {
        throw new UnsupportedOperationException();
    }

    public int opNum() {
        throw new UnsupportedOperationException();
    }

    @JsonIgnore
    public INDArray getInputArgument(int i) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public abstract void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef);

    public abstract void initFromOnnx(Onnx.NodeProto nodeProto, SameDiff sameDiff, Map<String, Onnx.AttributeProto> map, Onnx.GraphProto graphProto);

    public SDVariable larg() {
        SDVariable[] args = args();
        if (args == null || args.length == 0) {
            throw new ND4JIllegalStateException("No arguments found.");
        }
        return args()[0];
    }

    public SDVariable rarg() {
        SDVariable[] args = args();
        if (args == null || args.length != 2) {
            throw new ND4JIllegalStateException("In order to use this function, the number of arguments for this function must be 2.");
        }
        return args[1];
    }

    public DifferentialFunction dup() {
        return FlatBuffersMapper.cloneViaSerialize(this.sameDiff, this);
    }

    public List<LongShapeDescriptor> calculateOutputShape() {
        throw new ND4JIllegalStateException("Op type of " + getClass().getName() + "did not override calculateOutputShape() method leaked out for [" + opName() + "]");
    }

    public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
        throw new ND4JIllegalStateException("Op type of " + getClass().getName() + " did not override calculateOutputShape(OpContext) method leaked out for [" + opName() + "]");
    }

    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        throw new UnsupportedOperationException("Op type of " + getClass().getName() + " and name " + toString() + " did not override  calculateOutputDataTypes()! This function has not been implemented for " + getClass().getName());
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        DifferentialFunction differentialFunction = (DifferentialFunction) obj;
        if (this.inPlace != differentialFunction.inPlace) {
            return false;
        }
        if (this.scalarValue != null) {
            if (!this.scalarValue.equals(differentialFunction.scalarValue)) {
                return false;
            }
        } else if (differentialFunction.scalarValue != null) {
            return false;
        }
        if (Arrays.equals(this.dimensions, differentialFunction.dimensions)) {
            return this.ownName != null ? this.ownName.equals(differentialFunction.ownName) : differentialFunction.ownName == null;
        }
        return false;
    }

    public int hashCode() {
        return (31 * ((31 * ((31 * ((31 * 31) + (this.inPlace ? 1 : 0))) + (this.scalarValue != null ? this.scalarValue.hashCode() : 0))) + Arrays.hashCode(this.dimensions))) + (this.ownName != null ? this.ownName.hashCode() : 0);
    }

    public String[] onnxNames() {
        return new String[]{onnxName()};
    }

    public String[] tensorflowNames() {
        return new String[]{tensorflowName()};
    }

    public abstract String onnxName();

    public abstract String tensorflowName();

    public int getNumOutputs() {
        return -1;
    }

    public abstract void clearArrays();

    public Object[] getExtraArgs() {
        return this.extraArgs;
    }

    public void setExtraArgs(Object[] objArr) {
        this.extraArgs = objArr;
    }

    public String toString() {
        return "DifferentialFunction(sameDiff=" + getSameDiff() + ", inPlace=" + isInPlace() + ", scalarValue=" + getScalarValue() + ", dimensions=" + Arrays.toString(getDimensions()) + ", extraArgs=" + Arrays.deepToString(getExtraArgs()) + ", ownName=" + getOwnName() + ", ownNameSetWithDefault=" + isOwnNameSetWithDefault() + ")";
    }

    public SameDiff getSameDiff() {
        return this.sameDiff;
    }

    public void setSameDiff(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
    }

    public boolean isInPlace() {
        return this.inPlace;
    }

    public void setInPlace(boolean z) {
        this.inPlace = z;
    }

    public INDArray getScalarValue() {
        return this.scalarValue;
    }

    public void setScalarValue(INDArray iNDArray) {
        this.scalarValue = iNDArray;
    }

    public int[] getDimensions() {
        return this.dimensions;
    }

    public void setDimensions(int[] iArr) {
        this.dimensions = iArr;
    }

    public String getOwnName() {
        return this.ownName;
    }

    public void setOwnName(String str) {
        this.ownName = str;
    }

    public boolean isOwnNameSetWithDefault() {
        return this.ownNameSetWithDefault;
    }

    public void setOwnNameSetWithDefault(boolean z) {
        this.ownNameSetWithDefault = z;
    }
}
