/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer;

import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NullEquals;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.EnrichExec;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
import org.elasticsearch.xpack.esql.plan.physical.EsSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec;
import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec;
import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.MvExpandExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.RegexExtractExec;
import org.elasticsearch.xpack.esql.plan.physical.RowExec;
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.ql.common.Failure;
import org.elasticsearch.xpack.ql.common.Failures;
import org.elasticsearch.xpack.ql.expression.AttributeSet;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.Literal;
import org.elasticsearch.xpack.ql.expression.function.Function;
import org.elasticsearch.xpack.ql.expression.predicate.Predicates;
import org.elasticsearch.xpack.ql.expression.predicate.Range;
import org.elasticsearch.xpack.ql.expression.predicate.logical.And;
import org.elasticsearch.xpack.ql.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.ql.expression.predicate.logical.Not;
import org.elasticsearch.xpack.ql.expression.predicate.logical.Or;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules;
import org.elasticsearch.xpack.ql.plan.QueryPlan;
import org.elasticsearch.xpack.ql.plan.logical.Aggregate;
import org.elasticsearch.xpack.ql.plan.logical.EsRelation;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;
import org.elasticsearch.xpack.ql.util.CollectionUtils;

class OptimizerRules {
    private OptimizerRules() {
    }

    public static final class PropagateEquals
    extends OptimizerRules.OptimizerExpressionRule<BinaryLogic> {
        PropagateEquals() {
            super(OptimizerRules.TransformDirection.DOWN);
        }

        public Expression rule(BinaryLogic e) {
            if (e instanceof And) {
                return PropagateEquals.propagate((And)e);
            }
            if (e instanceof Or) {
                return PropagateEquals.propagate((Or)e);
            }
            return e;
        }

        private static Expression propagate(And and) {
            ArrayList<Range> ranges = new ArrayList<Range>();
            ArrayList<BinaryComparison> equals = new ArrayList<BinaryComparison>();
            ArrayList<NotEquals> notEquals = new ArrayList<NotEquals>();
            ArrayList<BinaryComparison> inequalities = new ArrayList<BinaryComparison>();
            ArrayList<Object> exps = new ArrayList<Object>();
            boolean changed = false;
            for (Expression ex : Predicates.splitAnd((Expression)and)) {
                if (ex instanceof Range) {
                    ranges.add((Range)ex);
                    continue;
                }
                if (ex instanceof Equals || ex instanceof NullEquals) {
                    BinaryComparison otherEq = (BinaryComparison)ex;
                    if (otherEq.right().foldable() && !DataTypes.isDateTime((DataType)otherEq.left().dataType())) {
                        for (BinaryComparison eq : equals) {
                            Integer comp;
                            if (!otherEq.left().semanticEquals(eq.left()) || (comp = BinaryComparison.compare((Object)eq.right().fold(), (Object)otherEq.right().fold())) == null || comp == 0) continue;
                            return new Literal(and.source(), (Object)Boolean.FALSE, DataTypes.BOOLEAN);
                        }
                        equals.add(otherEq);
                        continue;
                    }
                    exps.add(otherEq);
                    continue;
                }
                if (ex instanceof GreaterThan || ex instanceof GreaterThanOrEqual || ex instanceof LessThan || ex instanceof LessThanOrEqual) {
                    BinaryComparison bc = (BinaryComparison)ex;
                    if (bc.right().foldable()) {
                        inequalities.add(bc);
                        continue;
                    }
                    exps.add(ex);
                    continue;
                }
                if (ex instanceof NotEquals) {
                    NotEquals otherNotEq = (NotEquals)ex;
                    if (otherNotEq.right().foldable()) {
                        notEquals.add(otherNotEq);
                        continue;
                    }
                    exps.add(ex);
                    continue;
                }
                exps.add(ex);
            }
            for (BinaryComparison eq : equals) {
                Integer compare;
                Object eqValue = eq.right().fold();
                Iterator iterator = ranges.iterator();
                while (iterator.hasNext()) {
                    Range range = (Range)iterator.next();
                    if (!range.value().semanticEquals(eq.left())) continue;
                    if (range.lower().foldable() && (compare = BinaryComparison.compare((Object)range.lower().fold(), (Object)eqValue)) != null && (compare > 0 || compare == 0 && !range.includeLower())) {
                        return new Literal(and.source(), (Object)Boolean.FALSE, DataTypes.BOOLEAN);
                    }
                    if (range.upper().foldable() && (compare = BinaryComparison.compare((Object)range.upper().fold(), (Object)eqValue)) != null && (compare < 0 || compare == 0 && !range.includeUpper())) {
                        return new Literal(and.source(), (Object)Boolean.FALSE, DataTypes.BOOLEAN);
                    }
                    iterator.remove();
                    changed = true;
                }
                Iterator iter = notEquals.iterator();
                while (iter.hasNext()) {
                    Integer comp;
                    NotEquals neq = (NotEquals)iter.next();
                    if (!eq.left().semanticEquals(neq.left()) || (comp = BinaryComparison.compare((Object)eqValue, (Object)neq.right().fold())) == null) continue;
                    if (comp == 0) {
                        return new Literal(and.source(), (Object)Boolean.FALSE, DataTypes.BOOLEAN);
                    }
                    iter.remove();
                    changed = true;
                }
                iter = inequalities.iterator();
                while (iter.hasNext()) {
                    BinaryComparison bc = (BinaryComparison)iter.next();
                    if (!eq.left().semanticEquals(bc.left()) || (compare = BinaryComparison.compare((Object)eqValue, (Object)bc.right().fold())) == null) continue;
                    if (bc instanceof LessThan || bc instanceof LessThanOrEqual ? compare == 0 && bc instanceof LessThan || 0 < compare : (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) && (compare == 0 && bc instanceof GreaterThan || compare < 0)) {
                        return new Literal(and.source(), (Object)Boolean.FALSE, DataTypes.BOOLEAN);
                    }
                    iter.remove();
                    changed = true;
                }
            }
            return changed ? Predicates.combineAnd((List)CollectionUtils.combine((Collection[])new Collection[]{exps, equals, notEquals, inequalities, ranges})) : and;
        }

        private static Expression propagate(Or or) {
            Equals eq;
            ArrayList<Expression> exps = new ArrayList<Expression>();
            ArrayList<Equals> equals = new ArrayList<Equals>();
            ArrayList<NotEquals> notEquals = new ArrayList<NotEquals>();
            ArrayList<Range> ranges = new ArrayList<Range>();
            ArrayList<BinaryComparison> inequalities = new ArrayList<BinaryComparison>();
            for (Expression ex : Predicates.splitOr((Expression)or)) {
                if (ex instanceof Equals) {
                    eq = (Equals)ex;
                    if (eq.right().foldable()) {
                        equals.add(eq);
                        continue;
                    }
                    exps.add(ex);
                    continue;
                }
                if (ex instanceof NotEquals) {
                    NotEquals neq = (NotEquals)ex;
                    if (neq.right().foldable()) {
                        notEquals.add(neq);
                        continue;
                    }
                    exps.add(ex);
                    continue;
                }
                if (ex instanceof Range) {
                    ranges.add((Range)ex);
                    continue;
                }
                if (ex instanceof BinaryComparison) {
                    BinaryComparison bc = (BinaryComparison)ex;
                    if (bc.right().foldable()) {
                        inequalities.add(bc);
                        continue;
                    }
                    exps.add(ex);
                    continue;
                }
                exps.add(ex);
            }
            boolean updated = false;
            Iterator iterEq = equals.iterator();
            while (iterEq.hasNext()) {
                int i;
                Integer comp;
                eq = (Equals)iterEq.next();
                Object eqValue = eq.right().fold();
                boolean removeEquals = false;
                for (NotEquals neq : notEquals) {
                    if (!eq.left().semanticEquals(neq.left()) || (comp = BinaryComparison.compare((Object)eqValue, (Object)neq.right().fold())) == null) continue;
                    if (comp == 0) {
                        return Literal.TRUE;
                    }
                    removeEquals = true;
                    break;
                }
                if (removeEquals) {
                    iterEq.remove();
                    updated = true;
                    continue;
                }
                for (i = 0; i < ranges.size(); ++i) {
                    Integer upperComp;
                    Range range = (Range)ranges.get(i);
                    if (!eq.left().semanticEquals(range.value())) continue;
                    Integer lowerComp = range.lower().foldable() ? BinaryComparison.compare((Object)eqValue, (Object)range.lower().fold()) : null;
                    Integer n = upperComp = range.upper().foldable() ? BinaryComparison.compare((Object)eqValue, (Object)range.upper().fold()) : null;
                    if (lowerComp != null && lowerComp == 0) {
                        if (!range.includeLower()) {
                            ranges.set(i, new Range(range.source(), range.value(), range.lower(), true, range.upper(), range.includeUpper(), range.zoneId()));
                        }
                        removeEquals = true;
                        break;
                    }
                    if (upperComp != null && upperComp == 0) {
                        if (!range.includeUpper()) {
                            ranges.set(i, new Range(range.source(), range.value(), range.lower(), range.includeLower(), range.upper(), true, range.zoneId()));
                        }
                        removeEquals = true;
                        break;
                    }
                    if (lowerComp == null || upperComp == null || 0 >= lowerComp || upperComp >= 0) continue;
                    removeEquals = true;
                    break;
                }
                if (removeEquals) {
                    iterEq.remove();
                    updated = true;
                    continue;
                }
                for (i = 0; i < inequalities.size(); ++i) {
                    BinaryComparison bc = (BinaryComparison)inequalities.get(i);
                    if (!eq.left().semanticEquals(bc.left()) || (comp = BinaryComparison.compare((Object)eqValue, (Object)bc.right().fold())) == null) continue;
                    if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) {
                        if (comp < 0) continue;
                        if (comp == 0 && bc instanceof GreaterThan) {
                            inequalities.set(i, new GreaterThanOrEqual(bc.source(), bc.left(), bc.right(), bc.zoneId()));
                        }
                        removeEquals = true;
                        break;
                    }
                    if (!(bc instanceof LessThan) && !(bc instanceof LessThanOrEqual) || comp > 0) continue;
                    if (comp == 0 && bc instanceof LessThan) {
                        inequalities.set(i, new LessThanOrEqual(bc.source(), bc.left(), bc.right(), bc.zoneId()));
                    }
                    removeEquals = true;
                    break;
                }
                if (!removeEquals) continue;
                iterEq.remove();
                updated = true;
            }
            return updated ? Predicates.combineOr((List)CollectionUtils.combine((Collection[])new Collection[]{exps, equals, notEquals, inequalities, ranges})) : or;
        }
    }

    public static final class BooleanFunctionEqualsElimination
    extends OptimizerRules.OptimizerExpressionRule<BinaryComparison> {
        BooleanFunctionEqualsElimination() {
            super(OptimizerRules.TransformDirection.UP);
        }

        protected Expression rule(BinaryComparison bc) {
            if ((bc instanceof Equals || bc instanceof NotEquals) && bc.left() instanceof Function) {
                if (Literal.TRUE.equals((Object)bc.right())) {
                    return bc instanceof Equals ? bc.left() : new Not(bc.left().source(), bc.left());
                }
                if (Literal.FALSE.equals((Object)bc.right())) {
                    return bc instanceof Equals ? new Not(bc.left().source(), bc.left()) : bc.left();
                }
            }
            return bc;
        }
    }

    public static class CombineDisjunctionsToIn
    extends OptimizerRules.OptimizerExpressionRule<Or> {
        CombineDisjunctionsToIn() {
            super(OptimizerRules.TransformDirection.UP);
        }

        protected In createIn(Expression key, List<Expression> values, ZoneId zoneId) {
            return new In(key.source(), key, values);
        }

        protected Equals createEquals(Expression k, Set<Expression> v, ZoneId finalZoneId) {
            return new Equals(k.source(), k, v.iterator().next(), finalZoneId);
        }

        protected Expression rule(Or or) {
            Or e = or;
            List exps = Predicates.splitOr((Expression)e);
            LinkedHashMap<Expression, Set> found = new LinkedHashMap<Expression, Set>();
            ZoneId zoneId = null;
            LinkedList<Expression> ors = new LinkedList<Expression>();
            for (Expression exp : exps) {
                if (exp instanceof Equals) {
                    Equals eq = (Equals)exp;
                    if (eq.right().foldable()) {
                        found.computeIfAbsent(eq.left(), k -> new LinkedHashSet()).add(eq.right());
                    } else {
                        ors.add(exp);
                    }
                    if (zoneId != null) continue;
                    zoneId = eq.zoneId();
                    continue;
                }
                if (exp instanceof In) {
                    In in = (In)exp;
                    found.computeIfAbsent(in.value(), k -> new LinkedHashSet()).addAll(in.list());
                    if (zoneId != null) continue;
                    zoneId = in.zoneId();
                    continue;
                }
                ors.add(exp);
            }
            if (!found.isEmpty()) {
                ZoneId finalZoneId = zoneId;
                found.forEach((k, v) -> ors.add((Expression)(v.size() == 1 ? this.createEquals((Expression)k, (Set<Expression>)v, finalZoneId) : this.createIn((Expression)k, (List<Expression>)new ArrayList<Expression>((Collection<Expression>)v), finalZoneId))));
                Expression combineOr = Predicates.combineOr(ors);
                if (!e.semanticEquals(combineOr)) {
                    e = combineOr;
                }
            }
            return e;
        }
    }

    static class PhysicalPlanDependencyCheck
    extends DependencyConsistency<PhysicalPlan> {
        PhysicalPlanDependencyCheck() {
        }

        @Override
        protected AttributeSet generates(PhysicalPlan physicalPlan) {
            if (physicalPlan instanceof EsSourceExec || physicalPlan instanceof EsStatsQueryExec || physicalPlan instanceof EsQueryExec || physicalPlan instanceof LocalSourceExec || physicalPlan instanceof RowExec || physicalPlan instanceof ExchangeExec || physicalPlan instanceof ExchangeSourceExec || physicalPlan instanceof AggregateExec || physicalPlan instanceof ShowExec) {
                return physicalPlan.outputSet();
            }
            if (physicalPlan instanceof FieldExtractExec) {
                FieldExtractExec fieldExtractExec = (FieldExtractExec)physicalPlan;
                return new AttributeSet(fieldExtractExec.attributesToExtract());
            }
            if (physicalPlan instanceof EvalExec) {
                EvalExec eval = (EvalExec)physicalPlan;
                return new AttributeSet((Collection)Expressions.asAttributes(eval.fields()));
            }
            if (physicalPlan instanceof RegexExtractExec) {
                RegexExtractExec extract = (RegexExtractExec)physicalPlan;
                return new AttributeSet(extract.extractedFields());
            }
            if (physicalPlan instanceof MvExpandExec) {
                MvExpandExec mvExpand = (MvExpandExec)physicalPlan;
                return new AttributeSet(mvExpand.expanded());
            }
            if (physicalPlan instanceof EnrichExec) {
                EnrichExec enrich = (EnrichExec)physicalPlan;
                return new AttributeSet((Collection)Expressions.asAttributes(enrich.enrichFields()));
            }
            return AttributeSet.EMPTY;
        }

        @Override
        protected AttributeSet references(PhysicalPlan plan) {
            AggregateExec aggregate;
            if (plan instanceof AggregateExec && (aggregate = (AggregateExec)plan).getMode() == AggregateExec.Mode.FINAL) {
                return aggregate.inputSet();
            }
            return plan.references();
        }
    }

    static class LogicalPlanDependencyCheck
    extends DependencyConsistency<LogicalPlan> {
        LogicalPlanDependencyCheck() {
        }

        @Override
        protected AttributeSet references(LogicalPlan plan) {
            if (plan instanceof Enrich) {
                Enrich enrich = (Enrich)plan;
                return enrich.matchField().references();
            }
            return super.references(plan);
        }

        @Override
        protected AttributeSet generates(LogicalPlan logicalPlan) {
            if (logicalPlan instanceof EsRelation || logicalPlan instanceof LocalRelation || logicalPlan instanceof Row || logicalPlan instanceof Aggregate) {
                return logicalPlan.outputSet();
            }
            if (logicalPlan instanceof Eval) {
                Eval eval = (Eval)logicalPlan;
                return new AttributeSet((Collection)Expressions.asAttributes(eval.fields()));
            }
            if (logicalPlan instanceof RegexExtract) {
                RegexExtract extract = (RegexExtract)logicalPlan;
                return new AttributeSet(extract.extractedFields());
            }
            if (logicalPlan instanceof MvExpand) {
                MvExpand mvExpand = (MvExpand)logicalPlan;
                return new AttributeSet(mvExpand.expanded());
            }
            if (logicalPlan instanceof Enrich) {
                Enrich enrich = (Enrich)logicalPlan;
                return new AttributeSet((Collection)Expressions.asAttributes(enrich.enrichFields()));
            }
            return AttributeSet.EMPTY;
        }
    }

    static class DependencyConsistency<P extends QueryPlan<P>> {
        DependencyConsistency() {
        }

        void checkPlan(P p, Failures failures) {
            AttributeSet refs = this.references(p);
            AttributeSet input = p.inputSet();
            AttributeSet generated = this.generates(p);
            AttributeSet missing = refs.subtract(input).subtract(generated);
            if (missing.size() > 0) {
                failures.add(Failure.fail(p, (String)"Plan [{}] optimized incorrectly due to missing references {}", (Object[])new Object[]{p.nodeString(), missing}));
            }
        }

        protected AttributeSet references(P p) {
            return p.references();
        }

        protected AttributeSet generates(P p) {
            return AttributeSet.EMPTY;
        }
    }
}

