package org.nd4j.autodiff.samediff.internal.memory;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.function.Predicate;
import org.nd4j.autodiff.samediff.internal.IDependeeGroup;
import org.nd4j.autodiff.samediff.internal.IDependencyMap;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/nd4j/autodiff/samediff/internal/memory/DependencyMap.class */
public class DependencyMap<K extends IDependeeGroup<INDArray>, V> implements IDependencyMap<K, V> {
    private HashMap<Long, HashSet<Pair<Long, V>>> map = new HashMap<>();

    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public void clear() {
        this.map.clear();
    }

    public void add(K k, V v) {
        long id = k.getId();
        for (INDArray iNDArray : k.getCollection()) {
            if (iNDArray != null) {
                HashSet<Pair<Long, V>> hashSet = this.map.get(Long.valueOf(iNDArray.getId()));
                if (hashSet != null) {
                    hashSet.add(Pair.create(Long.valueOf(id), v));
                } else {
                    HashSet<Pair<Long, V>> hashSet2 = new HashSet<>();
                    hashSet2.add(Pair.create(Long.valueOf(id), v));
                    this.map.put(Long.valueOf(iNDArray.getId()), hashSet2);
                }
            }
        }
    }

    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public boolean isEmpty() {
        return this.map.isEmpty();
    }

    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public Iterable<V> getDependantsForEach(K k) {
        HashSet<Pair<Long, V>> hashSet;
        HashSet hashSet2 = new HashSet();
        for (INDArray iNDArray : k.getCollection()) {
            if (iNDArray != null && (hashSet = this.map.get(Long.valueOf(iNDArray.getId()))) != null) {
                Iterator<Pair<Long, V>> it = hashSet.iterator();
                while (it.hasNext()) {
                    hashSet2.add(it.next().getSecond());
                }
            }
        }
        return hashSet2;
    }

    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public Iterable<V> getDependantsForGroup(K k) {
        HashSet<Pair<Long, V>> hashSet;
        HashSet hashSet2 = new HashSet();
        for (INDArray iNDArray : k.getCollection()) {
            if (iNDArray != null && (hashSet = this.map.get(Long.valueOf(iNDArray.getId()))) != null) {
                Iterator<Pair<Long, V>> it = hashSet.iterator();
                while (it.hasNext()) {
                    Pair<Long, V> next = it.next();
                    if (((Long) next.getFirst()).longValue() == k.getId()) {
                        hashSet2.add(next.getSecond());
                    }
                }
            }
        }
        return hashSet2;
    }

    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public boolean containsAnyForGroup(K k) {
        HashSet<Pair<Long, V>> hashSet;
        for (INDArray iNDArray : k.getCollection()) {
            if (iNDArray != null && (hashSet = this.map.get(Long.valueOf(iNDArray.getId()))) != null) {
                Iterator<Pair<Long, V>> it = hashSet.iterator();
                while (it.hasNext()) {
                    if (((Long) it.next().getFirst()).longValue() == k.getId()) {
                        return true;
                    }
                }
            }
        }
        return false;
    }

    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public void removeGroup(K k) {
        HashSet<Pair<Long, V>> hashSet;
        for (INDArray iNDArray : k.getCollection()) {
            if (iNDArray != null && (hashSet = this.map.get(Long.valueOf(iNDArray.getId()))) != null) {
                long size = hashSet.size();
                ArrayList arrayList = new ArrayList();
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    Pair pair = (Pair) it.next();
                    if (((Long) pair.getFirst()).longValue() == k.getId()) {
                        arrayList.add(pair);
                    }
                }
                if (arrayList.size() > 0) {
                    hashSet.removeAll(arrayList);
                    if (size == arrayList.size()) {
                        this.map.remove(Long.valueOf(iNDArray.getId()));
                    }
                }
            }
        }
    }

    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public Iterable<V> removeGroupReturn(K k) {
        HashSet<Pair<Long, V>> hashSet;
        HashSet hashSet2 = new HashSet();
        for (INDArray iNDArray : k.getCollection()) {
            if (iNDArray != null && (hashSet = this.map.get(Long.valueOf(iNDArray.getId()))) != null) {
                long size = hashSet.size();
                ArrayList arrayList = new ArrayList();
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    Pair pair = (Pair) it.next();
                    if (((Long) pair.getFirst()).longValue() == k.getId()) {
                        arrayList.add(pair);
                        hashSet2.add(pair.getSecond());
                    }
                }
                if (arrayList.size() > 0) {
                    hashSet.removeAll(arrayList);
                    if (size == arrayList.size()) {
                        this.map.remove(Long.valueOf(iNDArray.getId()));
                    }
                }
            }
        }
        return hashSet2;
    }

    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public void removeForEach(K k) {
        for (INDArray iNDArray : k.getCollection()) {
            if (iNDArray != null) {
                this.map.remove(Long.valueOf(iNDArray.getId()));
            }
        }
    }

    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public Iterable<V> removeForEachResult(K k) {
        HashSet<Pair<Long, V>> remove;
        HashSet hashSet = new HashSet();
        for (INDArray iNDArray : k.getCollection()) {
            if (iNDArray != null && (remove = this.map.remove(Long.valueOf(iNDArray.getId()))) != null) {
                Iterator<Pair<Long, V>> it = remove.iterator();
                while (it.hasNext()) {
                    hashSet.add(it.next().getSecond());
                }
                this.map.remove(Long.valueOf(iNDArray.getId()));
            }
        }
        return hashSet;
    }

    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public boolean containsAny(K k) {
        for (INDArray iNDArray : k.getCollection()) {
            if (iNDArray != null && this.map.containsKey(Long.valueOf(iNDArray.getId()))) {
                return true;
            }
        }
        return false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public Iterable<V> removeGroupReturn(K k, Predicate<V> predicate) {
        HashSet hashSet = new HashSet();
        for (INDArray iNDArray : k.getCollection()) {
            if (iNDArray != null) {
                long id = iNDArray.getId();
                HashSet<Pair<Long, V>> hashSet2 = this.map.get(Long.valueOf(id));
                if (hashSet2 != null) {
                    long size = hashSet2.size();
                    ArrayList arrayList = new ArrayList();
                    Iterator it = hashSet2.iterator();
                    while (it.hasNext()) {
                        Pair pair = (Pair) it.next();
                        if (((Long) pair.getFirst()).longValue() == k.getId() && predicate.test(pair.getSecond())) {
                            arrayList.add(pair);
                            hashSet.add(pair.getSecond());
                        }
                    }
                    if (arrayList.size() > 0) {
                        hashSet2.removeAll(arrayList);
                        if (size == arrayList.size()) {
                            this.map.remove(Long.valueOf(id));
                        }
                    }
                }
            }
        }
        return hashSet;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.nd4j.autodiff.samediff.internal.IDependencyMap
    public /* bridge */ /* synthetic */ void add(Object obj, Object obj2) {
        add((DependencyMap<K, V>) obj, (IDependeeGroup) obj2);
    }
}
