package org.deeplearning4j.models.word2vec;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.AbstractStorage;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/word2vec/StaticWord2Vec.class */
public class StaticWord2Vec implements WordVectors {
    private static final Logger log = LoggerFactory.getLogger(StaticWord2Vec.class);
    private AbstractStorage<Integer> storage;
    private VocabCache<VocabWord> vocabCache;
    private List<Map<Integer, INDArray>> cacheWrtDevice = new ArrayList();
    private long cachePerDevice = 0;
    private String unk = null;

    /* loaded from: input_file:org/deeplearning4j/models/word2vec/StaticWord2Vec$Builder.class */
    public static class Builder {
        private AbstractStorage<Integer> storage;
        private long cachePerDevice = 0;
        private VocabCache<VocabWord> vocabCache;

        public Builder(AbstractStorage<Integer> abstractStorage, VocabCache<VocabWord> vocabCache) {
            this.storage = abstractStorage;
            this.vocabCache = vocabCache;
        }

        public Builder setCachePerDevice(long j) {
            this.cachePerDevice = j;
            return this;
        }

        public StaticWord2Vec build() {
            StaticWord2Vec staticWord2Vec = new StaticWord2Vec();
            staticWord2Vec.cachePerDevice = this.cachePerDevice;
            staticWord2Vec.storage = this.storage;
            staticWord2Vec.vocabCache = this.vocabCache;
            staticWord2Vec.init();
            return staticWord2Vec;
        }
    }

    private StaticWord2Vec() {
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public String getUNK() {
        return this.unk;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public void setUNK(String str) {
        this.unk = str;
    }

    protected void init() {
        if (this.storage.size() != this.vocabCache.numWords()) {
            throw new RuntimeException("Number of words in Vocab isn't matching number of stored Vectors. vocab: [" + this.vocabCache.numWords() + "]; storage: [" + this.storage.size() + "]");
        }
        for (int i = 0; i < Nd4j.getAffinityManager().getNumberOfDevices(); i++) {
            this.cacheWrtDevice.add(new ConcurrentHashMap());
        }
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public boolean hasWord(String str) {
        return this.vocabCache.containsWord(str);
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearest(INDArray iNDArray, int i) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearestSum(INDArray iNDArray, int i) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearestSum(String str, int i) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearestSum(Collection<String> collection, Collection<String> collection2, int i) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Map<String, Double> accuracy(List<String> list) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public int indexOf(String str) {
        return this.vocabCache.indexOf(str);
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public List<String> similarWordsInVocabTo(String str, double d) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public double[] getWordVector(String str) {
        return getWordVectorMatrix(str).data().asDouble();
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public INDArray getWordVectorMatrixNormalized(String str) {
        return Transforms.unitVec(getWordVectorMatrix(str));
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public INDArray getWordVectorMatrix(String str) {
        int indexOf;
        if (hasWord(str)) {
            indexOf = this.vocabCache.indexOf(str);
        } else {
            if (getUNK() == null) {
                return null;
            }
            indexOf = this.vocabCache.indexOf(getUNK());
        }
        int intValue = Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue();
        if (this.cachePerDevice > 0 && this.cacheWrtDevice.get(intValue).containsKey(Integer.valueOf(indexOf))) {
            return this.cacheWrtDevice.get(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue()).get(Integer.valueOf(indexOf));
        }
        INDArray iNDArray = this.storage.get(Integer.valueOf(indexOf));
        if (this.cachePerDevice > 0) {
            long length = iNDArray.length() * iNDArray.data().getElementSize();
            if ((length * this.cacheWrtDevice.get(intValue).size()) + length < this.cachePerDevice) {
                this.cacheWrtDevice.get(intValue).put(Integer.valueOf(indexOf), iNDArray);
            }
        }
        return iNDArray;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public INDArray getWordVectors(Collection<String> collection) {
        ArrayList arrayList = new ArrayList();
        for (String str : collection) {
            if (hasWord(str) || getUNK() != null) {
                arrayList.add(getWordVectorMatrix(str));
            }
        }
        return Nd4j.vstack(arrayList);
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public INDArray getWordVectorsMean(Collection<String> collection) {
        return getWordVectors(collection).mean(new int[]{1});
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearest(Collection<String> collection, Collection<String> collection2, int i) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearest(String str, int i) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public double similarity(String str, String str2) {
        if (str == null || str2 == null) {
            log.debug("LABELS: " + str + ": " + (str == null ? "null" : BasicModelUtils.EXISTS) + ";" + str2 + " vec2:" + (str2 == null ? "null" : BasicModelUtils.EXISTS));
            return Double.NaN;
        }
        INDArray dup = getWordVectorMatrix(str).dup();
        INDArray dup2 = getWordVectorMatrix(str2).dup();
        if (dup == null || dup2 == null) {
            log.debug(str + ": " + (dup == null ? "null" : BasicModelUtils.EXISTS) + ";" + str2 + " vec2:" + (dup2 == null ? "null" : BasicModelUtils.EXISTS));
            return Double.NaN;
        }
        if (str.equals(str2)) {
            return 1.0d;
        }
        return Transforms.cosineSim(Transforms.unitVec(dup), Transforms.unitVec(dup2));
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public VocabCache vocab() {
        return this.vocabCache;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public WeightLookupTable lookupTable() {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public void setModelUtils(ModelUtils modelUtils) {
    }

    public void loadWeightsInto(INDArray iNDArray) {
        int vocabSize = (int) vocabSize();
        INDArray iNDArray2 = null;
        for (int i = 0; i < vocabSize; i++) {
            INDArray iNDArray3 = this.storage.get(Integer.valueOf(i));
            if (iNDArray3 == null) {
                if (iNDArray2 == null) {
                    iNDArray2 = Nd4j.create(iNDArray.dataType(), new long[]{1, iNDArray.size(1)});
                }
                iNDArray3 = iNDArray2;
            }
            iNDArray.putRow(i, iNDArray3);
        }
    }

    public long vocabSize() {
        return this.storage.size();
    }

    public int vectorSize() {
        INDArray iNDArray = this.storage.get(0);
        if (iNDArray != null) {
            return (int) iNDArray.length();
        }
        int vocabSize = (int) vocabSize();
        for (int i = 1; i < vocabSize; i++) {
            INDArray iNDArray2 = this.storage.get(0);
            if (iNDArray2 != null) {
                return (int) iNDArray2.length();
            }
        }
        throw new UnsupportedOperationException("No vectors found");
    }

    public boolean jsonSerializable() {
        return false;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public boolean outOfVocabularySupported() {
        return false;
    }
}
