/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.queries;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentLocation;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.ml.queries.TextExpansionQueryBuilder;
import org.elasticsearch.xpack.ml.queries.TokenPruningConfig;

public class WeightedTokensQueryBuilder
extends AbstractQueryBuilder<WeightedTokensQueryBuilder> {
    public static final String NAME = "weighted_tokens";
    public static final ParseField TOKENS_FIELD = new ParseField("tokens", new String[0]);
    private final String fieldName;
    private final List<TextExpansionResults.WeightedToken> tokens;
    @Nullable
    private final TokenPruningConfig tokenPruningConfig;

    public WeightedTokensQueryBuilder(String fieldName, List<TextExpansionResults.WeightedToken> tokens) {
        this(fieldName, tokens, null);
    }

    public WeightedTokensQueryBuilder(String fieldName, List<TextExpansionResults.WeightedToken> tokens, @Nullable TokenPruningConfig tokenPruningConfig) {
        this.fieldName = Objects.requireNonNull(fieldName, "[weighted_tokens] requires a fieldName");
        this.tokens = Objects.requireNonNull(tokens, "[weighted_tokens] requires tokens");
        if (tokens.isEmpty()) {
            throw new IllegalArgumentException("[weighted_tokens] requires at least one token");
        }
        this.tokenPruningConfig = tokenPruningConfig;
    }

    public WeightedTokensQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        this.tokens = in.readCollectionAsList(TextExpansionResults.WeightedToken::new);
        this.tokenPruningConfig = (TokenPruningConfig)in.readOptionalWriteable(TokenPruningConfig::new);
    }

    public String getFieldName() {
        return this.fieldName;
    }

    @Nullable
    public TokenPruningConfig getTokenPruningConfig() {
        return this.tokenPruningConfig;
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        out.writeCollection(this.tokens);
        out.writeOptionalWriteable((Writeable)this.tokenPruningConfig);
    }

    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME);
        builder.startObject(this.fieldName);
        builder.startObject(TOKENS_FIELD.getPreferredName());
        for (TextExpansionResults.WeightedToken token : this.tokens) {
            token.toXContent(builder, params);
        }
        builder.endObject();
        if (this.tokenPruningConfig != null) {
            builder.field(TextExpansionQueryBuilder.PRUNING_CONFIG.getPreferredName(), (ToXContent)this.tokenPruningConfig);
        }
        this.boostAndQueryNameToXContent(builder);
        builder.endObject();
        builder.endObject();
    }

    private float getAverageTokenFreqRatio(IndexReader reader, int fieldDocCount) throws IOException {
        int numUniqueTokens = 0;
        for (LeafReaderContext leaf : reader.getContext().leaves()) {
            Terms terms = leaf.reader().terms(this.fieldName);
            if (terms == null) continue;
            numUniqueTokens = (int)Math.max(terms.size(), (long)numUniqueTokens);
        }
        if (numUniqueTokens == 0) {
            return 0.0f;
        }
        return (float)reader.getSumDocFreq(this.fieldName) / (float)fieldDocCount / (float)numUniqueTokens;
    }

    private boolean shouldKeepToken(IndexReader reader, TextExpansionResults.WeightedToken token, int fieldDocCount, float averageTokenFreqRatio, float bestWeight) throws IOException {
        if (this.tokenPruningConfig == null) {
            return true;
        }
        int docFreq = reader.docFreq(new Term(this.fieldName, token.token()));
        if (docFreq == 0) {
            return false;
        }
        float tokenFreqRatio = (float)docFreq / (float)fieldDocCount;
        return tokenFreqRatio < this.tokenPruningConfig.getTokensFreqRatioThreshold() * averageTokenFreqRatio || token.weight() > this.tokenPruningConfig.getTokensWeightThreshold() * bestWeight;
    }

    protected Query doToQuery(SearchExecutionContext context) throws IOException {
        MappedFieldType ft = context.getFieldType(this.fieldName);
        if (ft == null) {
            return new MatchNoDocsQuery("The \"" + this.getName() + "\" query is against a field that does not exist");
        }
        String fieldTypeName = ft.typeName();
        if (!TextExpansionQueryBuilder.AllowedFieldType.isFieldTypeAllowed(fieldTypeName)) {
            throw new ElasticsearchParseException("[" + fieldTypeName + "] is not an appropriate field type for this query. Allowed field types are [" + TextExpansionQueryBuilder.AllowedFieldType.getAllowedFieldTypesAsString() + "].", new Object[0]);
        }
        return this.tokenPruningConfig == null ? this.queryBuilderWithAllTokens(this.tokens, ft, context) : this.queryBuilderWithPrunedTokens(this.tokens, ft, context);
    }

    private Query queryBuilderWithAllTokens(List<TextExpansionResults.WeightedToken> tokens, MappedFieldType ft, SearchExecutionContext context) {
        BooleanQuery.Builder qb = new BooleanQuery.Builder();
        for (TextExpansionResults.WeightedToken token : tokens) {
            qb.add((Query)new BoostQuery(ft.termQuery((Object)token.token(), context), token.weight()), BooleanClause.Occur.SHOULD);
        }
        return qb.setMinimumNumberShouldMatch(1).build();
    }

    private Query queryBuilderWithPrunedTokens(List<TextExpansionResults.WeightedToken> tokens, MappedFieldType ft, SearchExecutionContext context) throws IOException {
        BooleanQuery.Builder qb = new BooleanQuery.Builder();
        int fieldDocCount = context.getIndexReader().getDocCount(this.fieldName);
        float bestWeight = tokens.stream().map(TextExpansionResults.WeightedToken::weight).reduce(Float.valueOf(0.0f), Math::max).floatValue();
        float averageTokenFreqRatio = this.getAverageTokenFreqRatio(context.getIndexReader(), fieldDocCount);
        if (averageTokenFreqRatio == 0.0f) {
            return new MatchNoDocsQuery("The \"" + this.getName() + "\" query is against an empty field");
        }
        for (TextExpansionResults.WeightedToken token : tokens) {
            boolean keep = this.shouldKeepToken(context.getIndexReader(), token, fieldDocCount, averageTokenFreqRatio, bestWeight);
            if (!(keep ^= this.tokenPruningConfig.isOnlyScorePrunedTokens())) continue;
            qb.add((Query)new BoostQuery(ft.termQuery((Object)token.token(), context), token.weight()), BooleanClause.Occur.SHOULD);
        }
        return qb.setMinimumNumberShouldMatch(1).build();
    }

    protected boolean doEquals(WeightedTokensQueryBuilder other) {
        return Objects.equals(this.fieldName, other.fieldName) && Objects.equals(this.tokenPruningConfig, other.tokenPruningConfig) && this.tokens.equals(other.tokens);
    }

    protected int doHashCode() {
        return Objects.hash(this.fieldName, this.tokens, this.tokenPruningConfig);
    }

    public String getWriteableName() {
        return NAME;
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.TEXT_EXPANSION_TOKEN_PRUNING_CONFIG_ADDED;
    }

    private static float parseWeight(String token, Object weight) throws IOException {
        if (weight instanceof Number) {
            Number asNumber = (Number)weight;
            return asNumber.floatValue();
        }
        if (weight instanceof String) {
            String asString = (String)weight;
            return Float.parseFloat(asString);
        }
        throw new ElasticsearchParseException("Illegal weight for token: [" + token + "], expected floating point got " + weight.getClass().getSimpleName(), new Object[0]);
    }

    public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) throws IOException {
        XContentParser.Token token;
        String currentFieldName = null;
        String fieldName = null;
        ArrayList<TextExpansionResults.WeightedToken> tokens = new ArrayList<TextExpansionResults.WeightedToken>();
        TokenPruningConfig tokenPruningConfig = null;
        float boost = 1.0f;
        String queryName = null;
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                currentFieldName = parser.currentName();
                continue;
            }
            if (token == XContentParser.Token.START_OBJECT) {
                WeightedTokensQueryBuilder.throwParsingExceptionOnMultipleFields((String)NAME, (XContentLocation)parser.getTokenLocation(), fieldName, (String)currentFieldName);
                fieldName = currentFieldName;
                while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
                    if (token == XContentParser.Token.FIELD_NAME) {
                        currentFieldName = parser.currentName();
                        continue;
                    }
                    if (TextExpansionQueryBuilder.PRUNING_CONFIG.match(currentFieldName, parser.getDeprecationHandler())) {
                        if (token != XContentParser.Token.START_OBJECT) {
                            throw new ParsingException(parser.getTokenLocation(), "[" + TextExpansionQueryBuilder.PRUNING_CONFIG.getPreferredName() + "] should be an object", new Object[0]);
                        }
                        tokenPruningConfig = TokenPruningConfig.fromXContent(parser);
                        continue;
                    }
                    if (TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                        Map tokensMap = parser.map();
                        for (Map.Entry e : tokensMap.entrySet()) {
                            tokens.add(new TextExpansionResults.WeightedToken((String)e.getKey(), WeightedTokensQueryBuilder.parseWeight((String)e.getKey(), e.getValue())));
                        }
                        continue;
                    }
                    if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                        boost = parser.floatValue();
                        continue;
                    }
                    if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                        queryName = parser.text();
                        continue;
                    }
                    throw new ParsingException(parser.getTokenLocation(), "unknown field [" + currentFieldName + "]", new Object[0]);
                }
                continue;
            }
            throw new IllegalArgumentException("invalid query");
        }
        if (fieldName == null) {
            throw new ParsingException(parser.getTokenLocation(), "No fieldname specified for query", new Object[0]);
        }
        WeightedTokensQueryBuilder qb = new WeightedTokensQueryBuilder(fieldName, tokens, tokenPruningConfig);
        qb.queryName(queryName);
        qb.boost(boost);
        return qb;
    }
}

