/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.aggs.inference;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation;
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.ml.aggs.inference.InternalInferenceAggregation;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;

public class InferencePipelineAggregator
extends PipelineAggregator {
    private final Map<String, String> bucketPathMap;
    private final InferenceConfigUpdate configUpdate;
    private final LocalModel model;

    public InferencePipelineAggregator(String name, Map<String, String> bucketPathMap, Map<String, Object> metaData, InferenceConfigUpdate configUpdate, LocalModel model) {
        super(name, bucketPathMap.values().toArray(new String[0]), metaData);
        this.bucketPathMap = bucketPathMap;
        this.configUpdate = configUpdate;
        this.model = model;
    }

    public InternalAggregation reduce(InternalAggregation aggregation, AggregationReduceContext reduceContext) {
        try (LocalModel localModel = this.model;){
            InternalMultiBucketAggregation originalAgg = (InternalMultiBucketAggregation)aggregation;
            List buckets = originalAgg.getBuckets();
            ArrayList<InternalMultiBucketAggregation.InternalBucket> newBuckets = new ArrayList<InternalMultiBucketAggregation.InternalBucket>();
            for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
                InferenceResults inference;
                HashMap<String, Object> inputFields = new HashMap<String, Object>();
                if (bucket.getDocCount() == 0L && !this.bucketPathMap.containsKey("_count")) {
                    newBuckets.add(bucket);
                    continue;
                }
                for (Map.Entry<String, String> entry : this.bucketPathMap.entrySet()) {
                    double doubleVal;
                    String aggName = entry.getKey();
                    String bucketPath = entry.getValue();
                    Object propertyValue = InferencePipelineAggregator.resolveBucketValue((MultiBucketsAggregation)originalAgg, bucket, bucketPath);
                    if (propertyValue instanceof Number) {
                        Number numberValue = (Number)propertyValue;
                        doubleVal = numberValue.doubleValue();
                        if (!Double.isFinite(doubleVal)) continue;
                        inputFields.put(aggName, doubleVal);
                        continue;
                    }
                    if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) {
                        InternalNumericMetricsAggregation.SingleValue singleValue = (InternalNumericMetricsAggregation.SingleValue)propertyValue;
                        doubleVal = singleValue.value();
                        if (!Double.isFinite(doubleVal)) continue;
                        inputFields.put(aggName, doubleVal);
                        continue;
                    }
                    if (propertyValue instanceof StringTerms.Bucket) {
                        StringTerms.Bucket b = (StringTerms.Bucket)propertyValue;
                        inputFields.put(aggName, b.getKeyAsString());
                        continue;
                    }
                    if (propertyValue instanceof String) {
                        inputFields.put(aggName, propertyValue);
                        continue;
                    }
                    if (propertyValue == null) continue;
                    throw InferencePipelineAggregator.invalidAggTypeError(bucketPath, propertyValue);
                }
                try {
                    inference = this.model.infer(inputFields, this.configUpdate);
                }
                catch (Exception e) {
                    inference = new WarningInferenceResults(e.getMessage());
                }
                ArrayList<InternalInferenceAggregation> aggs = new ArrayList<InternalInferenceAggregation>(bucket.getAggregations().asList());
                InternalInferenceAggregation aggResult = new InternalInferenceAggregation(this.name(), this.metadata(), inference);
                aggs.add(aggResult);
                InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(InternalAggregations.from(aggs), bucket);
                newBuckets.add(newBucket);
            }
            assert (this.model.getReferenceCount() > 0L);
            InternalMultiBucketAggregation internalMultiBucketAggregation = originalAgg.create(newBuckets);
            return internalMultiBucketAggregation;
        }
    }

    public static Object resolveBucketValue(MultiBucketsAggregation agg, InternalMultiBucketAggregation.InternalBucket bucket, String aggPath) {
        List aggPathsList = AggregationPath.parse((String)aggPath).getPathElementsAsStringList();
        return bucket.getProperty(agg.getName(), aggPathsList);
    }

    private static IllegalArgumentException invalidAggTypeError(String aggPath, Object propertyValue) {
        String msg = AbstractPipelineAggregationBuilder.BUCKETS_PATH_FIELD.getPreferredName() + " must reference either a number value, a single value numeric metric aggregation or a string: got [" + propertyValue + "] of type [" + propertyValue.getClass().getSimpleName() + "] ] at aggregation [" + aggPath + "]";
        return new IllegalArgumentException(msg);
    }
}

