package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/shape/CreateView.class */
public class CreateView extends DynamicCustomOp {
    public static int POINT_TYPE = 0;
    public static int INTERVAL_TYPE = 1;
    public static int ALL_TYPE = 2;
    public static int NEW_AXIS = 3;
    public static int DEFAULT_INCLUSIVE = 1;

    public CreateView() {
    }

    public CreateView(INDArray[] iNDArrayArr) {
        super(iNDArrayArr, (INDArray[]) null);
    }

    public CreateView(SameDiff sameDiff, SDVariable[] sDVariableArr) {
        super(sameDiff, sDVariableArr);
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [java.lang.Object[][], org.nd4j.autodiff.samediff.SDVariable[]] */
    public CreateView(SameDiff sameDiff, SDVariable sDVariable, SDVariable[] sDVariableArr) {
        this(sameDiff, (SDVariable[]) ArrayUtil.combine((Object[][]) new SDVariable[]{new SDVariable[]{sDVariable}, sDVariableArr}));
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [org.nd4j.linalg.api.ndarray.INDArray[], java.lang.Object[][]] */
    public CreateView(INDArray iNDArray, INDArray[] iNDArrayArr) {
        this((INDArray[]) ArrayUtil.combine((Object[][]) new INDArray[]{new INDArray[]{iNDArray}, iNDArrayArr}));
    }

    public static SDVariable createInterval(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4) {
        return createInterval(sameDiff, (String) null, sDVariable, sDVariable2, sDVariable3, sDVariable4);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        return Collections.singletonList(list.get(0));
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int getNumOutputs() {
        return 1;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "create_view";
    }

    public static SDVariable createPoint(SameDiff sameDiff, long j) {
        return createPoint(sameDiff, (String) null, j);
    }

    public static SDVariable createPoint(SameDiff sameDiff, SDVariable sDVariable) {
        return createPoint(sameDiff, (String) null, sDVariable);
    }

    public static SDVariable createPoint(SameDiff sameDiff, String str, long j) {
        return sameDiff.var(str, Nd4j.createFromArray(POINT_TYPE, 1, 1, j, DEFAULT_INCLUSIVE));
    }

    public static SDVariable createPoint(SameDiff sameDiff, String str, SDVariable sDVariable) {
        return sameDiff.concat(str, 0, sameDiff.constant(POINT_TYPE).reshape(1).castTo(DataType.INT64), sameDiff.constant(1).reshape(1).castTo(DataType.INT64), sameDiff.constant(1).reshape(1).castTo(DataType.INT64), sDVariable.reshape(1).castTo(DataType.INT64), sameDiff.constant(DEFAULT_INCLUSIVE).reshape(1).castTo(DataType.INT64));
    }

    public static SDVariable createAll(SameDiff sameDiff) {
        return createAll(sameDiff, null);
    }

    public static SDVariable createAll(SameDiff sameDiff, String str) {
        return sameDiff.var(str, Nd4j.createFromArray(ALL_TYPE, 0, 1, DEFAULT_INCLUSIVE));
    }

    public static SDVariable createNewAxis(SameDiff sameDiff, String str) {
        return sameDiff.var(str, Nd4j.createFromArray(NEW_AXIS, 1, 10, DEFAULT_INCLUSIVE));
    }

    public static SDVariable createNewAxis(SameDiff sameDiff) {
        return createNewAxis(sameDiff, null);
    }

    public static SDVariable createInterval(SameDiff sameDiff, String str, long j, long j2, long j3, long j4) {
        return sameDiff.var(str, Nd4j.createFromArray(INTERVAL_TYPE, 2, 1, j, j2, j3, j4));
    }

    public static SDVariable createInterval(SameDiff sameDiff, String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4) {
        if (sDVariable3 == null) {
            sDVariable3 = sameDiff.constant(1).castTo(DataType.INT64).reshape(1);
        }
        if (sDVariable4 == null) {
            sDVariable4 = sameDiff.constant(0).castTo(DataType.INT64).reshape(1);
        }
        return sameDiff.concat(str, 0, sameDiff.constant(INTERVAL_TYPE).reshape(1).castTo(DataType.INT64), sameDiff.constant(2).reshape(1).castTo(DataType.INT64), sameDiff.constant(1).reshape(1).castTo(DataType.INT64), sDVariable.reshape(1).castTo(DataType.INT64), sDVariable2.reshape(1).castTo(DataType.INT64), sDVariable3.reshape(1).castTo(DataType.INT64), sDVariable4.castTo(DataType.INT64).reshape(1));
    }

    public static SDVariable createInterval(SameDiff sameDiff, long j, long j2, long j3, long j4) {
        return createInterval(sameDiff, (String) null, j, j2, j3, j4);
    }

    public static INDArray createFrom(INDArray iNDArray, INDArray... iNDArrayArr) {
        return iNDArray.get(indices(iNDArrayArr));
    }

    public static INDArrayIndex[] indices(INDArray... iNDArrayArr) {
        return (INDArrayIndex[]) ((List) Arrays.stream(iNDArrayArr).map(CreateView::fromIndexArr).collect(Collectors.toList())).toArray(new INDArrayIndex[iNDArrayArr.length]);
    }

    public static INDArrayIndex fromIndexArr(INDArray iNDArray) {
        int i = iNDArray.getInt(0);
        if (i == POINT_TYPE) {
            return NDArrayIndex.point(iNDArray.getInt(3));
        }
        if (i == INTERVAL_TYPE) {
            return NDArrayIndex.interval(iNDArray.getInt(3), iNDArray.getInt(5), iNDArray.getInt(4), iNDArray.getInt(6) > 0);
        }
        if (i == NEW_AXIS) {
            return NDArrayIndex.newAxis();
        }
        if (i == ALL_TYPE) {
            return NDArrayIndex.all();
        }
        throw new IllegalArgumentException("Invalid type. Must be 1 of: " + POINT_TYPE + " (point type) " + INTERVAL_TYPE + " (interval type)" + NEW_AXIS + " (new axis) " + ALL_TYPE + " (all) ");
    }
}
