package org.nd4j.autodiff.samediff.internal.memory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
import org.nd4j.linalg.api.ndarray.BaseNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Table;
import org.nd4j.shade.guava.primitives.Longs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/samediff/internal/memory/ArrayCacheMemoryMgr.class */
public class ArrayCacheMemoryMgr extends AbstractMemoryMgr {
    private static final Logger log = LoggerFactory.getLogger(ArrayCacheMemoryMgr.class);
    private final double maxMemFrac;
    private long smallArrayThreshold;
    private double largerArrayMaxMultiple;
    private long maxCacheBytes;
    private long totalMemBytes;
    private long currentCacheSize;
    private LinkedHashSet<Long> lruCache;
    private Map<Long, INDArray> lruCacheValues;
    private Table<DataType, String, List<INDArray>> arrays;
    private boolean enableCache;

    public ArrayCacheMemoryMgr() {
        this(0.25d, 1024L, 2.0d);
    }

    public ArrayCacheMemoryMgr(double d, long j, double d2) {
        this.currentCacheSize = 0L;
        this.lruCache = new LinkedHashSet<>();
        this.lruCacheValues = new HashMap();
        this.arrays = HashBasedTable.create();
        this.enableCache = Boolean.parseBoolean(System.getProperty("org.nd4j.autodiff.samediff", "true"));
        Preconditions.checkArgument(d > 0.0d && d < 1.0d, "Maximum memory fraction for cache must be between 0.0 and 1.0, got %s", d);
        Preconditions.checkArgument(j >= 0, "Small array threshold must be >= 0, got %s", j);
        Preconditions.checkArgument(d2 >= 1.0d, "Larger array max multiple must be >= 1.0, got %s", d2);
        this.maxMemFrac = d;
        this.smallArrayThreshold = j;
        this.largerArrayMaxMultiple = d2;
        if (isCpu()) {
            this.totalMemBytes = Pointer.maxBytes();
        } else {
            this.totalMemBytes = ((Long) ((Map) ((List) Nd4j.getExecutioner().getEnvironmentInformation().get(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY)).get(0)).get(Nd4jEnvironment.CUDA_TOTAL_MEMORY_KEY)).longValue();
        }
        this.maxCacheBytes = (long) (d * this.totalMemBytes);
    }

    private boolean isCpu() {
        return !"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty(Nd4jEnvironment.BACKEND_KEY));
    }

    @Override // org.nd4j.autodiff.samediff.internal.SessionMemMgr
    public INDArray allocate(boolean z, DataType dataType, long... jArr) {
        String arrays = Arrays.toString(jArr);
        if (this.arrays.contains(dataType, arrays) && this.enableCache) {
            INDArray iNDArray = !((List) this.arrays.get(dataType, arrays)).isEmpty() ? (INDArray) ((List) this.arrays.get(dataType, arrays)).remove(0) : null;
            if (iNDArray != null && !iNDArray.wasClosed()) {
                this.currentCacheSize -= dataType.width() * iNDArray.data().length();
                log.debug("Cache hit for data type " + dataType + " and shape " + Arrays.toString(jArr));
                this.lruCache.remove(Long.valueOf(iNDArray.getId()));
                this.lruCacheValues.remove(Long.valueOf(iNDArray.getId()));
                ((BaseNDArray) iNDArray).assignNewId();
                return iNDArray;
            }
        }
        return Nd4j.createUninitializedDetached(dataType, jArr);
    }

    @Override // org.nd4j.autodiff.samediff.internal.SessionMemMgr
    public INDArray allocate(boolean z, LongShapeDescriptor longShapeDescriptor) {
        if (longShapeDescriptor.isEmpty()) {
            INDArray create = Nd4j.create(longShapeDescriptor);
            if (z) {
                create = create.detach();
            }
            return create;
        }
        DataType dataType = longShapeDescriptor.dataType();
        long[] shape = longShapeDescriptor.getShape();
        String arrays = Arrays.toString(shape);
        if (this.arrays.contains(dataType, arrays) && this.enableCache && shape.length > 0 && !Longs.contains(shape, 0L)) {
            INDArray iNDArray = null;
            List list = (List) this.arrays.get(dataType, arrays);
            if (list.size() > 0) {
                iNDArray = (INDArray) list.remove(0);
            }
            if (iNDArray != null && iNDArray.ordering() != longShapeDescriptor.getOrder()) {
                iNDArray.setOrder(longShapeDescriptor.getOrder());
            }
            if (iNDArray != null && !iNDArray.wasClosed()) {
                this.currentCacheSize -= dataType.width() * iNDArray.data().length();
                log.debug("Cache hit for data type " + dataType + " and shape " + Arrays.toString(iNDArray.shape()));
                this.lruCache.remove(Long.valueOf(iNDArray.getId()));
                this.lruCacheValues.remove(Long.valueOf(iNDArray.getId()));
                ((BaseNDArray) iNDArray).assignNewId();
                return iNDArray;
            }
        }
        return Nd4j.createUninitializedDetached(dataType, shape);
    }

    @Override // org.nd4j.autodiff.samediff.internal.SessionMemMgr
    public void release(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        long id = iNDArray.getId();
        Preconditions.checkState(!this.lruCache.contains(Long.valueOf(id)), "Array was released multiple times: id=%s, shape=%ndShape", Long.valueOf(id), iNDArray);
        if (!this.enableCache) {
            if (iNDArray.closeable()) {
                iNDArray.close();
                return;
            }
            return;
        }
        DataType dataType = iNDArray.dataType();
        if (iNDArray.data() == null && iNDArray.closeable()) {
            iNDArray.close();
            return;
        }
        if (iNDArray != null && iNDArray.data() != null && Nd4j.getExecutioner().useCount(iNDArray.data()) > 1) {
            if (iNDArray.closeable()) {
                iNDArray.close();
                return;
            }
            return;
        }
        long length = iNDArray.data().length() * dataType.width();
        if (iNDArray.dataType() == DataType.UTF8) {
            if (iNDArray.closeable()) {
                iNDArray.close();
            }
        } else if (this.currentCacheSize + length <= this.maxCacheBytes) {
            cacheArray(iNDArray);
        } else {
            if (length > this.maxCacheBytes) {
                if (iNDArray.closeable()) {
                    iNDArray.close();
                    return;
                }
                return;
            }
            Iterator<Long> it = this.lruCache.iterator();
            while (this.currentCacheSize + length > this.maxCacheBytes) {
                long longValue = it.next().longValue();
                it.remove();
                INDArray remove = this.lruCacheValues.remove(Long.valueOf(longValue));
                DataType dataType2 = remove.dataType();
                long width = dataType2.width() * remove.data().length();
                List list = (List) this.arrays.get(dataType2, Arrays.toString(remove.shape()));
                if (list != null) {
                    list.remove(remove);
                }
                this.currentCacheSize -= width;
                if (remove.closeable()) {
                    remove.close();
                }
            }
            cacheArray(iNDArray);
        }
        this.lruCache.add(Long.valueOf(iNDArray.getId()));
        this.lruCacheValues.put(Long.valueOf(iNDArray.getId()), iNDArray);
    }

    private void cacheArray(INDArray iNDArray) {
        DataType dataType = iNDArray.dataType();
        String arrays = Arrays.toString(iNDArray.shape());
        if (!this.arrays.contains(dataType, arrays)) {
            this.arrays.put(dataType, arrays, new ArrayList());
        }
        ((List) this.arrays.get(dataType, arrays)).add(iNDArray);
        this.currentCacheSize += iNDArray.data().length() * dataType.width();
        this.lruCache.add(Long.valueOf(iNDArray.getId()));
        this.lruCacheValues.put(Long.valueOf(iNDArray.getId()), iNDArray);
    }

    @Override // org.nd4j.autodiff.samediff.internal.SessionMemMgr, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        this.arrays.values().stream().forEach(list -> {
            list.stream().forEach(iNDArray -> {
                if (iNDArray.closeable()) {
                    iNDArray.close();
                }
            });
        });
    }

    public double getMaxMemFrac() {
        return this.maxMemFrac;
    }

    public long getSmallArrayThreshold() {
        return this.smallArrayThreshold;
    }

    public double getLargerArrayMaxMultiple() {
        return this.largerArrayMaxMultiple;
    }

    public long getMaxCacheBytes() {
        return this.maxCacheBytes;
    }

    public long getTotalMemBytes() {
        return this.totalMemBytes;
    }

    public long getCurrentCacheSize() {
        return this.currentCacheSize;
    }

    public LinkedHashSet<Long> getLruCache() {
        return this.lruCache;
    }

    public Map<Long, INDArray> getLruCacheValues() {
        return this.lruCacheValues;
    }

    public Table<DataType, String, List<INDArray>> getArrays() {
        return this.arrays;
    }

    public boolean isEnableCache() {
        return this.enableCache;
    }

    public void setSmallArrayThreshold(long j) {
        this.smallArrayThreshold = j;
    }

    public void setLargerArrayMaxMultiple(double d) {
        this.largerArrayMaxMultiple = d;
    }

    public void setMaxCacheBytes(long j) {
        this.maxCacheBytes = j;
    }

    public void setTotalMemBytes(long j) {
        this.totalMemBytes = j;
    }

    public void setCurrentCacheSize(long j) {
        this.currentCacheSize = j;
    }

    public void setLruCache(LinkedHashSet<Long> linkedHashSet) {
        this.lruCache = linkedHashSet;
    }

    public void setLruCacheValues(Map<Long, INDArray> map) {
        this.lruCacheValues = map;
    }

    public void setArrays(Table<DataType, String, List<INDArray>> table) {
        this.arrays = table;
    }

    public void setEnableCache(boolean z) {
        this.enableCache = z;
    }
}
