/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.queries;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.VersionId;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentLocation;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
import org.elasticsearch.xpack.ml.queries.TokenPruningConfig;
import org.elasticsearch.xpack.ml.queries.WeightedTokensQueryBuilder;

public class TextExpansionQueryBuilder
extends AbstractQueryBuilder<TextExpansionQueryBuilder> {
    public static final String NAME = "text_expansion";
    public static final ParseField PRUNING_CONFIG = new ParseField("pruning_config", new String[0]);
    public static final ParseField MODEL_TEXT = new ParseField("model_text", new String[0]);
    public static final ParseField MODEL_ID = new ParseField("model_id", new String[0]);
    private final String fieldName;
    private final String modelText;
    private final String modelId;
    private SetOnce<TextExpansionResults> weightedTokensSupplier;
    private final TokenPruningConfig tokenPruningConfig;

    public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId) {
        this(fieldName, modelText, modelId, null);
    }

    public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId, @Nullable TokenPruningConfig tokenPruningConfig) {
        if (fieldName == null) {
            throw new IllegalArgumentException("[text_expansion] requires a fieldName");
        }
        if (modelText == null) {
            throw new IllegalArgumentException("[text_expansion] requires a " + MODEL_TEXT.getPreferredName() + " value");
        }
        if (modelId == null) {
            throw new IllegalArgumentException("[text_expansion] requires a " + MODEL_ID.getPreferredName() + " value");
        }
        this.fieldName = fieldName;
        this.modelText = modelText;
        this.modelId = modelId;
        this.tokenPruningConfig = tokenPruningConfig;
    }

    public TextExpansionQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        this.modelText = in.readString();
        this.modelId = in.readString();
        this.tokenPruningConfig = in.getTransportVersion().onOrAfter((VersionId)TransportVersions.TEXT_EXPANSION_TOKEN_PRUNING_CONFIG_ADDED) ? (TokenPruningConfig)in.readOptionalWriteable(TokenPruningConfig::new) : null;
    }

    private TextExpansionQueryBuilder(TextExpansionQueryBuilder other, SetOnce<TextExpansionResults> weightedTokensSupplier) {
        this.fieldName = other.fieldName;
        this.modelText = other.modelText;
        this.modelId = other.modelId;
        this.tokenPruningConfig = other.tokenPruningConfig;
        this.boost = other.boost;
        this.queryName = other.queryName;
        this.weightedTokensSupplier = weightedTokensSupplier;
    }

    String getFieldName() {
        return this.fieldName;
    }

    public TokenPruningConfig getTokenPruningConfig() {
        return this.tokenPruningConfig;
    }

    public String getWriteableName() {
        return NAME;
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.V_8_8_0;
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        if (this.weightedTokensSupplier != null) {
            throw new IllegalStateException("token supplier must be null, can't serialize suppliers, missing a rewriteAndFetch?");
        }
        out.writeString(this.fieldName);
        out.writeString(this.modelText);
        out.writeString(this.modelId);
        if (out.getTransportVersion().onOrAfter((VersionId)TransportVersions.TEXT_EXPANSION_TOKEN_PRUNING_CONFIG_ADDED)) {
            out.writeOptionalWriteable((Writeable)this.tokenPruningConfig);
        }
    }

    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME);
        builder.startObject(this.fieldName);
        builder.field(MODEL_TEXT.getPreferredName(), this.modelText);
        builder.field(MODEL_ID.getPreferredName(), this.modelId);
        if (this.tokenPruningConfig != null) {
            builder.field(PRUNING_CONFIG.getPreferredName(), (ToXContent)this.tokenPruningConfig);
        }
        this.boostAndQueryNameToXContent(builder);
        builder.endObject();
        builder.endObject();
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
        if (this.weightedTokensSupplier != null) {
            if (this.weightedTokensSupplier.get() == null) {
                return this;
            }
            return this.weightedTokensToQuery(this.fieldName, (TextExpansionResults)this.weightedTokensSupplier.get());
        }
        CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput((String)this.modelId, List.of(this.modelText), (InferenceConfigUpdate)TextExpansionConfigUpdate.EMPTY_UPDATE, (Boolean)false, (TimeValue)InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API);
        inferRequest.setHighPriority(true);
        inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
        SetOnce textExpansionResultsSupplier = new SetOnce();
        queryRewriteContext.registerAsyncAction((client, listener) -> ClientHelper.executeAsyncWithOrigin((Client)client, (String)"ml", (ActionType)CoordinatedInferenceAction.INSTANCE, (ActionRequest)inferRequest, (ActionListener)ActionListener.wrap(inferenceResponse -> {
            if (inferenceResponse.getInferenceResults().isEmpty()) {
                listener.onFailure((Exception)new IllegalStateException("inference response contain no results"));
                return;
            }
            Object patt8001$temp = inferenceResponse.getInferenceResults().get(0);
            if (patt8001$temp instanceof TextExpansionResults) {
                TextExpansionResults textExpansionResults = (TextExpansionResults)patt8001$temp;
                textExpansionResultsSupplier.set((Object)textExpansionResults);
                listener.onResponse(null);
            } else {
                Object patt8266$temp = inferenceResponse.getInferenceResults().get(0);
                if (patt8266$temp instanceof WarningInferenceResults) {
                    WarningInferenceResults warning = (WarningInferenceResults)patt8266$temp;
                    listener.onFailure((Exception)new IllegalStateException(warning.getWarning()));
                } else {
                    listener.onFailure((Exception)new IllegalStateException("expected a result of type [text_expansion_result] received [" + ((InferenceResults)inferenceResponse.getInferenceResults().get(0)).getWriteableName() + "]. Is [" + this.modelId + "] a compatible model?"));
                }
            }
        }, arg_0 -> ((ActionListener)listener).onFailure(arg_0))));
        return new TextExpansionQueryBuilder(this, (SetOnce<TextExpansionResults>)textExpansionResultsSupplier);
    }

    private QueryBuilder weightedTokensToQuery(String fieldName, TextExpansionResults textExpansionResults) {
        if (this.tokenPruningConfig != null) {
            WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder(fieldName, textExpansionResults.getWeightedTokens(), this.tokenPruningConfig);
            weightedTokensQueryBuilder.queryName(this.queryName);
            weightedTokensQueryBuilder.boost(this.boost);
            return weightedTokensQueryBuilder;
        }
        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
        for (TextExpansionResults.WeightedToken weightedToken : textExpansionResults.getWeightedTokens()) {
            boolQuery.should((QueryBuilder)QueryBuilders.termQuery((String)fieldName, (String)weightedToken.token()).boost(weightedToken.weight()));
        }
        boolQuery.minimumShouldMatch(1);
        boolQuery.boost(this.boost);
        boolQuery.queryName(this.queryName);
        return boolQuery;
    }

    protected Query doToQuery(SearchExecutionContext context) {
        throw new IllegalStateException("text_expansion should have been rewritten to another query type");
    }

    protected boolean doEquals(TextExpansionQueryBuilder other) {
        return Objects.equals(this.fieldName, other.fieldName) && Objects.equals(this.modelText, other.modelText) && Objects.equals(this.modelId, other.modelId) && Objects.equals(this.tokenPruningConfig, other.tokenPruningConfig) && Objects.equals(this.weightedTokensSupplier, other.weightedTokensSupplier);
    }

    protected int doHashCode() {
        return Objects.hash(this.fieldName, this.modelText, this.modelId, this.tokenPruningConfig, this.weightedTokensSupplier);
    }

    public static TextExpansionQueryBuilder fromXContent(XContentParser parser) throws IOException {
        XContentParser.Token token;
        String fieldName = null;
        String modelText = null;
        String modelId = null;
        TokenPruningConfig tokenPruningConfig = null;
        float boost = 1.0f;
        String queryName = null;
        String currentFieldName = null;
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                currentFieldName = parser.currentName();
                continue;
            }
            if (token == XContentParser.Token.START_OBJECT) {
                TextExpansionQueryBuilder.throwParsingExceptionOnMultipleFields((String)NAME, (XContentLocation)parser.getTokenLocation(), fieldName, (String)currentFieldName);
                fieldName = currentFieldName;
                while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
                    if (token == XContentParser.Token.FIELD_NAME) {
                        currentFieldName = parser.currentName();
                        continue;
                    }
                    if (token == XContentParser.Token.START_OBJECT) {
                        if (PRUNING_CONFIG.match(currentFieldName, parser.getDeprecationHandler())) {
                            tokenPruningConfig = TokenPruningConfig.fromXContent(parser);
                            continue;
                        }
                        throw new ParsingException(parser.getTokenLocation(), "[text_expansion] unknown token [" + token + "] after [" + currentFieldName + "]", new Object[0]);
                    }
                    if (token.isValue()) {
                        if (MODEL_TEXT.match(currentFieldName, parser.getDeprecationHandler())) {
                            modelText = parser.text();
                            continue;
                        }
                        if (MODEL_ID.match(currentFieldName, parser.getDeprecationHandler())) {
                            modelId = parser.text();
                            continue;
                        }
                        if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            boost = parser.floatValue();
                            continue;
                        }
                        if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            queryName = parser.text();
                            continue;
                        }
                        throw new ParsingException(parser.getTokenLocation(), "[text_expansion] query does not support [" + currentFieldName + "]", new Object[0]);
                    }
                    throw new ParsingException(parser.getTokenLocation(), "[text_expansion] unknown token [" + token + "] after [" + currentFieldName + "]", new Object[0]);
                }
                continue;
            }
            TextExpansionQueryBuilder.throwParsingExceptionOnMultipleFields((String)NAME, (XContentLocation)parser.getTokenLocation(), fieldName, (String)parser.currentName());
            fieldName = parser.currentName();
            modelText = parser.text();
        }
        if (modelText == null) {
            throw new ParsingException(parser.getTokenLocation(), "No text specified for text query", new Object[0]);
        }
        if (fieldName == null) {
            throw new ParsingException(parser.getTokenLocation(), "No fieldname specified for query", new Object[0]);
        }
        TextExpansionQueryBuilder queryBuilder = new TextExpansionQueryBuilder(fieldName, modelText, modelId, tokenPruningConfig);
        queryBuilder.queryName(queryName);
        queryBuilder.boost(boost);
        return queryBuilder;
    }

    public static enum AllowedFieldType {
        RANK_FEATURES("rank_features"),
        SPARSE_VECTOR("sparse_vector");

        private final String typeName;

        private AllowedFieldType(String typeName) {
            this.typeName = typeName;
        }

        public String getTypeName() {
            return this.typeName;
        }

        public static boolean isFieldTypeAllowed(String typeName) {
            return Arrays.stream(AllowedFieldType.values()).anyMatch(value -> value.typeName.equals(typeName));
        }

        public static String getAllowedFieldTypesAsString() {
            return Arrays.stream(AllowedFieldType.values()).map(value -> value.typeName).collect(Collectors.joining(", "));
        }
    }
}

