/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.function.scalar.date;

import java.time.Instant;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.temporal.ChronoUnit;
import java.time.temporal.IsoFields;
import java.time.temporal.Temporal;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateDiffConstantEvaluator;
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateDiffEvaluator;
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTimeField;
import org.elasticsearch.xpack.ql.InvalidArgumentException;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.TypeResolutions;
import org.elasticsearch.xpack.ql.tree.Node;
import org.elasticsearch.xpack.ql.tree.NodeInfo;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypeConverter;
import org.elasticsearch.xpack.ql.type.DataTypes;

public class DateDiff
extends EsqlScalarFunction {
    public static final ZoneId UTC = ZoneId.of("Z");
    private final Expression unit;
    private final Expression startTimestamp;
    private final Expression endTimestamp;

    @FunctionInfo(returnType={"integer"}, description="Subtract 2 dates and return their difference in multiples of a unit specified in the 1st argument")
    public DateDiff(Source source, @Param(name="unit", type={"keyword", "text"}, description="A valid date unit") Expression unit, @Param(name="startTimestamp", type={"date"}, description="A string representing a start timestamp") Expression startTimestamp, @Param(name="endTimestamp", type={"date"}, description="A string representing an end timestamp") Expression endTimestamp) {
        super(source, List.of(unit, startTimestamp, endTimestamp));
        this.unit = unit;
        this.startTimestamp = startTimestamp;
        this.endTimestamp = endTimestamp;
    }

    static int process(Part datePartFieldUnit, long startTimestamp, long endTimestamp) throws IllegalArgumentException {
        ZonedDateTime zdtStart = ZonedDateTime.ofInstant(Instant.ofEpochMilli(startTimestamp), UTC);
        ZonedDateTime zdtEnd = ZonedDateTime.ofInstant(Instant.ofEpochMilli(endTimestamp), UTC);
        return datePartFieldUnit.diff(zdtStart, zdtEnd);
    }

    static int process(BytesRef unit, long startTimestamp, long endTimestamp) throws IllegalArgumentException {
        return DateDiff.process(Part.resolve(unit.utf8ToString()), startTimestamp, endTimestamp);
    }

    @Override
    public EvalOperator.ExpressionEvaluator.Factory toEvaluator(Function<Expression, EvalOperator.ExpressionEvaluator.Factory> toEvaluator) {
        EvalOperator.ExpressionEvaluator.Factory startTimestampEvaluator = toEvaluator.apply(this.startTimestamp);
        EvalOperator.ExpressionEvaluator.Factory endTimestampEvaluator = toEvaluator.apply(this.endTimestamp);
        if (this.unit.foldable()) {
            try {
                Part datePartField = Part.resolve(((BytesRef)this.unit.fold()).utf8ToString());
                return new DateDiffConstantEvaluator.Factory(this.source(), datePartField, startTimestampEvaluator, endTimestampEvaluator);
            }
            catch (IllegalArgumentException e) {
                throw new InvalidArgumentException("invalid unit format for [{}]: {}", new Object[]{this.sourceText(), e.getMessage()});
            }
        }
        EvalOperator.ExpressionEvaluator.Factory unitEvaluator = toEvaluator.apply(this.unit);
        return new DateDiffEvaluator.Factory(this.source(), unitEvaluator, startTimestampEvaluator, endTimestampEvaluator);
    }

    protected Expression.TypeResolution resolveType() {
        if (!this.childrenResolved()) {
            return new Expression.TypeResolution("Unresolved children");
        }
        Expression.TypeResolution resolution = TypeResolutions.isString((Expression)this.unit, (String)this.sourceText(), (TypeResolutions.ParamOrdinal)TypeResolutions.ParamOrdinal.FIRST).and(TypeResolutions.isDate((Expression)this.startTimestamp, (String)this.sourceText(), (TypeResolutions.ParamOrdinal)TypeResolutions.ParamOrdinal.SECOND)).and(TypeResolutions.isDate((Expression)this.endTimestamp, (String)this.sourceText(), (TypeResolutions.ParamOrdinal)TypeResolutions.ParamOrdinal.THIRD));
        if (resolution.unresolved()) {
            return resolution;
        }
        return Expression.TypeResolution.TYPE_RESOLVED;
    }

    public boolean foldable() {
        return this.unit.foldable() && this.startTimestamp.foldable() && this.endTimestamp.foldable();
    }

    public DataType dataType() {
        return DataTypes.INTEGER;
    }

    public Expression replaceChildren(List<Expression> newChildren) {
        return new DateDiff(this.source(), newChildren.get(0), newChildren.get(1), newChildren.get(2));
    }

    protected NodeInfo<? extends Expression> info() {
        return NodeInfo.create((Node)this, DateDiff::new, (Object)((Expression)this.children().get(0)), (Object)((Expression)this.children().get(1)), (Object)((Expression)this.children().get(2)));
    }

    public static enum Part implements DateTimeField
    {
        YEAR((start, end) -> end.getYear() - start.getYear(), "years", "yyyy", "yy"),
        QUARTER((start, end) -> DataTypeConverter.safeToInt((long)IsoFields.QUARTER_YEARS.between((Temporal)start, (Temporal)end)), "quarters", "qq", "q"),
        MONTH((start, end) -> DataTypeConverter.safeToInt((long)ChronoUnit.MONTHS.between((Temporal)start, (Temporal)end)), "months", "mm", "m"),
        DAYOFYEAR((start, end) -> DataTypeConverter.safeToInt((long)ChronoUnit.DAYS.between((Temporal)start, (Temporal)end)), "dy", "y"),
        DAY(DAYOFYEAR::diff, "days", "dd", "d"),
        WEEK((start, end) -> DataTypeConverter.safeToInt((long)ChronoUnit.WEEKS.between((Temporal)start, (Temporal)end)), "weeks", "wk", "ww"),
        WEEKDAY(DAYOFYEAR::diff, "weekdays", "dw"),
        HOUR((start, end) -> DataTypeConverter.safeToInt((long)ChronoUnit.HOURS.between((Temporal)start, (Temporal)end)), "hours", "hh"),
        MINUTE((start, end) -> DataTypeConverter.safeToInt((long)ChronoUnit.MINUTES.between((Temporal)start, (Temporal)end)), "minutes", "mi", "n"),
        SECOND((start, end) -> DataTypeConverter.safeToInt((long)ChronoUnit.SECONDS.between((Temporal)start, (Temporal)end)), "seconds", "ss", "s"),
        MILLISECOND((start, end) -> DataTypeConverter.safeToInt((long)ChronoUnit.MILLIS.between((Temporal)start, (Temporal)end)), "milliseconds", "ms"),
        MICROSECOND((start, end) -> DataTypeConverter.safeToInt((long)ChronoUnit.MICROS.between((Temporal)start, (Temporal)end)), "microseconds", "mcs"),
        NANOSECOND((start, end) -> DataTypeConverter.safeToInt((long)ChronoUnit.NANOS.between((Temporal)start, (Temporal)end)), "nanoseconds", "ns");

        private static final Map<String, Part> NAME_TO_PART;
        private final BiFunction<ZonedDateTime, ZonedDateTime, Integer> diffFunction;
        private final Set<String> aliases;

        private Part(BiFunction<ZonedDateTime, ZonedDateTime, Integer> diffFunction, String ... aliases) {
            this.diffFunction = diffFunction;
            this.aliases = Set.of(aliases);
        }

        public Integer diff(ZonedDateTime startTimestamp, ZonedDateTime endTimestamp) {
            return this.diffFunction.apply(startTimestamp, endTimestamp);
        }

        @Override
        public Iterable<String> aliases() {
            return this.aliases;
        }

        public static Part resolve(String dateTimeUnit) {
            Part datePartField = DateTimeField.resolveMatch(NAME_TO_PART, dateTimeUnit);
            if (datePartField == null) {
                List<String> similar = DateTimeField.findSimilar(NAME_TO_PART.keySet(), dateTimeUnit);
                String errorMessage = !similar.isEmpty() ? String.format(Locale.ROOT, "Received value [%s] is not valid date part to add; did you mean %s?", dateTimeUnit, similar) : String.format(Locale.ROOT, "A value of %s or their aliases is required; received [%s]", Arrays.asList(Part.values()), dateTimeUnit);
                throw new IllegalArgumentException(errorMessage);
            }
            return datePartField;
        }

        static {
            NAME_TO_PART = DateTimeField.initializeResolutionMap((DateTimeField[])Part.values());
        }
    }
}

