package org.nd4j.linalg.workspace;

import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/workspace/WorkspaceUtils.class */
public class WorkspaceUtils {
    private WorkspaceUtils() {
    }

    public static void assertNoWorkspacesOpen(String str) throws ND4JWorkspaceException {
        assertNoWorkspacesOpen(str, false);
    }

    public static void assertNoWorkspacesOpen(String str, boolean z) throws ND4JWorkspaceException {
        if (Nd4j.getWorkspaceManager().anyWorkspaceActiveForCurrentThread()) {
            MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
            if (z && (currentWorkspace == null || (currentWorkspace instanceof DummyWorkspace))) {
                return;
            }
            List<MemoryWorkspace> allWorkspacesForCurrentThread = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
            ArrayList arrayList = new ArrayList(allWorkspacesForCurrentThread.size());
            for (MemoryWorkspace memoryWorkspace : allWorkspacesForCurrentThread) {
                if (memoryWorkspace.isScopeActive()) {
                    arrayList.add(memoryWorkspace.getId());
                }
            }
            throw new ND4JWorkspaceException(str + " - Open/active workspaces: " + arrayList);
        }
    }

    public static void assertOpenAndActive(@NonNull String str, @NonNull String str2) throws ND4JWorkspaceException {
        if (str == null) {
            throw new NullPointerException("ws is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("errorMsg is marked non-null but is null");
        }
        if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(str)) {
            throw new ND4JWorkspaceException(str2);
        }
    }

    public static void assertOpenActiveAndCurrent(@NonNull String str, @NonNull String str2) throws ND4JWorkspaceException {
        if (str == null) {
            throw new NullPointerException("ws is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("errorMsg is marked non-null but is null");
        }
        if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(str)) {
            throw new ND4JWorkspaceException(str2 + " - workspace is not open and active");
        }
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        if (currentWorkspace == null || !str.equals(currentWorkspace.getId())) {
            throw new ND4JWorkspaceException(str2 + " - not the current workspace (current workspace: " + (currentWorkspace == null ? null : currentWorkspace.getId()));
        }
    }

    public static void assertValidArray(INDArray iNDArray, String str) {
        if (iNDArray == null || !iNDArray.isAttached()) {
            return;
        }
        MemoryWorkspace parentWorkspace = iNDArray.data().getParentWorkspace();
        if (parentWorkspace.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
            if (!parentWorkspace.isScopeActive()) {
                throw new ND4JWorkspaceException((str == null ? "" : str + ": ") + "Array uses leaked workspace pointer from workspace " + parentWorkspace.getId() + "\nAll open workspaces: " + allOpenWorkspaces());
            }
            if (parentWorkspace.getGenerationId() != iNDArray.data().getGenerationId()) {
                String str2 = str == null ? "" : str + ": ";
                String id = parentWorkspace.getId();
                long generationId = iNDArray.data().getGenerationId();
                long generationId2 = parentWorkspace.getGenerationId();
                allOpenWorkspaces();
                ND4JWorkspaceException nD4JWorkspaceException = new ND4JWorkspaceException(str2 + "Array outdated workspace pointer from workspace " + id + " (array generation " + generationId + ", current workspace generation " + nD4JWorkspaceException + ")\nAll open workspaces: " + generationId2);
                throw nD4JWorkspaceException;
            }
        }
    }

    private static List<String> allOpenWorkspaces() {
        List<MemoryWorkspace> allWorkspacesForCurrentThread = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
        ArrayList arrayList = new ArrayList(allWorkspacesForCurrentThread.size());
        for (MemoryWorkspace memoryWorkspace : allWorkspacesForCurrentThread) {
            if (memoryWorkspace.isScopeActive()) {
                arrayList.add(memoryWorkspace.getId());
            }
        }
        return arrayList;
    }
}
