/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.Executor;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.Detector;

public class TransportEstimateModelMemoryAction
extends HandledTransportAction<EstimateModelMemoryAction.Request, EstimateModelMemoryAction.Response> {
    static final ByteSizeValue BASIC_REQUIREMENT = ByteSizeValue.ofMb((long)10L);
    static final long BYTES_PER_INFLUENCER_VALUE = ByteSizeValue.ofKb((long)10L).getBytes();
    private static final long BYTES_IN_MB = ByteSizeValue.ofMb((long)1L).getBytes();

    @Inject
    public TransportEstimateModelMemoryAction(TransportService transportService, ActionFilters actionFilters) {
        super("cluster:admin/xpack/ml/job/estimate_model_memory", transportService, actionFilters, EstimateModelMemoryAction.Request::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
    }

    protected void doExecute(Task task, EstimateModelMemoryAction.Request request, ActionListener<EstimateModelMemoryAction.Response> listener) {
        AnalysisConfig analysisConfig = request.getAnalysisConfig();
        Map overallCardinality = request.getOverallCardinality();
        Map maxBucketCardinality = request.getMaxBucketCardinality();
        long answer = BASIC_REQUIREMENT.getBytes();
        answer = TransportEstimateModelMemoryAction.addNonNegativeLongsWithMaxValueCap(answer, TransportEstimateModelMemoryAction.calculateDetectorsRequirementBytes(analysisConfig, overallCardinality));
        answer = TransportEstimateModelMemoryAction.addNonNegativeLongsWithMaxValueCap(answer, TransportEstimateModelMemoryAction.calculateInfluencerRequirementBytes(analysisConfig, maxBucketCardinality));
        answer = TransportEstimateModelMemoryAction.addNonNegativeLongsWithMaxValueCap(answer, TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfig, overallCardinality));
        listener.onResponse((Object)new EstimateModelMemoryAction.Response(TransportEstimateModelMemoryAction.roundUpToNextMb(answer)));
    }

    static long calculateDetectorsRequirementBytes(AnalysisConfig analysisConfig, Map<String, Long> overallCardinality) {
        long bucketSpanSeconds = analysisConfig.getBucketSpan().getSeconds();
        return analysisConfig.getDetectors().stream().map(detector -> TransportEstimateModelMemoryAction.calculateDetectorRequirementBytes(detector, bucketSpanSeconds, overallCardinality)).reduce(0L, TransportEstimateModelMemoryAction::addNonNegativeLongsWithMaxValueCap);
    }

    static long calculateDetectorRequirementBytes(Detector detector, long bucketSpanSeconds, Map<String, Long> overallCardinality) {
        String overFieldName;
        String byFieldName;
        long answer = 0L;
        boolean addFieldValueWorkspace = false;
        switch (detector.getFunction()) {
            case DISTINCT_COUNT: 
            case LOW_DISTINCT_COUNT: 
            case HIGH_DISTINCT_COUNT: {
                addFieldValueWorkspace = true;
            }
            case COUNT: 
            case LOW_COUNT: 
            case HIGH_COUNT: 
            case NON_ZERO_COUNT: 
            case LOW_NON_ZERO_COUNT: 
            case HIGH_NON_ZERO_COUNT: {
                answer = ByteSizeValue.ofKb((long)32L).getBytes();
                break;
            }
            case RARE: 
            case FREQ_RARE: {
                answer = ByteSizeValue.ofKb((long)2L).getBytes();
                break;
            }
            case INFO_CONTENT: 
            case LOW_INFO_CONTENT: 
            case HIGH_INFO_CONTENT: {
                addFieldValueWorkspace = true;
            }
            case MEAN: 
            case LOW_MEAN: 
            case HIGH_MEAN: 
            case AVG: 
            case LOW_AVG: 
            case HIGH_AVG: 
            case MIN: 
            case MAX: 
            case SUM: 
            case LOW_SUM: 
            case HIGH_SUM: 
            case NON_NULL_SUM: 
            case LOW_NON_NULL_SUM: 
            case HIGH_NON_NULL_SUM: 
            case VARP: 
            case LOW_VARP: 
            case HIGH_VARP: {
                answer = ByteSizeValue.ofKb((long)48L).getBytes();
                break;
            }
            case METRIC: {
                answer = ByteSizeValue.ofKb((long)120L).getBytes();
                break;
            }
            case MEDIAN: 
            case LOW_MEDIAN: 
            case HIGH_MEDIAN: {
                answer = ByteSizeValue.ofKb((long)64L).getBytes();
                break;
            }
            case TIME_OF_DAY: 
            case TIME_OF_WEEK: {
                answer = ByteSizeValue.ofKb((long)10L).getBytes();
                break;
            }
            case LAT_LONG: {
                answer = ByteSizeValue.ofKb((long)64L).getBytes();
                break;
            }
            default: {
                assert (false) : "unhandled detector function: " + detector.getFunction().getFullName();
                break;
            }
        }
        long partitionFieldCardinalityEstimate = 1L;
        String partitionFieldName = detector.getPartitionFieldName();
        if (partitionFieldName != null) {
            partitionFieldCardinalityEstimate = Math.max(1L, TransportEstimateModelMemoryAction.cardinalityEstimate(Detector.PARTITION_FIELD_NAME_FIELD.getPreferredName(), partitionFieldName, overallCardinality, true));
        }
        if ((byFieldName = detector.getByFieldName()) != null) {
            long byFieldCardinalityEstimate = TransportEstimateModelMemoryAction.cardinalityEstimate(Detector.BY_FIELD_NAME_FIELD.getPreferredName(), byFieldName, overallCardinality, true);
            double multiplier = Math.ceil(TransportEstimateModelMemoryAction.reducedCardinality(byFieldCardinalityEstimate, partitionFieldCardinalityEstimate, bucketSpanSeconds) * 2.0 / 3.0);
            answer = TransportEstimateModelMemoryAction.multiplyNonNegativeLongsWithMaxValueCap(answer, (long)multiplier);
        }
        if ((overFieldName = detector.getOverFieldName()) != null) {
            long overFieldCardinalityEstimate = TransportEstimateModelMemoryAction.cardinalityEstimate(Detector.OVER_FIELD_NAME_FIELD.getPreferredName(), overFieldName, overallCardinality, true);
            double multiplier = Math.ceil(TransportEstimateModelMemoryAction.reducedCardinality(overFieldCardinalityEstimate, partitionFieldCardinalityEstimate, bucketSpanSeconds));
            answer = TransportEstimateModelMemoryAction.addNonNegativeLongsWithMaxValueCap(answer, TransportEstimateModelMemoryAction.multiplyNonNegativeLongsWithMaxValueCap(768L, (long)multiplier));
        }
        if (partitionFieldName != null) {
            answer = TransportEstimateModelMemoryAction.multiplyNonNegativeLongsWithMaxValueCap(answer, partitionFieldCardinalityEstimate);
        }
        if (addFieldValueWorkspace) {
            answer = TransportEstimateModelMemoryAction.addNonNegativeLongsWithMaxValueCap(answer, ByteSizeValue.ofMb((long)5L).getBytes());
        }
        return answer;
    }

    static long calculateInfluencerRequirementBytes(AnalysisConfig analysisConfig, Map<String, Long> maxBucketCardinality) {
        HashSet pureInfluencers = new HashSet(analysisConfig.getInfluencers());
        for (Detector detector : analysisConfig.getDetectors()) {
            pureInfluencers.removeAll(detector.extractAnalysisFields());
        }
        long totalInfluencerCardinality = pureInfluencers.stream().map(influencer -> TransportEstimateModelMemoryAction.cardinalityEstimate(AnalysisConfig.INFLUENCERS.getPreferredName(), influencer, maxBucketCardinality, false)).reduce(0L, TransportEstimateModelMemoryAction::addNonNegativeLongsWithMaxValueCap);
        return TransportEstimateModelMemoryAction.multiplyNonNegativeLongsWithMaxValueCap(BYTES_PER_INFLUENCER_VALUE, totalInfluencerCardinality);
    }

    static long calculateCategorizationRequirementBytes(AnalysisConfig analysisConfig, Map<String, Long> overallCardinality) {
        if (analysisConfig.getCategorizationFieldName() == null) {
            return 0L;
        }
        long memoryPerPartitionMb = 20L;
        long relevantPartitionFieldCardinalityEstimate = 1L;
        if (analysisConfig.getPerPartitionCategorizationConfig().isEnabled()) {
            for (Detector detector : analysisConfig.getDetectors()) {
                String partitionFieldName = detector.getPartitionFieldName();
                if (partitionFieldName == null) continue;
                relevantPartitionFieldCardinalityEstimate = Math.max(1L, TransportEstimateModelMemoryAction.cardinalityEstimate(Detector.PARTITION_FIELD_NAME_FIELD.getPreferredName(), partitionFieldName, overallCardinality, true));
                break;
            }
            if (!analysisConfig.getPerPartitionCategorizationConfig().isStopOnWarn()) {
                memoryPerPartitionMb *= 2L;
            }
        } else {
            memoryPerPartitionMb *= 2L;
        }
        return ByteSizeValue.ofMb((long)(memoryPerPartitionMb * relevantPartitionFieldCardinalityEstimate)).getBytes();
    }

    static long cardinalityEstimate(String description, String fieldName, Map<String, Long> suppliedCardinailityEstimates, boolean isOverall) {
        Long suppliedEstimate = suppliedCardinailityEstimates.get(fieldName);
        if (suppliedEstimate != null) {
            return suppliedEstimate;
        }
        if ("mlcategory".equals(fieldName)) {
            return isOverall ? 500L : 50L;
        }
        throw new IllegalArgumentException("[" + (isOverall ? "Overall" : "Bucket max") + "] cardinality estimate required for [" + description + "] [" + fieldName + "] but not supplied");
    }

    static ByteSizeValue roundUpToNextMb(long bytes) {
        assert (bytes >= 0L) : "negative bytes " + bytes;
        return ByteSizeValue.ofMb((long)(TransportEstimateModelMemoryAction.addNonNegativeLongsWithMaxValueCap(bytes, BYTES_IN_MB - 1L) / BYTES_IN_MB));
    }

    static double reducedCardinality(long cardinalityToReduce, long partitionFieldCardinalityEstimate, long bucketSpanSeconds) {
        assert (cardinalityToReduce >= 0L) : "negative cardinality to reduce " + cardinalityToReduce;
        assert (partitionFieldCardinalityEstimate > 0L) : "non-positive partition field cardinality " + partitionFieldCardinalityEstimate;
        assert (bucketSpanSeconds > 0L) : "non-positive bucket span " + bucketSpanSeconds;
        if (cardinalityToReduce == 0L) {
            return 0.0;
        }
        double power = Math.min(1.0, (Math.log10(bucketSpanSeconds) + 1.0) / 8.0);
        return (double)cardinalityToReduce / Math.pow(Math.min(cardinalityToReduce, partitionFieldCardinalityEstimate), power);
    }

    static long addNonNegativeLongsWithMaxValueCap(long a, long b) {
        assert (a >= 0L);
        assert (b >= 0L);
        if (Long.MAX_VALUE - a - b < 0L) {
            return Long.MAX_VALUE;
        }
        return a + b;
    }

    static long multiplyNonNegativeLongsWithMaxValueCap(long a, long b) {
        assert (a >= 0L);
        assert (b >= 0L);
        if (a == 0L || b == 0L) {
            return 0L;
        }
        if (Long.MAX_VALUE / a < b) {
            return Long.MAX_VALUE;
        }
        return a * b;
    }
}

