package org.deeplearning4j.iterator.bert;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/iterator/bert/BertMaskedLMMasker.class */
public class BertMaskedLMMasker implements BertSequenceMasker {
    public static final double DEFAULT_MASK_PROB = 0.15d;
    public static final double DEFAULT_MASK_TOKEN_PROB = 0.8d;
    public static final double DEFAULT_RANDOM_WORD_PROB = 0.1d;
    protected final Random r;
    protected final double maskProb;
    protected final double maskTokenProb;
    protected final double randomTokenProb;

    public BertMaskedLMMasker() {
        this(new Random(), 0.15d, 0.8d, 0.1d);
    }

    public BertMaskedLMMasker(Random random, double d, double d2, double d3) {
        Preconditions.checkArgument(d > 0.0d && d < 1.0d, "Probability must be beteen 0 and 1, got %s", d);
        Preconditions.checkState(d2 >= 0.0d && d2 <= 1.0d, "Mask token probability must be between 0 and 1, got %s", d2);
        Preconditions.checkState(d3 >= 0.0d && d3 <= 1.0d, "Random token probability must be between 0 and 1, got %s", d3);
        Preconditions.checkState(d2 + d3 <= 1.0d, "Sum of maskTokenProb (%s) and randomTokenProb (%s) must be <= 1.0, got sum is %s", d2, d3, d2 + d3);
        this.r = random;
        this.maskProb = d;
        this.maskTokenProb = d2;
        this.randomTokenProb = d3;
    }

    @Override // org.deeplearning4j.iterator.bert.BertSequenceMasker
    public Pair<List<String>, boolean[]> maskSequence(List<String> list, String str, List<String> list2) {
        ArrayList arrayList = new ArrayList(list.size());
        boolean[] zArr = new boolean[list.size()];
        for (int i = 0; i < list.size(); i++) {
            if (this.r.nextDouble() < this.maskProb) {
                double nextDouble = this.r.nextDouble();
                if (nextDouble < this.maskTokenProb) {
                    arrayList.add(str);
                } else if (nextDouble < this.maskTokenProb + this.randomTokenProb) {
                    arrayList.add(list2.get(this.r.nextInt(list2.size())));
                } else {
                    arrayList.add(list.get(i));
                }
                zArr[i] = true;
            } else {
                arrayList.add(list.get(i));
            }
        }
        return new Pair<>(arrayList, zArr);
    }
}
