/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.gt;

import java.util.ArrayList;
import java.util.List;

import org.geotools.api.filter.And;
import org.geotools.api.filter.BinaryComparisonOperator;
import org.geotools.api.filter.BinaryLogicOperator;
import org.geotools.api.filter.ExcludeFilter;
import org.geotools.api.filter.Filter;
import org.geotools.api.filter.FilterVisitor;
import org.geotools.api.filter.Id;
import org.geotools.api.filter.IncludeFilter;
import org.geotools.api.filter.Not;
import org.geotools.api.filter.Or;
import org.geotools.api.filter.PropertyIsBetween;
import org.geotools.api.filter.PropertyIsEqualTo;
import org.geotools.api.filter.PropertyIsGreaterThan;
import org.geotools.api.filter.PropertyIsGreaterThanOrEqualTo;
import org.geotools.api.filter.PropertyIsLessThan;
import org.geotools.api.filter.PropertyIsLessThanOrEqualTo;
import org.geotools.api.filter.PropertyIsLike;
import org.geotools.api.filter.PropertyIsNil;
import org.geotools.api.filter.PropertyIsNotEqualTo;
import org.geotools.api.filter.PropertyIsNull;
import org.geotools.api.filter.expression.Add;
import org.geotools.api.filter.expression.BinaryExpression;
import org.geotools.api.filter.expression.Divide;
import org.geotools.api.filter.expression.Expression;
import org.geotools.api.filter.expression.ExpressionVisitor;
import org.geotools.api.filter.expression.Function;
import org.geotools.api.filter.expression.Literal;
import org.geotools.api.filter.expression.Multiply;
import org.geotools.api.filter.expression.NilExpression;
import org.geotools.api.filter.expression.PropertyName;
import org.geotools.api.filter.expression.Subtract;
import org.geotools.api.filter.spatial.BBOX;
import org.geotools.api.filter.spatial.Beyond;
import org.geotools.api.filter.spatial.BinarySpatialOperator;
import org.geotools.api.filter.spatial.Contains;
import org.geotools.api.filter.spatial.Crosses;
import org.geotools.api.filter.spatial.DWithin;
import org.geotools.api.filter.spatial.Disjoint;
import org.geotools.api.filter.spatial.Equals;
import org.geotools.api.filter.spatial.Intersects;
import org.geotools.api.filter.spatial.Overlaps;
import org.geotools.api.filter.spatial.Touches;
import org.geotools.api.filter.spatial.Within;
import org.geotools.api.filter.temporal.After;
import org.geotools.api.filter.temporal.AnyInteracts;
import org.geotools.api.filter.temporal.Before;
import org.geotools.api.filter.temporal.Begins;
import org.geotools.api.filter.temporal.BegunBy;
import org.geotools.api.filter.temporal.BinaryTemporalOperator;
import org.geotools.api.filter.temporal.During;
import org.geotools.api.filter.temporal.EndedBy;
import org.geotools.api.filter.temporal.Ends;
import org.geotools.api.filter.temporal.Meets;
import org.geotools.api.filter.temporal.MetBy;
import org.geotools.api.filter.temporal.OverlappedBy;
import org.geotools.api.filter.temporal.TContains;
import org.geotools.api.filter.temporal.TEquals;
import org.geotools.api.filter.temporal.TOverlaps;

import com.google.common.collect.Lists;

import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.Unchecked;
import nz.org.riskscape.engine.expr.TypedExpression;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ResultOrProblems;

/**
 * Validates a {@link Filter} or an {@link Expression} against a {@link Struct} to check for potential problems when
 * it is evaluated against a {@link Tuple} of that type.
 */
public class FilterValidator {

  public static final FilterValidator INSTANCE = new FilterValidator();

  public List<Problem> validateFilter(Struct struct, Filter filter) {
    List<Problem> problems = new ArrayList<>();

    filter.accept(new FilterVisitor() {

      @Override
      public Object visitNullFilter(Object extraData) {
        return null;
      }

      private Object unsupported(Object message) {
        problems.add(new Unchecked("Unsupported filter - " + message));
        return null;
      }

      @Override
      public Object visit(TOverlaps contains, Object extraData) {
        return check(contains);
      }

      private Boolean check(BinaryTemporalOperator contains) {
        problems.addAll(validateExpression(struct, contains.getExpression1()).getProblems());
        problems.addAll(validateExpression(struct, contains.getExpression2()).getProblems());
        return null;
      }

      @Override
      public Object visit(TEquals equals, Object extraData) {
        return check(equals);
      }

      @Override
      public Object visit(TContains contains, Object extraData) {
        return check(contains);
      }

      @Override
      public Object visit(OverlappedBy overlappedBy, Object extraData) {
        return check(overlappedBy);
      }

      @Override
      public Object visit(MetBy metBy, Object extraData) {
        return check(metBy);
      }

      @Override
      public Object visit(Meets meets, Object extraData) {
        return check(meets);
      }

      @Override
      public Object visit(Ends ends, Object extraData) {
        return check(ends);
      }

      @Override
      public Object visit(EndedBy endedBy, Object extraData) {
        return check(endedBy);
      }

      @Override
      public Object visit(During during, Object extraData) {
        return check(during);
      }

      @Override
      public Object visit(BegunBy begunBy, Object extraData) {
        return check(begunBy);
      }

      @Override
      public Object visit(Begins begins, Object extraData) {
        return check(begins);
      }

      @Override
      public Object visit(Before before, Object extraData) {
        return check(before);
      }

      @Override
      public Object visit(AnyInteracts anyInteracts, Object extraData) {
        return check(anyInteracts);
      }

      @Override
      public Object visit(After after, Object extraData) {
        return check(after);
      }

      @Override
      public Object visit(Within filter, Object extraData) {
        return check(filter);
      }

      private Object check(BinarySpatialOperator filter) {
        problems.addAll(validateExpression(struct, filter.getExpression1()).getProblems());
        problems.addAll(validateExpression(struct, filter.getExpression2()).getProblems());
        return null;
      }

      @Override
      public Object visit(Touches filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(Overlaps filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(Intersects filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(Equals filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(DWithin filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(Disjoint filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(Crosses filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(Contains filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(Beyond filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(BBOX filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(PropertyIsNil filter, Object extraData) {
        problems.addAll(validateExpression(struct, filter.getExpression()).getProblems());
        return null;
      }

      @Override
      public Object visit(PropertyIsNull filter, Object extraData) {
        problems.addAll(validateExpression(struct, filter.getExpression()).getProblems());
        return null;
      }

      @Override
      public Object visit(PropertyIsLike filter, Object extraData) {
        problems.addAll(validateExpression(struct, filter.getExpression()).getProblems());
        return null;
      }

      @Override
      public Object visit(PropertyIsLessThanOrEqualTo filter, Object extraData) {
        return check(filter);
      }

      private Object check(BinaryComparisonOperator filter) {
        problems.addAll(validateExpression(struct, filter.getExpression1()).getProblems());
        problems.addAll(validateExpression(struct, filter.getExpression2()).getProblems());
        return null;
      }

      @Override
      public Object visit(PropertyIsLessThan filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(PropertyIsGreaterThanOrEqualTo filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(PropertyIsGreaterThan filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(PropertyIsNotEqualTo filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(PropertyIsEqualTo filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(PropertyIsBetween filter, Object extraData) {
        problems.addAll(validateExpression(struct, filter.getExpression()).getProblems());
        return null;
      }

      @Override
      public Object visit(Or filter, Object extraData) {
        return check(filter);
      }

      private Object check(BinaryLogicOperator filter) {
        for (Filter child : filter.getChildren()) {
          child.accept(this, null);
        }

        return null;
      }

      @Override
      public Object visit(Not filter, Object extraData) {
        return filter.getFilter().accept(this, null);
      }

      @Override
      public Object visit(Id filter, Object extraData) {
        return unsupported(filter);
      }

      @Override
      public Object visit(And filter, Object extraData) {
        return check(filter);
      }

      @Override
      public Object visit(IncludeFilter filter, Object extraData) {
        return Boolean.TRUE;
      }

      @Override
      public Object visit(ExcludeFilter filter, Object extraData) {
        return Boolean.TRUE;
      }
    }, null);

    return problems;
  }

  /**
   * Check the given {@link Expression} against a {@link Type} and determine whether there are any definite or
   * possible problems with it and infer the return type as best as possible
   */
  public ResultOrProblems<Type> validateExpression(Type type, Expression expression) {
    List<Problem> problems = new ArrayList<>();

    Type returnType = (Type)expression.accept(new ExpressionVisitor() {

      @Override
      public Object visit(Subtract expression, Object extraData) {
        return check(expression);
      }

      //used for finding the most precise number type.
      private List<Type> typeRanking = Lists.newArrayList(Types.ANYTHING, Types.INTEGER, Types.FLOATING, Types.DECIMAL);

      /**
       * Checks the expression arguments and returns the most precise of the argument types. This is limited
       * to numeric types and anything.
       *
       * @param expression to check
       * @return the most precise number {@link Type} from expression args
       */
      private Object check(BinaryExpression expression) {
        Type t1 = (Type)expression.getExpression1().accept(this, null);
        Type t2 = (Type)expression.getExpression2().accept(this, null);
        int x = typeRanking.indexOf(t1) - typeRanking.indexOf(t2);
        if (x >= 0) {
          return t1;
        }
        return t2;
      }

      @Override
      public Object visit(PropertyName expression, Object extraData) {
        if (expression instanceof TypedExpression) {
          ResultOrProblems<Type> typeResult = ((TypedExpression)expression).evaluateType(type);
          problems.addAll(typeResult.getProblems());
          if (typeResult.isPresent()) {
            return typeResult.get();
          }
        }
        return Types.ANYTHING;
      }

      @Override
      public Object visit(Multiply expression, Object extraData) {
        return check(expression);
      }

      @Override
      public Object visit(Literal expression, Object extraData) {
        if (expression instanceof TypedExpression) {
          ResultOrProblems<Type> typeResult = ((TypedExpression)expression).evaluateType(type);
          problems.addAll(typeResult.getProblems());
          if (typeResult.isPresent()) {
            return typeResult.get();
          }
        }
        return Types.ANYTHING;
      }

      @Override
      public Object visit(Function expression, Object extraData) {
        if (expression instanceof TypedExpression) {
          ResultOrProblems<Type> typeResult = ((TypedExpression)expression).evaluateType(type);
          problems.addAll(typeResult.getProblems());
          if (typeResult.isPresent()) {
            return typeResult.get();
          }
        } else {
          for (Expression child : expression.getParameters()) {
            child.accept(this, null);
          }
        }
        //A well behaved function should tell us what class it's return type will be of.
        //Use this to infer the type.

        return findReturnType(expression);
      }

      private Type findReturnType(Function expression) {

        Class<?> functionReturnClass = Object.class;
        if (expression.getFunctionName() != null && expression.getFunctionName().getReturn() != null) {
          functionReturnClass = expression.getFunctionName().getReturn().getType();
        }

        if (expression.getName().equals("if_then_else")) {
          return findIfThenElseReturnType(expression);
        }
        // no dice - fall back to anything
        return Types.fromJavaTypeOptional(functionReturnClass)
            .orElse(Types.ANYTHING);
      }

      private Type findIfThenElseReturnType(Function expression) {
        Expression thenExpression = expression.getParameters().get(1);
        Expression elseExpression = expression.getParameters().get(2);

        Type foundThen = (Type) thenExpression.accept(this, null);
        Type foundElse = (Type) elseExpression.accept(this, null);

        if (foundThen.equals(foundElse)) {
          // exact is perfect
          return foundThen;
        } else {
          Type notNullable;

          if (Nullable.unwrap(foundThen).equals(Nullable.unwrap(foundElse))) {
            notNullable = foundElse;
          } else {
            notNullable = Types.ANYTHING;
          }

          // respect the nullability of the else case
          return foundElse
                .find(Nullable.class)
                .<Type>map(n -> Nullable.of(notNullable))
                .orElse(notNullable);
        }
      }

      @Override
      public Object visit(Divide expression, Object extraData) {
        check(expression);
        //Division can result in higher precision than the input types.
        return Types.FLOATING;
      }

      @Override
      public Object visit(Add expression, Object extraData) {
        return check(expression);
      }

      @Override
      public Object visit(NilExpression expression, Object extraData) {
        return Types.ANYTHING;
      }
    }, null);

    if (Problem.hasErrors(problems)) {
      return ResultOrProblems.failed(problems);
    } else {
      return ResultOrProblems.of(returnType, problems);
    }

  }

}
