package org.deeplearning4j.models.word2vec.wordstore;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/word2vec/wordstore/VocabularyHolder.class */
public class VocabularyHolder implements Serializable {
    private VocabCache vocabCache;
    private static final int MAX_CODE_LENGTH = 40;
    private final Map<String, VocabularyWord> vocabulary = new ConcurrentHashMap();
    private transient Map<Integer, VocabularyWord> idxMap = new ConcurrentHashMap();
    private int minWordFrequency = 0;
    private boolean hugeModelExpected = false;
    private int retentionDelay = 3;
    private int scavengerThreshold = 2000000;
    private long totalWordOccurrences = 0;
    private transient AtomicLong hiddenWordsCounter = new AtomicLong(0);
    private AtomicInteger totalWordCount = new AtomicInteger(0);
    private Logger logger = LoggerFactory.getLogger(VocabularyHolder.class);

    /* loaded from: input_file:org/deeplearning4j/models/word2vec/wordstore/VocabularyHolder$Builder.class */
    public static class Builder {
        private VocabCache cache = null;
        private int minWordFrequency = 0;
        private boolean hugeModelExpected = false;
        private int scavengerThreshold = 2000000;
        private int retentionDelay = 3;

        public Builder externalCache(@NonNull VocabCache vocabCache) {
            if (vocabCache == null) {
                throw new NullPointerException("cache is marked non-null but is null");
            }
            this.cache = vocabCache;
            return this;
        }

        public Builder minWordFrequency(int i) {
            this.minWordFrequency = i;
            return this;
        }

        public Builder hugeModelExpected(boolean z) {
            this.hugeModelExpected = z;
            return this;
        }

        public Builder scavengerActivationThreshold(int i) {
            this.scavengerThreshold = i;
            return this;
        }

        public Builder scavengerRetentionDelay(int i) {
            if (i < 2) {
                throw new IllegalStateException("Delay < 2 doesn't really makes sense");
            }
            this.retentionDelay = i;
            return this;
        }

        public VocabularyHolder build() {
            VocabularyHolder vocabularyHolder = this.cache != null ? new VocabularyHolder(this.cache, true) : new VocabularyHolder();
            vocabularyHolder.minWordFrequency = this.minWordFrequency;
            vocabularyHolder.hugeModelExpected = this.hugeModelExpected;
            vocabularyHolder.scavengerThreshold = this.scavengerThreshold;
            vocabularyHolder.retentionDelay = this.retentionDelay;
            return vocabularyHolder;
        }
    }

    protected VocabularyHolder() {
    }

    protected VocabularyHolder(@NonNull VocabCache<? extends SequenceElement> vocabCache, boolean z) {
        if (vocabCache == null) {
            throw new NullPointerException("cache is marked non-null but is null");
        }
        this.vocabCache = vocabCache;
        for (SequenceElement sequenceElement : vocabCache.tokens()) {
            VocabularyWord vocabularyWord = new VocabularyWord(sequenceElement.getLabel());
            vocabularyWord.setCount((int) sequenceElement.getElementFrequency());
            vocabularyWord.setSpecial(z);
            if (sequenceElement.getPoints() != null && !sequenceElement.getPoints().isEmpty()) {
                vocabularyWord.setHuffmanNode(buildNode(sequenceElement.getCodes(), sequenceElement.getPoints(), sequenceElement.getCodeLength(), sequenceElement.getIndex()));
            }
            this.vocabulary.put(vocabularyWord.getWord(), vocabularyWord);
        }
        if (numWords() > 1) {
            updateHuffmanCodes();
        }
        this.logger.info("Init from VocabCache is complete. " + numWords() + " word(s) were transferred.");
    }

    public static HuffmanNode buildNode(List<Byte> list, List<Integer> list2, int i, int i2) {
        return new HuffmanNode(listToArray(list), listToArray(list2, MAX_CODE_LENGTH), i2, (byte) i);
    }

    public void transferBackToVocabCache() {
        transferBackToVocabCache(this.vocabCache, true);
    }

    public void transferBackToVocabCache(VocabCache vocabCache) {
        transferBackToVocabCache(vocabCache, true);
    }

    public void transferBackToVocabCache(VocabCache vocabCache, boolean z) {
        if (!(vocabCache instanceof InMemoryLookupCache)) {
            throw new IllegalStateException("Sorry, only InMemoryLookupCache use implemented.");
        }
        for (VocabularyWord vocabularyWord : words()) {
            if (!vocabularyWord.getWord().isEmpty()) {
                VocabWord vocabWord = new VocabWord(1.0d, vocabularyWord.getWord());
                if (vocabularyWord.getHistoricalGradient() != null) {
                    vocabWord.setHistoricalGradient(Nd4j.create(vocabularyWord.getHistoricalGradient()));
                }
                ((InMemoryLookupCache) vocabCache).getVocabs().put(vocabularyWord.getWord(), vocabWord);
                ((InMemoryLookupCache) vocabCache).getTokens().put(vocabularyWord.getWord(), vocabWord);
                if (vocabularyWord.getHuffmanNode() != null) {
                    vocabWord.setIndex(vocabularyWord.getHuffmanNode().getIdx());
                    vocabWord.setCodeLength(vocabularyWord.getHuffmanNode().getLength());
                    vocabWord.setPoints(arrayToList(vocabularyWord.getHuffmanNode().getPoint(), vocabularyWord.getHuffmanNode().getLength()));
                    vocabWord.setCodes(arrayToList(vocabularyWord.getHuffmanNode().getCode(), (int) vocabularyWord.getHuffmanNode().getLength()));
                    vocabCache.addWordToIndex(vocabularyWord.getHuffmanNode().getIdx(), vocabularyWord.getWord());
                }
                if (vocabularyWord.getCount() > 1) {
                    vocabCache.incrementWordCount(vocabularyWord.getWord(), vocabularyWord.getCount() - 1);
                }
            }
        }
        if (z) {
            this.idxMap.clear();
            this.vocabulary.clear();
        }
    }

    protected void setScavengerActivationThreshold(int i) {
        this.scavengerThreshold = i;
    }

    public static List<Byte> arrayToList(byte[] bArr, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(Byte.valueOf(bArr[i2]));
        }
        return arrayList;
    }

    public static byte[] listToArray(List<Byte> list) {
        byte[] bArr = new byte[MAX_CODE_LENGTH];
        for (int i = 0; i < list.size(); i++) {
            bArr[i] = list.get(i).byteValue();
        }
        return bArr;
    }

    public static int[] listToArray(List<Integer> list, int i) {
        int[] iArr = new int[list.size()];
        for (int i2 = 0; i2 < list.size(); i2++) {
            iArr[i2] = list.get(i2).intValue();
        }
        return iArr;
    }

    public static List<Integer> arrayToList(int[] iArr, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(Integer.valueOf(iArr[i2]));
        }
        return arrayList;
    }

    public Collection<VocabularyWord> getVocabulary() {
        return this.vocabulary.values();
    }

    public VocabularyWord getVocabularyWordByString(String str) {
        return this.vocabulary.get(str);
    }

    public VocabularyWord getVocabularyWordByIdx(Integer num) {
        return this.idxMap.get(num);
    }

    public boolean containsWord(String str) {
        return this.vocabulary.containsKey(str);
    }

    public void incrementWordCounter(String str) {
        if (this.vocabulary.containsKey(str)) {
            this.vocabulary.get(str).incrementCount();
        }
    }

    public void addWord(String str) {
        if (this.vocabulary.containsKey(str)) {
            return;
        }
        VocabularyWord vocabularyWord = new VocabularyWord(str);
        if (this.hugeModelExpected) {
            vocabularyWord.setFrequencyShift(new byte[this.retentionDelay]);
        }
        this.vocabulary.put(str, vocabularyWord);
        if (this.hugeModelExpected && this.minWordFrequency > 1 && this.hiddenWordsCounter.incrementAndGet() % this.scavengerThreshold == 0) {
            activateScavenger();
        }
    }

    public void addWord(VocabularyWord vocabularyWord) {
        this.vocabulary.put(vocabularyWord.getWord(), vocabularyWord);
    }

    public void consumeVocabulary(VocabularyHolder vocabularyHolder) {
        for (VocabularyWord vocabularyWord : vocabularyHolder.getVocabulary()) {
            if (containsWord(vocabularyWord.getWord())) {
                vocabularyHolder.incrementWordCounter(vocabularyWord.getWord());
            } else {
                addWord(vocabularyWord);
            }
        }
    }

    protected synchronized void activateScavenger() {
        int size = this.vocabulary.size();
        for (VocabularyWord vocabularyWord : new ArrayList(this.vocabulary.values())) {
            if (vocabularyWord.isSpecial() || vocabularyWord.getCount() >= this.minWordFrequency || vocabularyWord.getFrequencyShift() == null) {
                vocabularyWord.setFrequencyShift(null);
            } else {
                vocabularyWord.getFrequencyShift()[vocabularyWord.getRetentionStep()] = (byte) vocabularyWord.getCount();
                int max = Math.max(this.minWordFrequency / 5, 2);
                this.logger.debug("Current state> Activation: [" + max + "], retention info: " + Arrays.toString(vocabularyWord.getFrequencyShift()));
                if (vocabularyWord.getCount() <= max && vocabularyWord.getFrequencyShift()[this.retentionDelay - 1] > 0 && vocabularyWord.getFrequencyShift()[this.retentionDelay - 1] <= max && vocabularyWord.getFrequencyShift()[this.retentionDelay - 1] == vocabularyWord.getFrequencyShift()[0]) {
                    this.vocabulary.remove(vocabularyWord.getWord());
                }
                if (vocabularyWord.getRetentionStep() < this.retentionDelay - 1) {
                    vocabularyWord.incrementRetentionStep();
                } else {
                    for (int i = 1; i < this.retentionDelay; i++) {
                        vocabularyWord.getFrequencyShift()[i - 1] = vocabularyWord.getFrequencyShift()[i];
                    }
                }
            }
        }
        this.logger.info("Scavenger was activated. Vocab size before: [" + size + "],  after: [" + this.vocabulary.size() + "]");
    }

    public void resetWordCounters() {
        for (VocabularyWord vocabularyWord : getVocabulary()) {
            vocabularyWord.setHuffmanNode(null);
            vocabularyWord.setFrequencyShift(null);
            vocabularyWord.setCount(0);
        }
    }

    public int numWords() {
        return this.vocabulary.size();
    }

    public void truncateVocabulary() {
        truncateVocabulary(this.minWordFrequency);
    }

    public void truncateVocabulary(int i) {
        this.logger.debug("Truncating vocabulary to minWordFrequency: [" + i + "]");
        for (String str : this.vocabulary.keySet()) {
            VocabularyWord vocabularyWord = this.vocabulary.get(str);
            if (!vocabularyWord.isSpecial() && vocabularyWord.getCount() < i) {
                this.vocabulary.remove(str);
                if (vocabularyWord.getHuffmanNode() != null) {
                    this.idxMap.remove(Integer.valueOf(vocabularyWord.getHuffmanNode().getIdx()));
                }
            }
        }
    }

    public List<VocabularyWord> updateHuffmanCodes() {
        int i;
        int i2;
        List<VocabularyWord> words = words();
        int[] iArr = new int[(words.size() * 2) + 1];
        int[] iArr2 = new int[(words.size() * 2) + 1];
        byte[] bArr = new byte[(words.size() * 2) + 1];
        for (int i3 = 0; i3 < words.size(); i3++) {
            iArr[i3] = words.get(i3).getCount();
        }
        for (int size = words.size(); size < words.size() * 2; size++) {
            iArr[size] = Integer.MAX_VALUE;
        }
        int size2 = words.size() - 1;
        int size3 = words.size();
        for (int i4 = 0; i4 < words.size(); i4++) {
            if (size2 < 0) {
                i = size3;
                size3++;
            } else if (iArr[size2] < iArr[size3]) {
                i = size2;
                size2--;
            } else {
                i = size3;
                size3++;
            }
            if (size2 < 0) {
                i2 = size3;
                size3++;
            } else if (iArr[size2] < iArr[size3]) {
                i2 = size2;
                size2--;
            } else {
                i2 = size3;
                size3++;
            }
            iArr[words.size() + i4] = iArr[i] + iArr[i2];
            iArr2[i] = words.size() + i4;
            iArr2[i2] = words.size() + i4;
            bArr[i2] = 1;
        }
        byte[] bArr2 = new byte[MAX_CODE_LENGTH];
        int[] iArr3 = new int[MAX_CODE_LENGTH];
        for (int i5 = 0; i5 < words.size(); i5++) {
            int i6 = i5;
            int i7 = 0;
            byte[] bArr3 = new byte[MAX_CODE_LENGTH];
            int[] iArr4 = new int[MAX_CODE_LENGTH];
            do {
                bArr2[i7] = bArr[i6];
                iArr3[i7] = i6;
                i7++;
                i6 = iArr2[i6];
            } while (i6 != (words.size() * 2) - 2);
            iArr4[0] = words.size() - 2;
            for (int i8 = 0; i8 < i7; i8++) {
                bArr3[(i7 - i8) - 1] = bArr2[i8];
                iArr4[i7 - i8] = iArr3[i8] - words.size();
            }
            words.get(i5).setHuffmanNode(new HuffmanNode(bArr3, iArr4, i5, (byte) i7));
        }
        this.idxMap.clear();
        for (VocabularyWord vocabularyWord : words) {
            this.idxMap.put(Integer.valueOf(vocabularyWord.getHuffmanNode().getIdx()), vocabularyWord);
        }
        return words;
    }

    public int indexOf(String str) {
        if (this.vocabulary.containsKey(str)) {
            return this.vocabulary.get(str).getHuffmanNode().getIdx();
        }
        return -1;
    }

    public List<VocabularyWord> words() {
        ArrayList arrayList = new ArrayList(this.vocabulary.values());
        Collections.sort(arrayList, new Comparator<VocabularyWord>() { // from class: org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder.1
            @Override // java.util.Comparator
            public int compare(VocabularyWord vocabularyWord, VocabularyWord vocabularyWord2) {
                return Integer.compare(vocabularyWord2.getCount(), vocabularyWord.getCount());
            }
        });
        return arrayList;
    }

    public long totalWordsBeyondLimit() {
        if (this.totalWordOccurrences != 0) {
            return this.totalWordOccurrences;
        }
        Iterator<VocabularyWord> it = this.vocabulary.values().iterator();
        while (it.hasNext()) {
            this.totalWordOccurrences += it.next().getCount();
        }
        return this.totalWordOccurrences;
    }
}
