/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml;

import java.time.Instant;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.lucene.util.Counter;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.env.Environment;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.protocol.xpack.XPackUsageRequest;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackFeatureSet;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureResponse;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureTransportAction;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.MachineLearningFeatureSetUsage;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats;
import org.elasticsearch.xpack.core.ml.stats.ForecastStats;
import org.elasticsearch.xpack.core.ml.stats.StatsAccumulator;
import org.elasticsearch.xpack.ml.DefaultMachineLearningExtension;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.MachineLearningExtension;
import org.elasticsearch.xpack.ml.MachineLearningExtensionHolder;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;

public class MachineLearningUsageTransportAction
extends XPackUsageFeatureTransportAction {
    private static final Logger logger = LogManager.getLogger(MachineLearningUsageTransportAction.class);
    private final Client client;
    private final XPackLicenseState licenseState;
    private final JobManagerHolder jobManagerHolder;
    private final MachineLearningExtension machineLearningExtension;
    private final boolean enabled;

    @Inject
    public MachineLearningUsageTransportAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Environment environment, Client client, XPackLicenseState licenseState, JobManagerHolder jobManagerHolder, MachineLearningExtensionHolder machineLearningExtensionHolder) {
        super(XPackUsageFeatureAction.MACHINE_LEARNING.name(), transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver);
        this.client = new OriginSettingClient(client, "ml");
        this.licenseState = licenseState;
        this.jobManagerHolder = jobManagerHolder;
        this.machineLearningExtension = machineLearningExtensionHolder.isEmpty() ? new DefaultMachineLearningExtension() : machineLearningExtensionHolder.getMachineLearningExtension();
        this.enabled = (Boolean)XPackSettings.MACHINE_LEARNING_ENABLED.get(environment.settings());
    }

    protected void masterOperation(Task task, XPackUsageRequest request, ClusterState state, ActionListener<XPackUsageFeatureResponse> listener) {
        if (!this.enabled) {
            MachineLearningFeatureSetUsage usage = new MachineLearningFeatureSetUsage(MachineLearningField.ML_API_FEATURE.checkWithoutTracking(this.licenseState), this.enabled, Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), 0);
            listener.onResponse((Object)new XPackUsageFeatureResponse((XPackFeatureSet.Usage)usage));
            return;
        }
        LinkedHashMap jobsUsage = new LinkedHashMap();
        LinkedHashMap datafeedsUsage = new LinkedHashMap();
        LinkedHashMap analyticsUsage = new LinkedHashMap();
        int nodeCount = MachineLearningUsageTransportAction.mlNodeCount(state);
        ActionListener inferenceUsageListener = ActionListener.wrap(inferenceUsage -> listener.onResponse((Object)new XPackUsageFeatureResponse((XPackFeatureSet.Usage)new MachineLearningFeatureSetUsage(MachineLearningField.ML_API_FEATURE.checkWithoutTracking(this.licenseState), this.enabled, jobsUsage, datafeedsUsage, analyticsUsage, inferenceUsage, nodeCount))), e -> {
            logger.warn("Failed to get inference usage to include in ML usage", (Throwable)e);
            listener.onResponse((Object)new XPackUsageFeatureResponse((XPackFeatureSet.Usage)new MachineLearningFeatureSetUsage(MachineLearningField.ML_API_FEATURE.checkWithoutTracking(this.licenseState), this.enabled, jobsUsage, datafeedsUsage, analyticsUsage, Collections.emptyMap(), nodeCount)));
        });
        ActionListener dataframeAnalyticsListener = ActionListener.wrap(response -> {
            MachineLearningUsageTransportAction.addDataFrameAnalyticsUsage(response, analyticsUsage);
            this.addInferenceUsage((ActionListener<Map<String, Object>>)inferenceUsageListener);
        }, e -> {
            logger.warn("Failed to get data frame analytics configs to include in ML usage", (Throwable)e);
            this.addInferenceUsage((ActionListener<Map<String, Object>>)inferenceUsageListener);
        });
        GetDataFrameAnalyticsAction.Request getDfaRequest = new GetDataFrameAnalyticsAction.Request("_all");
        getDfaRequest.setPageParams(new PageParams(0, 10000));
        ActionListener dataframeAnalyticsStatsListener = ActionListener.wrap(response -> {
            MachineLearningUsageTransportAction.addDataFrameAnalyticsStatsUsage(response, analyticsUsage);
            this.client.execute((ActionType)GetDataFrameAnalyticsAction.INSTANCE, (ActionRequest)getDfaRequest, dataframeAnalyticsListener);
        }, e -> {
            logger.warn("Failed to get data frame analytics stats to include in ML usage", (Throwable)e);
            this.client.execute((ActionType)GetDataFrameAnalyticsAction.INSTANCE, (ActionRequest)getDfaRequest, dataframeAnalyticsListener);
        });
        GetDataFrameAnalyticsStatsAction.Request dataframeAnalyticsStatsRequest = new GetDataFrameAnalyticsStatsAction.Request("_all");
        dataframeAnalyticsStatsRequest.setPageParams(new PageParams(0, 10000));
        ActionListener datafeedStatsListener = ActionListener.wrap(response -> {
            MachineLearningUsageTransportAction.addDatafeedsUsage(response, datafeedsUsage);
            if (this.machineLearningExtension.isDataFrameAnalyticsEnabled()) {
                this.client.execute((ActionType)GetDataFrameAnalyticsStatsAction.INSTANCE, (ActionRequest)dataframeAnalyticsStatsRequest, dataframeAnalyticsStatsListener);
            } else {
                this.addInferenceUsage((ActionListener<Map<String, Object>>)inferenceUsageListener);
            }
        }, e -> {
            logger.warn("Failed to get datafeed stats to include in ML usage", (Throwable)e);
            if (this.machineLearningExtension.isDataFrameAnalyticsEnabled()) {
                this.client.execute((ActionType)GetDataFrameAnalyticsStatsAction.INSTANCE, (ActionRequest)dataframeAnalyticsStatsRequest, dataframeAnalyticsStatsListener);
            } else {
                this.addInferenceUsage((ActionListener<Map<String, Object>>)inferenceUsageListener);
            }
        });
        GetDatafeedsStatsAction.Request datafeedStatsRequest = new GetDatafeedsStatsAction.Request("_all");
        ActionListener jobStatsListener = ActionListener.wrap(response -> this.jobManagerHolder.getJobManager().expandJobs("_all", true, (ActionListener<QueryPage<Job>>)ActionListener.wrap(jobs -> {
            this.addJobsUsage((GetJobsStatsAction.Response)response, jobs.results(), jobsUsage);
            this.client.execute((ActionType)GetDatafeedsStatsAction.INSTANCE, (ActionRequest)datafeedStatsRequest, datafeedStatsListener);
        }, e -> {
            logger.warn("Failed to get job configs to include in ML usage", (Throwable)e);
            this.client.execute((ActionType)GetDatafeedsStatsAction.INSTANCE, (ActionRequest)datafeedStatsRequest, datafeedStatsListener);
        })), e -> {
            logger.warn("Failed to get job stats to include in ML usage", (Throwable)e);
            this.client.execute((ActionType)GetDatafeedsStatsAction.INSTANCE, (ActionRequest)datafeedStatsRequest, datafeedStatsListener);
        });
        if (this.machineLearningExtension.isAnomalyDetectionEnabled()) {
            GetJobsStatsAction.Request jobStatsRequest = new GetJobsStatsAction.Request("_all");
            this.client.execute((ActionType)GetJobsStatsAction.INSTANCE, (ActionRequest)jobStatsRequest, jobStatsListener);
        } else if (this.machineLearningExtension.isDataFrameAnalyticsEnabled()) {
            this.client.execute((ActionType)GetDataFrameAnalyticsStatsAction.INSTANCE, (ActionRequest)dataframeAnalyticsStatsRequest, dataframeAnalyticsStatsListener);
        } else {
            this.addInferenceUsage((ActionListener<Map<String, Object>>)inferenceUsageListener);
        }
    }

    private void addJobsUsage(GetJobsStatsAction.Response response, List<Job> jobs, Map<String, Object> jobsUsage) {
        StatsAccumulator allJobsDetectorsStats = new StatsAccumulator();
        StatsAccumulator allJobsModelSizeStats = new StatsAccumulator();
        ForecastStats allJobsForecastStats = new ForecastStats();
        EnumMap<JobState, Counter> jobCountByState = new EnumMap<JobState, Counter>(JobState.class);
        EnumMap<JobState, StatsAccumulator> detectorStatsByState = new EnumMap<JobState, StatsAccumulator>(JobState.class);
        EnumMap<JobState, StatsAccumulator> modelSizeStatsByState = new EnumMap<JobState, StatsAccumulator>(JobState.class);
        EnumMap<JobState, ForecastStats> forecastStatsByState = new EnumMap<JobState, ForecastStats>(JobState.class);
        EnumMap<JobState, Map> createdByByState = new EnumMap<JobState, Map>(JobState.class);
        List jobsStats = response.getResponse().results();
        Map<String, Job> jobMap = jobs.stream().collect(Collectors.toMap(Job::getId, item -> item));
        Map<String, Long> allJobsCreatedBy = jobs.stream().map(MachineLearningUsageTransportAction::jobCreatedBy).collect(Collectors.groupingBy(item -> item, Collectors.counting()));
        for (GetJobsStatsAction.Response.JobStats jobStats : jobsStats) {
            Job job = jobMap.get(jobStats.getJobId());
            if (job == null) continue;
            int detectorsCount = job.getAnalysisConfig().getDetectors().size();
            ModelSizeStats modelSizeStats = jobStats.getModelSizeStats();
            double modelSize = modelSizeStats == null ? 0.0 : (double)jobStats.getModelSizeStats().getModelBytes();
            allJobsForecastStats.merge(jobStats.getForecastStats());
            allJobsDetectorsStats.add((double)detectorsCount);
            allJobsModelSizeStats.add(modelSize);
            JobState jobState = jobStats.getState();
            jobCountByState.computeIfAbsent(jobState, js -> Counter.newCounter()).addAndGet(1L);
            detectorStatsByState.computeIfAbsent(jobState, js -> new StatsAccumulator()).add((double)detectorsCount);
            modelSizeStatsByState.computeIfAbsent(jobState, js -> new StatsAccumulator()).add(modelSize);
            forecastStatsByState.merge(jobState, jobStats.getForecastStats(), ForecastStats::merge);
            createdByByState.computeIfAbsent(jobState, js -> new HashMap()).compute(MachineLearningUsageTransportAction.jobCreatedBy(job), (k, v) -> v == null ? 1L : v + 1L);
        }
        jobsUsage.put("_all", MachineLearningUsageTransportAction.createJobUsageEntry(jobs.size(), allJobsDetectorsStats, allJobsModelSizeStats, allJobsForecastStats, allJobsCreatedBy));
        for (JobState jobState : jobCountByState.keySet()) {
            jobsUsage.put(jobState.name().toLowerCase(Locale.ROOT), MachineLearningUsageTransportAction.createJobUsageEntry(((Counter)jobCountByState.get(jobState)).get(), (StatsAccumulator)detectorStatsByState.get(jobState), (StatsAccumulator)modelSizeStatsByState.get(jobState), (ForecastStats)forecastStatsByState.get(jobState), (Map)createdByByState.get(jobState)));
        }
    }

    private static String jobCreatedBy(Job job) {
        Map customSettings = job.getCustomSettings();
        if (customSettings == null || !customSettings.containsKey("created_by")) {
            return "unknown";
        }
        return customSettings.get("created_by").toString().replaceAll("\\W", "_");
    }

    private static Map<String, Object> createJobUsageEntry(long count, StatsAccumulator detectorStats, StatsAccumulator modelSizeStats, ForecastStats forecastStats, Map<String, Long> createdBy) {
        HashMap<String, Object> usage = new HashMap<String, Object>();
        usage.put("count", count);
        usage.put("detectors", detectorStats.asMap());
        usage.put("model_size", modelSizeStats.asMap());
        usage.put("forecasts", forecastStats.asMap());
        usage.put("created_by", createdBy);
        return usage;
    }

    private static void addDatafeedsUsage(GetDatafeedsStatsAction.Response response, Map<String, Object> datafeedsUsage) {
        EnumMap<DatafeedState, Counter> datafeedCountByState = new EnumMap<DatafeedState, Counter>(DatafeedState.class);
        List datafeedsStats = response.getResponse().results();
        for (GetDatafeedsStatsAction.Response.DatafeedStats datafeedStats : datafeedsStats) {
            datafeedCountByState.computeIfAbsent(datafeedStats.getDatafeedState(), ds -> Counter.newCounter()).addAndGet(1L);
        }
        datafeedsUsage.put("_all", MachineLearningUsageTransportAction.createCountUsageEntry(response.getResponse().count()));
        for (DatafeedState datafeedState : datafeedCountByState.keySet()) {
            datafeedsUsage.put(datafeedState.name().toLowerCase(Locale.ROOT), MachineLearningUsageTransportAction.createCountUsageEntry(((Counter)datafeedCountByState.get(datafeedState)).get()));
        }
    }

    private static Map<String, Object> createCountUsageEntry(long count) {
        HashMap<String, Object> usage = new HashMap<String, Object>();
        usage.put("count", count);
        return usage;
    }

    private static void addDataFrameAnalyticsStatsUsage(GetDataFrameAnalyticsStatsAction.Response response, Map<String, Object> dataframeAnalyticsUsage) {
        EnumMap<DataFrameAnalyticsState, Counter> dataFrameAnalyticsStateCounterMap = new EnumMap<DataFrameAnalyticsState, Counter>(DataFrameAnalyticsState.class);
        StatsAccumulator memoryUsagePeakBytesStats = new StatsAccumulator();
        for (GetDataFrameAnalyticsStatsAction.Response.Stats stats : response.getResponse().results()) {
            dataFrameAnalyticsStateCounterMap.computeIfAbsent(stats.getState(), ds -> Counter.newCounter()).addAndGet(1L);
            MemoryUsage memoryUsage = stats.getMemoryUsage();
            if (memoryUsage == null || memoryUsage.getPeakUsageBytes() <= 0L) continue;
            memoryUsagePeakBytesStats.add((double)memoryUsage.getPeakUsageBytes());
        }
        dataframeAnalyticsUsage.put("memory_usage", Collections.singletonMap(MemoryUsage.PEAK_USAGE_BYTES.getPreferredName(), memoryUsagePeakBytesStats.asMap()));
        dataframeAnalyticsUsage.put("_all", MachineLearningUsageTransportAction.createCountUsageEntry(response.getResponse().count()));
        for (DataFrameAnalyticsState state : dataFrameAnalyticsStateCounterMap.keySet()) {
            dataframeAnalyticsUsage.put(state.name().toLowerCase(Locale.ROOT), MachineLearningUsageTransportAction.createCountUsageEntry(((Counter)dataFrameAnalyticsStateCounterMap.get(state)).get()));
        }
    }

    private static void addDataFrameAnalyticsUsage(GetDataFrameAnalyticsAction.Response response, Map<String, Object> dataframeAnalyticsUsage) {
        HashMap<String, Integer> perAnalysisTypeCounterMap = new HashMap<String, Integer>();
        for (DataFrameAnalyticsConfig config : response.getResources().results()) {
            int count = perAnalysisTypeCounterMap.computeIfAbsent(config.getAnalysis().getWriteableName(), k -> 0);
            perAnalysisTypeCounterMap.put(config.getAnalysis().getWriteableName(), ++count);
        }
        dataframeAnalyticsUsage.put("analysis_counts", perAnalysisTypeCounterMap);
    }

    private void addInferenceUsage(ActionListener<Map<String, Object>> listener) {
        if (this.machineLearningExtension.isDataFrameAnalyticsEnabled() || this.machineLearningExtension.isNlpEnabled()) {
            GetTrainedModelsAction.Request getModelsRequest = new GetTrainedModelsAction.Request("*", Collections.emptyList(), Collections.emptySet());
            getModelsRequest.setPageParams(new PageParams(0, 10000));
            this.client.execute((ActionType)GetTrainedModelsAction.INSTANCE, (ActionRequest)getModelsRequest, listener.delegateFailureAndWrap((delegate, getModelsResponse) -> {
                GetTrainedModelsStatsAction.Request getStatsRequest = new GetTrainedModelsStatsAction.Request("*");
                getStatsRequest.setPageParams(new PageParams(0, 10000));
                this.client.execute((ActionType)GetTrainedModelsStatsAction.INSTANCE, (ActionRequest)getStatsRequest, delegate.delegateFailureAndWrap((l, getStatsResponse) -> {
                    LinkedHashMap<String, Object> inferenceUsage = new LinkedHashMap<String, Object>();
                    MachineLearningUsageTransportAction.addInferenceIngestUsage(getStatsResponse, inferenceUsage);
                    MachineLearningUsageTransportAction.addTrainedModelStats(getModelsResponse, getStatsResponse, inferenceUsage);
                    MachineLearningUsageTransportAction.addDeploymentStats(getModelsResponse, getStatsResponse, inferenceUsage);
                    l.onResponse(inferenceUsage);
                }));
            }));
        } else {
            listener.onResponse(Map.of());
        }
    }

    private static void addDeploymentStats(GetTrainedModelsAction.Response modelsResponse, GetTrainedModelsStatsAction.Response statsResponse, Map<String, Object> inferenceUsage) {
        Map<String, String> taskTypes = modelsResponse.getResources().results().stream().collect(Collectors.toMap(TrainedModelConfig::getModelId, cfg -> cfg.getInferenceConfig().getName()));
        StatsAccumulator modelSizes = new StatsAccumulator();
        int deploymentsCount = 0;
        double avgTimeSum = 0.0;
        StatsAccumulator nodeDistribution = new StatsAccumulator();
        TreeMap<String, ModelStats> statsByModel = new TreeMap<String, ModelStats>();
        for (GetTrainedModelsStatsAction.Response.TrainedModelStats stats : statsResponse.getResources().results()) {
            AssignmentStats deploymentStats = stats.getDeploymentStats();
            if (deploymentStats == null) continue;
            ++deploymentsCount;
            TrainedModelSizeStats modelSizeStats = stats.getModelSizeStats();
            if (modelSizeStats != null) {
                modelSizes.add((double)modelSizeStats.getModelSizeBytes());
            }
            String modelId = deploymentStats.getModelId();
            String taskType = taskTypes.get(deploymentStats.getModelId());
            String mapKey = modelId + ":" + taskType;
            ModelStats modelStats = statsByModel.computeIfAbsent(mapKey, key -> new ModelStats(modelId, taskType));
            for (AssignmentStats.NodeStats nodeStats : deploymentStats.getNodeStats()) {
                long nodeInferenceCount = nodeStats.getInferenceCount().orElse(0L);
                avgTimeSum += nodeStats.getAvgInferenceTime().orElse(0.0) * (double)nodeInferenceCount;
                nodeDistribution.add((double)nodeInferenceCount);
                modelStats.update(nodeStats);
            }
        }
        inferenceUsage.put("deployments", Map.of("count", deploymentsCount, "time_ms", Map.of("avg", nodeDistribution.getTotal() == 0.0 ? 0.0 : avgTimeSum / nodeDistribution.getTotal()), "model_sizes_bytes", modelSizes.asMap(), "inference_counts", nodeDistribution.asMap(), "stats_by_model", statsByModel.values().stream().map(ModelStats::asMap).collect(Collectors.toList())));
    }

    private static void addTrainedModelStats(GetTrainedModelsAction.Response modelsResponse, GetTrainedModelsStatsAction.Response statsResponse, Map<String, Object> inferenceUsage) {
        List trainedModelConfigs = modelsResponse.getResources().results();
        Map statsToModelId = statsResponse.getResources().results().stream().collect(Collectors.toMap(GetTrainedModelsStatsAction.Response.TrainedModelStats::getModelId, Function.identity()));
        HashMap<String, Map> trainedModelsUsage = new HashMap<String, Map>();
        trainedModelsUsage.put("_all", MachineLearningUsageTransportAction.createCountUsageEntry(trainedModelConfigs.size()));
        StatsAccumulator estimatedOperations = new StatsAccumulator();
        StatsAccumulator estimatedMemoryUsageBytes = new StatsAccumulator();
        int createdByAnalyticsCount = 0;
        LinkedHashMap<String, Counter> inferenceConfigCounts = new LinkedHashMap<String, Counter>();
        int prepackagedCount = 0;
        for (TrainedModelConfig trainedModelConfig : trainedModelConfigs) {
            if (trainedModelConfig.getTags().contains("prepackaged")) {
                ++prepackagedCount;
                continue;
            }
            InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig();
            if (inferenceConfig != null) {
                inferenceConfigCounts.computeIfAbsent(inferenceConfig.getName(), s -> Counter.newCounter()).addAndGet(1L);
            }
            if (trainedModelConfig.getMetadata() != null && trainedModelConfig.getMetadata().containsKey("analytics_config")) {
                ++createdByAnalyticsCount;
            }
            estimatedOperations.add((double)trainedModelConfig.getEstimatedOperations());
            if (!statsToModelId.containsKey(trainedModelConfig.getModelId())) continue;
            estimatedMemoryUsageBytes.add((double)((GetTrainedModelsStatsAction.Response.TrainedModelStats)statsToModelId.get(trainedModelConfig.getModelId())).getModelSizeStats().getModelSizeBytes());
        }
        HashMap<String, Integer> counts = new HashMap<String, Integer>();
        counts.put("total", trainedModelConfigs.size());
        inferenceConfigCounts.forEach((configName, count) -> counts.put((String)configName, (Integer)count.get()));
        counts.put("prepackaged", prepackagedCount);
        counts.put("other", trainedModelConfigs.size() - createdByAnalyticsCount - prepackagedCount);
        trainedModelsUsage.put("count", counts);
        trainedModelsUsage.put(TrainedModelConfig.ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations.asMap());
        trainedModelsUsage.put(TrainedModelConfig.MODEL_SIZE_BYTES.getPreferredName(), estimatedMemoryUsageBytes.asMap());
        inferenceUsage.put("trained_models", trainedModelsUsage);
    }

    private static void addInferenceIngestUsage(GetTrainedModelsStatsAction.Response statsResponse, Map<String, Object> inferenceUsage) {
        int pipelineCount = 0;
        StatsAccumulator docCountStats = new StatsAccumulator();
        StatsAccumulator timeStats = new StatsAccumulator();
        StatsAccumulator failureStats = new StatsAccumulator();
        for (GetTrainedModelsStatsAction.Response.TrainedModelStats modelStats : statsResponse.getResources().results()) {
            pipelineCount += modelStats.getPipelineCount();
            modelStats.getIngestStats().processorStats().values().stream().flatMap(Collection::stream).forEach(processorStat -> {
                if (processorStat.name().equals("inference")) {
                    docCountStats.add((double)processorStat.stats().ingestCount());
                    timeStats.add((double)processorStat.stats().ingestTimeInMillis());
                    failureStats.add((double)processorStat.stats().ingestFailedCount());
                }
            });
        }
        Map ingestUsage = Maps.newMapWithExpectedSize((int)6);
        ingestUsage.put("pipelines", MachineLearningUsageTransportAction.createCountUsageEntry(pipelineCount));
        ingestUsage.put("num_docs_processed", MachineLearningUsageTransportAction.getMinMaxSumAsLongsFromStats(docCountStats));
        ingestUsage.put("time_ms", MachineLearningUsageTransportAction.getMinMaxSumAsLongsFromStats(timeStats));
        ingestUsage.put("num_failures", MachineLearningUsageTransportAction.getMinMaxSumAsLongsFromStats(failureStats));
        inferenceUsage.put("ingest_processors", Collections.singletonMap("_all", ingestUsage));
    }

    private static Map<String, Object> getMinMaxSumAsLongsFromStats(StatsAccumulator stats) {
        Map asMap = Maps.newMapWithExpectedSize((int)3);
        asMap.put("sum", Double.valueOf(stats.getTotal()).longValue());
        asMap.put("min", Double.valueOf(stats.getMin()).longValue());
        asMap.put("max", Double.valueOf(stats.getMax()).longValue());
        return asMap;
    }

    private static int mlNodeCount(ClusterState clusterState) {
        int mlNodeCount = 0;
        for (DiscoveryNode node : clusterState.getNodes()) {
            if (!MachineLearning.isMlNode(node)) continue;
            ++mlNodeCount;
        }
        return mlNodeCount;
    }

    private static class ModelStats {
        private final String modelId;
        private final String taskType;
        private final StatsAccumulator inferenceCounts = new StatsAccumulator();
        private Instant lastAccess;

        ModelStats(String modelId, String taskType) {
            this.modelId = modelId;
            this.taskType = taskType;
        }

        void update(AssignmentStats.NodeStats stats) {
            this.inferenceCounts.add((double)stats.getInferenceCount().orElse(0L).longValue());
            if (stats.getLastAccess() != null && (this.lastAccess == null || stats.getLastAccess().isAfter(this.lastAccess))) {
                this.lastAccess = stats.getLastAccess();
            }
        }

        Map<String, Object> asMap() {
            HashMap<String, Object> result = new HashMap<String, Object>();
            result.put("model_id", this.modelId);
            result.put("task_type", this.taskType);
            result.put("inference_counts", this.inferenceCounts.asMap());
            if (this.lastAccess != null) {
                result.put("last_access", this.lastAccess.toString());
            }
            return result;
        }
    }
}

