/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.OptionalInt;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BpeAnalyzer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BpeTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.RobertaTokenizationResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

public class RobertaTokenizer
extends NlpTokenizer {
    public static final String UNKNOWN_TOKEN = "<unk>";
    public static final String SEPARATOR_TOKEN = "</s>";
    public static final String PAD_TOKEN = "<pad>";
    public static final String CLASS_TOKEN = "<s>";
    public static final String MASK_TOKEN = "<mask>";
    private static final Set<String> NEVER_SPLIT = Set.of("<mask>");
    private final BpeAnalyzer bpeAnalyzer;
    protected final List<String> originalVocab;
    private final SortedMap<String, Integer> vocab;
    protected final boolean withSpecialTokens;
    protected final int sepTokenId;
    private final int clsTokenId;
    protected final int padTokenId;
    private final int maxSequenceLength;

    protected RobertaTokenizer(List<String> originalVocab, SortedMap<String, Integer> vocab, List<String> merges, boolean isPrefixSpace, boolean withSpecialTokens, int maxSequenceLength, Set<String> neverSplit) {
        this.originalVocab = originalVocab;
        this.bpeAnalyzer = new BpeAnalyzer(originalVocab, merges, new ArrayList<String>(Sets.union(NEVER_SPLIT, neverSplit)), isPrefixSpace, UNKNOWN_TOKEN);
        this.vocab = vocab;
        this.withSpecialTokens = withSpecialTokens;
        this.maxSequenceLength = maxSequenceLength;
        if (!vocab.containsKey(UNKNOWN_TOKEN)) {
            throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required [{}] token", (Object[])new Object[]{UNKNOWN_TOKEN});
        }
        if (!vocab.containsKey(PAD_TOKEN)) {
            throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required [{}] token", (Object[])new Object[]{PAD_TOKEN});
        }
        this.padTokenId = (Integer)vocab.get(PAD_TOKEN);
        if (withSpecialTokens) {
            Set missingSpecialTokens = Sets.difference(Set.of(SEPARATOR_TOKEN, CLASS_TOKEN), vocab.keySet());
            if (!missingSpecialTokens.isEmpty()) {
                throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required {} token(s)", (Object[])new Object[]{missingSpecialTokens});
            }
            this.sepTokenId = (Integer)vocab.get(SEPARATOR_TOKEN);
            this.clsTokenId = (Integer)vocab.get(CLASS_TOKEN);
        } else {
            this.sepTokenId = -1;
            this.clsTokenId = -1;
        }
    }

    @Override
    int sepTokenId() {
        return this.sepTokenId;
    }

    @Override
    int maxSequenceLength() {
        return this.maxSequenceLength;
    }

    @Override
    boolean isWithSpecialTokens() {
        return this.withSpecialTokens;
    }

    @Override
    int getNumExtraTokensForSeqPair() {
        return 4;
    }

    @Override
    int defaultSpanForChunking(int maxWindowSize) {
        return (maxWindowSize - this.numExtraTokensForSingleSequence()) / 2;
    }

    @Override
    int numExtraTokensForSingleSequence() {
        return 2;
    }

    @Override
    int clsTokenId() {
        return this.clsTokenId;
    }

    @Override
    public String getPadToken() {
        return PAD_TOKEN;
    }

    public String getUnknownToken() {
        return UNKNOWN_TOKEN;
    }

    public void close() {
        this.bpeAnalyzer.close();
    }

    @Override
    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> tokenizations) {
        return new RobertaTokenizationResult(this.originalVocab, tokenizations, this.padTokenId);
    }

    @Override
    public NlpTask.RequestBuilder requestBuilder() {
        return (inputs, requestId, truncate, span, windowSize) -> this.buildTokenizationResult(IntStream.range(0, inputs.size()).boxed().flatMap(seqId -> this.tokenize((String)inputs.get((int)seqId), truncate, span, (int)seqId, windowSize).stream()).collect(Collectors.toList())).buildRequest(requestId, truncate);
    }

    @Override
    public OptionalInt getPadTokenId() {
        return OptionalInt.of(this.padTokenId);
    }

    @Override
    public OptionalInt getMaskTokenId() {
        Integer maskId = (Integer)this.vocab.get(MASK_TOKEN);
        if (maskId == null) {
            return OptionalInt.empty();
        }
        return OptionalInt.of(maskId);
    }

    @Override
    public String getMaskToken() {
        return MASK_TOKEN;
    }

    @Override
    public List<String> getVocabulary() {
        return this.originalVocab;
    }

    @Override
    TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepTokenId, boolean withSpecialTokens) {
        return new RobertaTokenizationResult.RobertaTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId);
    }

    @Override
    public NlpTokenizer.InnerTokenization innerTokenize(String seq) {
        ArrayList<Integer> tokenPositionMap = new ArrayList<Integer>();
        try (TokenStream ts = this.bpeAnalyzer.tokenStream("input", seq);){
            ts.reset();
            PositionIncrementAttribute tokenPos = (PositionIncrementAttribute)ts.addAttribute(PositionIncrementAttribute.class);
            int currPos = -1;
            while (ts.incrementToken()) {
                tokenPositionMap.add(currPos += tokenPos.getPositionIncrement());
            }
        }
        catch (IOException ex) {
            throw new UncheckedIOException(ex);
        }
        return new NlpTokenizer.InnerTokenization(new ArrayList<BpeTokenizer.BpeToken>(this.bpeAnalyzer.getTokens()), tokenPositionMap);
    }

    public static Builder builder(List<String> vocab, List<String> merges, RobertaTokenization tokenization) {
        return new Builder(vocab, merges, tokenization);
    }

    public static class Builder {
        protected final List<String> originalVocab;
        protected final List<String> merges;
        protected final SortedMap<String, Integer> vocab;
        protected boolean withSpecialTokens;
        protected boolean prefixSpace;
        protected int maxSequenceLength;
        protected Set<String> neverSplit;

        protected Builder(List<String> vocab, List<String> merges, RobertaTokenization tokenization) {
            this.originalVocab = vocab;
            this.vocab = Builder.buildSortedVocab(vocab);
            this.merges = merges;
            this.prefixSpace = tokenization.isAddPrefixSpace();
            this.withSpecialTokens = tokenization.withSpecialTokens();
            this.maxSequenceLength = tokenization.maxSequenceLength();
        }

        private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
            TreeMap<String, Integer> sortedVocab = new TreeMap<String, Integer>();
            for (int i = 0; i < vocab.size(); ++i) {
                sortedVocab.put(vocab.get(i), i);
            }
            return sortedVocab;
        }

        public Builder setNeverSplit(Set<String> neverSplit) {
            this.neverSplit = neverSplit;
            return this;
        }

        public Builder setMaxSequenceLength(int maxSequenceLength) {
            this.maxSequenceLength = maxSequenceLength;
            return this;
        }

        public Builder setWithSpecialTokens(boolean withSpecialTokens) {
            this.withSpecialTokens = withSpecialTokens;
            return this;
        }

        public RobertaTokenizer build() {
            if (this.neverSplit == null) {
                this.neverSplit = Collections.emptySet();
            }
            return new RobertaTokenizer(this.originalVocab, this.vocab, this.merges, this.prefixSpace, this.withSpecialTokens, this.maxSequenceLength, this.neverSplit);
        }
    }
}

