/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.common;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.inference.common.WordBoundaryChunker;

public class EmbeddingRequestChunker {
    public static final int DEFAULT_WORDS_PER_CHUNK = 250;
    public static final int DEFAULT_CHUNK_OVERLAP = 100;
    private final List<BatchRequest> batchedRequests = new ArrayList<BatchRequest>();
    private final AtomicInteger resultCount = new AtomicInteger();
    private final int maxNumberOfInputsPerBatch;
    private final int wordsPerChunk;
    private final int chunkOverlap;
    private List<List<String>> chunkedInputs;
    private List<AtomicArray<List<TextEmbeddingResults.Embedding>>> results;
    private AtomicArray<ErrorChunkedInferenceResults> errors;
    private ActionListener<List<ChunkedInferenceServiceResults>> finalListener;

    public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch) {
        this.maxNumberOfInputsPerBatch = maxNumberOfInputsPerBatch;
        this.wordsPerChunk = 250;
        this.chunkOverlap = 100;
        this.splitIntoBatchedRequests(inputs);
    }

    public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap) {
        this.maxNumberOfInputsPerBatch = maxNumberOfInputsPerBatch;
        this.wordsPerChunk = wordsPerChunk;
        this.chunkOverlap = chunkOverlap;
        this.splitIntoBatchedRequests(inputs);
    }

    private void splitIntoBatchedRequests(List<String> inputs) {
        WordBoundaryChunker chunker = new WordBoundaryChunker();
        this.chunkedInputs = new ArrayList<List<String>>(inputs.size());
        this.results = new ArrayList<AtomicArray<List<TextEmbeddingResults.Embedding>>>(inputs.size());
        this.errors = new AtomicArray(inputs.size());
        for (int i = 0; i < inputs.size(); ++i) {
            List<String> chunks = chunker.chunk(inputs.get(i), this.wordsPerChunk, this.chunkOverlap);
            int numberOfSubBatches = this.addToBatches(chunks, i);
            this.results.add((AtomicArray<List<TextEmbeddingResults.Embedding>>)new AtomicArray(numberOfSubBatches));
            this.chunkedInputs.add(chunks);
        }
    }

    private int addToBatches(List<String> chunks, int inputIndex) {
        int toAdd;
        BatchRequest lastBatch;
        if (this.batchedRequests.isEmpty()) {
            lastBatch = new BatchRequest(new ArrayList<SubBatch>());
            this.batchedRequests.add(lastBatch);
        } else {
            lastBatch = this.batchedRequests.get(this.batchedRequests.size() - 1);
        }
        int freeSpace = this.maxNumberOfInputsPerBatch - lastBatch.size();
        assert (freeSpace >= 0);
        int chunkIndex = 0;
        if (freeSpace > 0) {
            int toAdd2 = Math.min(freeSpace, chunks.size());
            lastBatch.addSubBatch(new SubBatch(chunks.subList(0, toAdd2), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd2)));
        }
        for (int start = freeSpace; start < chunks.size(); start += toAdd) {
            toAdd = Math.min(this.maxNumberOfInputsPerBatch, chunks.size() - start);
            BatchRequest batch = new BatchRequest(new ArrayList<SubBatch>());
            batch.addSubBatch(new SubBatch(chunks.subList(start, start + toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)));
            this.batchedRequests.add(batch);
        }
        return chunkIndex;
    }

    public List<BatchRequestAndListener> batchRequestsWithListeners(ActionListener<List<ChunkedInferenceServiceResults>> finalListener) {
        this.finalListener = finalListener;
        int numberOfRequests = this.batchedRequests.size();
        ArrayList<BatchRequestAndListener> requests = new ArrayList<BatchRequestAndListener>(numberOfRequests);
        for (BatchRequest batch : this.batchedRequests) {
            requests.add(new BatchRequestAndListener(batch, new DebatchingListener(batch.subBatches().stream().map(SubBatch::positions).collect(Collectors.toList()), numberOfRequests)));
        }
        return requests;
    }

    public record BatchRequest(List<SubBatch> subBatches) {
        public int size() {
            return this.subBatches.stream().mapToInt(SubBatch::size).sum();
        }

        public void addSubBatch(SubBatch sb) {
            this.subBatches.add(sb);
        }

        public List<String> inputs() {
            return this.subBatches.stream().flatMap(s -> s.requests().stream()).collect(Collectors.toList());
        }
    }

    record SubBatch(List<String> requests, SubBatchPositionsAndCount positions) {
        public int size() {
            return this.requests.size();
        }
    }

    record SubBatchPositionsAndCount(int inputIndex, int chunkIndex, int embeddingCount) {
    }

    public record BatchRequestAndListener(BatchRequest batch, ActionListener<InferenceServiceResults> listener) {
    }

    private class DebatchingListener
    implements ActionListener<InferenceServiceResults> {
        private final List<SubBatchPositionsAndCount> positions;
        private final int totalNumberOfRequests;

        DebatchingListener(List<SubBatchPositionsAndCount> positions, int totalNumberOfRequests) {
            this.positions = positions;
            this.totalNumberOfRequests = totalNumberOfRequests;
        }

        public void onResponse(InferenceServiceResults inferenceServiceResults) {
            if (inferenceServiceResults instanceof TextEmbeddingResults) {
                TextEmbeddingResults textEmbeddingResults = (TextEmbeddingResults)inferenceServiceResults;
                int numRequests = this.positions.stream().mapToInt(SubBatchPositionsAndCount::embeddingCount).sum();
                if (numRequests != textEmbeddingResults.embeddings().size()) {
                    this.onFailure((Exception)new ElasticsearchStatusException("Error the number of embedding responses [{}] does not equal the number of requests [{}]", RestStatus.BAD_REQUEST, new Object[]{textEmbeddingResults.embeddings().size(), numRequests}));
                    return;
                }
                int start = 0;
                for (SubBatchPositionsAndCount pos : this.positions) {
                    EmbeddingRequestChunker.this.results.get(pos.inputIndex()).setOnce(pos.chunkIndex(), textEmbeddingResults.embeddings().subList(start, start + pos.embeddingCount()));
                    start += pos.embeddingCount();
                }
            }
            if (EmbeddingRequestChunker.this.resultCount.incrementAndGet() == this.totalNumberOfRequests) {
                this.sendResponse();
            }
        }

        public void onFailure(Exception e) {
            ErrorChunkedInferenceResults errorResult = new ErrorChunkedInferenceResults(e);
            for (SubBatchPositionsAndCount pos : this.positions) {
                EmbeddingRequestChunker.this.errors.setOnce(pos.inputIndex(), (Object)errorResult);
            }
            if (EmbeddingRequestChunker.this.resultCount.incrementAndGet() == this.totalNumberOfRequests) {
                this.sendResponse();
            }
        }

        private void sendResponse() {
            ArrayList<Object> response = new ArrayList<Object>(EmbeddingRequestChunker.this.chunkedInputs.size());
            for (int i = 0; i < EmbeddingRequestChunker.this.chunkedInputs.size(); ++i) {
                if (EmbeddingRequestChunker.this.errors.get(i) != null) {
                    response.add((ChunkedInferenceServiceResults)EmbeddingRequestChunker.this.errors.get(i));
                    continue;
                }
                response.add(this.merge(EmbeddingRequestChunker.this.chunkedInputs.get(i), EmbeddingRequestChunker.this.results.get(i)));
            }
            EmbeddingRequestChunker.this.finalListener.onResponse(response);
        }

        private ChunkedTextEmbeddingFloatResults merge(List<String> chunks, AtomicArray<List<TextEmbeddingResults.Embedding>> debatchedResults) {
            ArrayList all = new ArrayList();
            for (int i = 0; i < debatchedResults.length(); ++i) {
                List subBatch = (List)debatchedResults.get(i);
                all.addAll(subBatch);
            }
            assert (chunks.size() == all.size());
            ArrayList<ChunkedTextEmbeddingFloatResults.EmbeddingChunk> embeddingChunks = new ArrayList<ChunkedTextEmbeddingFloatResults.EmbeddingChunk>();
            for (int i = 0; i < chunks.size(); ++i) {
                embeddingChunks.add(new ChunkedTextEmbeddingFloatResults.EmbeddingChunk(chunks.get(i), ((TextEmbeddingResults.Embedding)all.get(i)).values()));
            }
            return new ChunkedTextEmbeddingFloatResults(embeddingChunks);
        }
    }
}

