/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.autoscaling;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.autoscaling.MlAutoscalingStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.MlAutoscalingContext;
import org.elasticsearch.xpack.ml.job.JobNodeSelector;
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.utils.MlProcessors;
import org.elasticsearch.xpack.ml.utils.NativeMemoryCalculator;

public final class MlAutoscalingResourceTracker {
    private static final Logger logger = LogManager.getLogger(MlAutoscalingResourceTracker.class);

    private MlAutoscalingResourceTracker() {
    }

    public static void getMlAutoscalingStats(ClusterState clusterState, ClusterSettings clusterSettings, MlMemoryTracker mlMemoryTracker, Settings settings, ActionListener<MlAutoscalingStats> listener) {
        Map<String, Long> nodeSizeByMlNode = clusterState.nodes().stream().filter(node -> node.getRoles().contains(DiscoveryNodeRole.ML_ROLE)).collect(Collectors.toMap(DiscoveryNode::getId, node -> NodeLoadDetector.getNodeSize(node).orElse(0L)));
        String firstMlNode = nodeSizeByMlNode.size() > 0 ? nodeSizeByMlNode.keySet().iterator().next() : null;
        long modelMemoryAvailableFirstNode = firstMlNode != null ? NativeMemoryCalculator.allowedBytesForMl(clusterState.nodes().get(firstMlNode), settings).orElse(0L) : 0L;
        int processorsAvailableFirstNode = firstMlNode != null ? MlProcessors.get(clusterState.nodes().get(firstMlNode), (Integer)clusterSettings.get(MachineLearning.ALLOCATED_PROCESSORS_SCALE)).roundUp() : 0;
        MlDummyAutoscalingEntity mlDummyAutoscalingEntity = new MlDummyAutoscalingEntity(Math.max(0L, ((ByteSizeValue)MachineLearning.DUMMY_ENTITY_MEMORY.get(settings)).getBytes()), (Integer)MachineLearning.DUMMY_ENTITY_PROCESSORS.get(settings));
        int maxOpenJobsPerNode = (Integer)MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(settings);
        MlAutoscalingResourceTracker.getMemoryAndProcessors(new MlAutoscalingContext(clusterState), mlMemoryTracker, nodeSizeByMlNode, modelMemoryAvailableFirstNode, processorsAvailableFirstNode, maxOpenJobsPerNode, mlDummyAutoscalingEntity, listener);
    }

    static void getMemoryAndProcessors(MlAutoscalingContext autoscalingContext, MlMemoryTracker mlMemoryTracker, Map<String, Long> nodeSizeByMlNode, long perNodeAvailableModelMemoryInBytes, int perNodeAvailableProcessors, int maxOpenJobsPerNode, MlDummyAutoscalingEntity dummyAutoscalingEntity, ActionListener<MlAutoscalingStats> listener) {
        Long jobMemory;
        String jobId;
        MemoryTrackedTaskState state;
        HashMap<String, List<MlJobRequirements>> perNodeModelMemoryInBytes = new HashMap<String, List<MlJobRequirements>>();
        int numberMlNodes = nodeSizeByMlNode.size();
        long perNodeMemoryInBytes = nodeSizeByMlNode.values().stream().distinct().count() != 1L ? 0L : nodeSizeByMlNode.values().iterator().next();
        long modelMemoryBytesSum = 0L;
        long extraSingleNodeModelMemoryInBytes = 0L;
        long extraModelMemoryInBytes = 0L;
        int extraSingleNodeProcessors = 0;
        int extraProcessors = 0;
        int processorsSum = 0;
        logger.debug("getting ml resources, found [{}] ad jobs, [{}] dfa jobs and [{}] inference deployments", (Object)autoscalingContext.anomalyDetectionTasks.size(), (Object)autoscalingContext.dataframeAnalyticsTasks.size(), (Object)autoscalingContext.modelAssignments.size());
        int minNodes = 0;
        for (PersistentTasksCustomMetadata.PersistentTask<?> persistentTask : autoscalingContext.anomalyDetectionTasks) {
            state = MlTasks.getMemoryTrackedTaskState(persistentTask);
            if (state != null && !state.consumesMemory()) continue;
            jobId = ((OpenJobAction.JobParams)persistentTask.getParams()).getJobId();
            jobMemory = mlMemoryTracker.getAnomalyDetectorJobMemoryRequirement(jobId);
            if (jobMemory == null) {
                logger.debug("could not find memory requirement for job [{}], returning no-scale", (Object)jobId);
                listener.onResponse((Object)MlAutoscalingResourceTracker.noScaleStats(numberMlNodes));
                return;
            }
            minNodes = 1;
            if (JobNodeSelector.AWAITING_LAZY_ASSIGNMENT.equals((Object)persistentTask.getAssignment())) {
                logger.debug("job [{}] lacks assignment , memory required [{}]", (Object)jobId, (Object)jobMemory);
                extraSingleNodeModelMemoryInBytes = Math.max(extraSingleNodeModelMemoryInBytes, jobMemory);
                extraModelMemoryInBytes += jobMemory.longValue();
                continue;
            }
            logger.debug("job [{}] assigned to [{}], memory required [{}]", (Object)jobId, (Object)persistentTask.getAssignment(), (Object)jobMemory);
            modelMemoryBytesSum += jobMemory.longValue();
            perNodeModelMemoryInBytes.computeIfAbsent(persistentTask.getExecutorNode(), k -> new ArrayList()).add(MlJobRequirements.of(jobMemory, 0));
        }
        for (PersistentTasksCustomMetadata.PersistentTask<?> persistentTask : autoscalingContext.dataframeAnalyticsTasks) {
            state = MlTasks.getMemoryTrackedTaskState(persistentTask);
            if (state != null && !state.consumesMemory()) continue;
            jobId = MlTasks.dataFrameAnalyticsId((String)persistentTask.getId());
            jobMemory = mlMemoryTracker.getDataFrameAnalyticsJobMemoryRequirement(jobId);
            if (jobMemory == null) {
                logger.debug("could not find memory requirement for job [{}], returning no-scale", (Object)jobId);
                listener.onResponse((Object)MlAutoscalingResourceTracker.noScaleStats(numberMlNodes));
                return;
            }
            minNodes = 1;
            if (JobNodeSelector.AWAITING_LAZY_ASSIGNMENT.equals((Object)persistentTask.getAssignment())) {
                logger.debug("dfa job [{}] lacks assignment , memory required [{}]", (Object)jobId, (Object)jobMemory);
                extraSingleNodeModelMemoryInBytes = Math.max(extraSingleNodeModelMemoryInBytes, jobMemory);
                extraModelMemoryInBytes += jobMemory.longValue();
                continue;
            }
            logger.debug("dfa job [{}] assigned to [{}], memory required [{}]", (Object)jobId, (Object)persistentTask.getAssignment(), (Object)jobMemory);
            modelMemoryBytesSum += jobMemory.longValue();
            perNodeModelMemoryInBytes.computeIfAbsent(persistentTask.getExecutorNode(), k -> new ArrayList()).add(MlJobRequirements.of(jobMemory, 0));
        }
        for (Map.Entry entry : autoscalingContext.modelAssignments.entrySet()) {
            TrainedModelAssignment assignment = (TrainedModelAssignment)entry.getValue();
            int numberOfAllocations = assignment.getTaskParams().getNumberOfAllocations();
            int numberOfThreadsPerAllocation = assignment.getTaskParams().getThreadsPerAllocation();
            long estimatedMemoryUsage = assignment.getTaskParams().estimateMemoryUsageBytes();
            if (AssignmentState.STARTING.equals((Object)assignment.getAssignmentState()) && assignment.getNodeRoutingTable().isEmpty()) {
                logger.debug(() -> Strings.format((String)"trained model [%s] lacks assignment , memory required [%d]", (Object[])new Object[]{modelAssignment.getKey(), estimatedMemoryUsage}));
                extraSingleNodeModelMemoryInBytes = Math.max(extraSingleNodeModelMemoryInBytes, estimatedMemoryUsage);
                extraModelMemoryInBytes += estimatedMemoryUsage;
                if (!Priority.LOW.equals((Object)((TrainedModelAssignment)entry.getValue()).getTaskParams().getPriority())) {
                    extraSingleNodeProcessors = Math.max(extraSingleNodeProcessors, numberOfThreadsPerAllocation);
                    extraProcessors += numberOfAllocations * numberOfThreadsPerAllocation;
                }
            } else {
                if (assignment.getNodeRoutingTable().values().stream().allMatch(r -> !r.getState().consumesMemory())) continue;
                logger.debug(() -> Strings.format((String)"trained model [%s] assigned to [%s], memory required [%d]", (Object[])new Object[]{modelAssignment.getKey(), org.elasticsearch.common.Strings.arrayToCommaDelimitedString((Object[])((TrainedModelAssignment)modelAssignment.getValue()).getStartedNodes()), estimatedMemoryUsage}));
                modelMemoryBytesSum += estimatedMemoryUsage;
                processorsSum += numberOfAllocations * numberOfThreadsPerAllocation;
                for (String node : ((TrainedModelAssignment)entry.getValue()).getNodeRoutingTable().keySet()) {
                    perNodeModelMemoryInBytes.computeIfAbsent(node, k -> new ArrayList()).add(MlJobRequirements.of(estimatedMemoryUsage, Priority.LOW.equals((Object)((TrainedModelAssignment)entry.getValue()).getTaskParams().getPriority()) ? 0 : numberOfThreadsPerAllocation));
                }
            }
            minNodes = Math.min(3, Math.max(minNodes, numberOfAllocations));
        }
        if (!MlAutoscalingResourceTracker.dummyEntityFitsOnLeastLoadedNode(perNodeModelMemoryInBytes, perNodeAvailableModelMemoryInBytes, perNodeAvailableProcessors, dummyAutoscalingEntity)) {
            logger.info("Scaling up due to dummy entity: dummyEntityMemory: [{}], dummyEntityProcessors: [{}]", (Object)dummyAutoscalingEntity.memory, (Object)dummyAutoscalingEntity.processors);
            modelMemoryBytesSum += dummyAutoscalingEntity.memory;
            processorsSum += dummyAutoscalingEntity.processors;
        }
        long removeNodeMemoryInBytes = 0L;
        if (perNodeMemoryInBytes > 0L && perNodeAvailableModelMemoryInBytes > 0L && extraModelMemoryInBytes == 0L && extraProcessors == 0 && modelMemoryBytesSum <= perNodeMemoryInBytes * (long)(numberMlNodes - 1) && minNodes < numberMlNodes && (perNodeModelMemoryInBytes.size() < numberMlNodes || MlAutoscalingResourceTracker.checkIfOneNodeCouldBeRemoved(perNodeModelMemoryInBytes, perNodeAvailableModelMemoryInBytes, perNodeAvailableProcessors, maxOpenJobsPerNode, dummyAutoscalingEntity))) {
            removeNodeMemoryInBytes = perNodeMemoryInBytes;
        }
        listener.onResponse((Object)new MlAutoscalingStats(numberMlNodes, perNodeMemoryInBytes, modelMemoryBytesSum, processorsSum, minNodes, extraSingleNodeModelMemoryInBytes, extraSingleNodeProcessors, extraModelMemoryInBytes, extraProcessors, removeNodeMemoryInBytes, MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes()));
    }

    static boolean dummyEntityFitsOnLeastLoadedNode(Map<String, List<MlJobRequirements>> perNodeJobRequirements, long perNodeMemoryInBytes, int perNodeProcessors, MlDummyAutoscalingEntity dummyAutoscalingEntity) {
        if (dummyAutoscalingEntity.processors == 0 && dummyAutoscalingEntity.memory == 0L) {
            return true;
        }
        if (perNodeJobRequirements.size() < 1) {
            return false;
        }
        Optional<MlJobRequirements> leastLoadedNodeRequirements = perNodeJobRequirements.values().stream().map(value -> value.stream().reduce(MlJobRequirements.of(0L, 0, 0), (subtotal, element) -> MlJobRequirements.of(subtotal.memory + element.memory, subtotal.processors + element.processors, subtotal.jobs + element.jobs))).min(Comparator.comparingLong(value -> value.memory));
        assert (leastLoadedNodeRequirements.isPresent());
        assert (leastLoadedNodeRequirements.get().memory >= 0L);
        assert (leastLoadedNodeRequirements.get().processors >= 0);
        if (leastLoadedNodeRequirements.get().memory + dummyAutoscalingEntity.memory > perNodeMemoryInBytes) {
            return false;
        }
        return leastLoadedNodeRequirements.get().processors + dummyAutoscalingEntity.processors <= perNodeProcessors;
    }

    public static MlAutoscalingStats noScaleStats(ClusterState clusterState) {
        int numberMlNodes = (int)clusterState.nodes().stream().filter(node -> node.getRoles().contains(DiscoveryNodeRole.ML_ROLE)).count();
        return MlAutoscalingResourceTracker.noScaleStats(numberMlNodes);
    }

    private static MlAutoscalingStats noScaleStats(int numberMlNodes) {
        return new MlAutoscalingStats(numberMlNodes, 0L, 0L, 0, Math.min(3, numberMlNodes), 0L, 0, 0L, 0, 0L, MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes());
    }

    static boolean checkIfOneNodeCouldBeRemoved(Map<String, List<MlJobRequirements>> perNodeJobRequirements, long perNodeMemoryInBytes, int perNodeProcessors, int maxOpenJobsPerNode, MlDummyAutoscalingEntity dummyAutoscalingEntity) {
        if (perNodeJobRequirements.size() <= 1) {
            return false;
        }
        Map<String, MlJobRequirements> perNodeMlJobRequirementSum = perNodeJobRequirements.entrySet().stream().map(entry -> Tuple.tuple((Object)((String)entry.getKey()), (Object)((List)entry.getValue()).stream().reduce(MlJobRequirements.of(0L, 0, 0), (subtotal, element) -> MlJobRequirements.of(subtotal.memory + element.memory, subtotal.processors + element.processors, subtotal.jobs + element.jobs)))).collect(Collectors.toMap(Tuple::v1, Tuple::v2));
        Optional<Map.Entry> leastLoadedNodeAndMemoryUsage = perNodeMlJobRequirementSum.entrySet().stream().min(Comparator.comparingLong(entry -> ((MlJobRequirements)entry.getValue()).memory));
        if (!leastLoadedNodeAndMemoryUsage.isPresent()) {
            return false;
        }
        assert (((MlJobRequirements)leastLoadedNodeAndMemoryUsage.get().getValue()).memory >= 0L);
        String candidateNode = (String)leastLoadedNodeAndMemoryUsage.get().getKey();
        List<MlJobRequirements> candidateJobRequirements = perNodeJobRequirements.get(candidateNode);
        if (dummyAutoscalingEntity.memory > 0L || dummyAutoscalingEntity.processors > 0) {
            candidateJobRequirements = new ArrayList<MlJobRequirements>(candidateJobRequirements);
            candidateJobRequirements.add(MlJobRequirements.of(dummyAutoscalingEntity.memory, dummyAutoscalingEntity.processors));
        }
        perNodeMlJobRequirementSum.remove(candidateNode);
        return MlAutoscalingResourceTracker.checkIfJobsCanBeMovedInLeastEfficientWay(candidateJobRequirements, perNodeMlJobRequirementSum, perNodeMemoryInBytes, perNodeProcessors, maxOpenJobsPerNode) == 0L;
    }

    static long checkIfJobsCanBeMovedInLeastEfficientWay(List<MlJobRequirements> candidateJobRequirements, Map<String, MlJobRequirements> perNodeMlJobRequirementsSum, long perNodeMemoryInBytes, int perNodeProcessors, int maxOpenJobsPerNode) {
        if (candidateJobRequirements.size() == 0) {
            return 0L;
        }
        List<MlJobRequirements> candidateNodeMemoryListSorted = candidateJobRequirements.stream().sorted(Comparator.comparingLong(MlJobRequirements::memory)).toList();
        long candidateNodeMemorySum = candidateJobRequirements.stream().mapToLong(MlJobRequirements::memory).sum();
        if (perNodeMlJobRequirementsSum.size() == 0) {
            return candidateNodeMemorySum;
        }
        PriorityQueue nodesWithSpareCapacitySortedByMemory = perNodeMlJobRequirementsSum.values().stream().filter(e -> e.jobs < maxOpenJobsPerNode).collect(Collectors.toCollection(() -> new PriorityQueue(perNodeMlJobRequirementsSum.size(), (c1, c2) -> {
            if (c1.memory == c2.memory) {
                return Integer.compare(c1.processors, c2.processors);
            }
            return Long.compare(c1.memory, c2.memory);
        })));
        for (MlJobRequirements jobRequirement : candidateNodeMemoryListSorted) {
            assert (jobRequirement.jobs == 1);
            if (jobRequirement.processors == 0) {
                MlJobRequirements nodeWithSpareCapacity = (MlJobRequirements)nodesWithSpareCapacitySortedByMemory.poll();
                long memoryAfterAddingJobMemory = nodeWithSpareCapacity.memory + jobRequirement.memory;
                if (memoryAfterAddingJobMemory > perNodeMemoryInBytes) break;
                if (nodeWithSpareCapacity.jobs + jobRequirement.jobs < maxOpenJobsPerNode) {
                    nodesWithSpareCapacitySortedByMemory.add(MlJobRequirements.of(memoryAfterAddingJobMemory, nodeWithSpareCapacity.processors, nodeWithSpareCapacity.jobs + jobRequirement.jobs));
                }
                candidateNodeMemorySum -= jobRequirement.memory;
            } else {
                ArrayList<MlJobRequirements> stash = new ArrayList<MlJobRequirements>();
                boolean foundNodeThatCanTakeTheJob = false;
                while (!nodesWithSpareCapacitySortedByMemory.isEmpty()) {
                    MlJobRequirements nodeWithSpareCapacity = (MlJobRequirements)nodesWithSpareCapacitySortedByMemory.poll();
                    long memoryAfterAddingJobMemory = nodeWithSpareCapacity.memory + jobRequirement.memory;
                    if (memoryAfterAddingJobMemory > perNodeMemoryInBytes) break;
                    if (nodeWithSpareCapacity.processors + jobRequirement.processors <= perNodeProcessors) {
                        if (nodeWithSpareCapacity.jobs + jobRequirement.jobs < maxOpenJobsPerNode) {
                            nodesWithSpareCapacitySortedByMemory.add(MlJobRequirements.of(memoryAfterAddingJobMemory, nodeWithSpareCapacity.processors + jobRequirement.processors, nodeWithSpareCapacity.jobs + jobRequirement.jobs));
                        }
                        candidateNodeMemorySum -= jobRequirement.memory;
                        foundNodeThatCanTakeTheJob = true;
                        break;
                    }
                    stash.add(nodeWithSpareCapacity);
                }
                if (!foundNodeThatCanTakeTheJob) break;
                nodesWithSpareCapacitySortedByMemory.addAll(stash);
            }
            if (!nodesWithSpareCapacitySortedByMemory.isEmpty()) continue;
            break;
        }
        return candidateNodeMemorySum;
    }

    record MlDummyAutoscalingEntity(long memory, int processors) {
        static MlDummyAutoscalingEntity of(long memory, int processors) {
            return new MlDummyAutoscalingEntity(memory, processors);
        }
    }

    record MlJobRequirements(long memory, int processors, int jobs) {
        static MlJobRequirements of(long memory, int processors, int jobs) {
            return new MlJobRequirements(memory, processors, jobs);
        }

        static MlJobRequirements of(long memory, int processors) {
            return new MlJobRequirements(memory, processors, 1);
        }
    }
}

