/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.rl;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiPredicate;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import nz.org.riskscape.ReflectionUtils;
import nz.org.riskscape.engine.RiskscapeException;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.function.BinaryOperatorFunction;
import nz.org.riskscape.engine.function.BinaryPredicateFunction;
import nz.org.riskscape.engine.function.NullSafeFunction;
import nz.org.riskscape.engine.function.OperatorResolver;
import nz.org.riskscape.engine.function.RiskscapeFunction;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Struct.StructBuilder;
import nz.org.riskscape.engine.types.Struct.StructMember;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.rl.TokenTypes;
import nz.org.riskscape.rl.ast.BinaryOperation;

/**
 * Core set of functions that work with the default types in RiskScape - should cover most of what ECQL does.
 */
public class DefaultOperators implements OperatorResolver {

  public static final EnumSet<TokenTypes> MATHS_OPERATORS = EnumSet.of(
      TokenTypes.PLUS,
      TokenTypes.MINUS,
      TokenTypes.MULTIPLY,
      TokenTypes.DIVIDE,
      TokenTypes.POW
  );

  public static final EnumSet<TokenTypes> BOOLEAN_COMPARATORS= EnumSet.of(
      TokenTypes.GREATER_THAN,
      TokenTypes.GREATER_THAN_EQUAL,
      TokenTypes.LESS_THAN,
      TokenTypes.LESS_THAN_EQUAL
  );

  public static final EnumSet<TokenTypes> EQUALITY_OPERATORS = EnumSet.of(
      TokenTypes.EQUALS,
      TokenTypes.NOT_EQUALS
  );

  public static final EnumSet<TokenTypes> BOOLEAN_LOGIC_OPERATORS = EnumSet.of(
      TokenTypes.OR,
      TokenTypes.AND
  );


  public static final Map<TokenTypes, List<RiskscapeFunction>> MATHS_FUNCTIONS =
      ImmutableMap.<TokenTypes, List<RiskscapeFunction>>builder()
      .put(
          TokenTypes.PLUS,
          Arrays.asList(
            operatorFor(TokenTypes.PLUS, Long.class, (a, b) -> a + b),
            operatorFor(TokenTypes.PLUS, String.class, (a, b) -> a + b),
            operatorFor(TokenTypes.PLUS, Double.class, (a, b) -> a + b)
          )
      )
      .put(
          TokenTypes.MINUS,
          Arrays.asList(
            operatorFor(TokenTypes.MINUS, Long.class, (a, b) -> a - b),
            operatorFor(TokenTypes.MINUS, Double.class, (a, b) -> a - b)
          )
      )
      .put(
          TokenTypes.MULTIPLY,
          Arrays.asList(
            operatorFor(TokenTypes.MULTIPLY, Long.class, (a, b) -> a * b),
            operatorFor(TokenTypes.MULTIPLY, Double.class, (a, b) -> a * b)
          )
      )
      .put(
          TokenTypes.DIVIDE,
          Arrays.asList(
            operatorFor(TokenTypes.DIVIDE, Double.class, (a, b) -> a / b)
          )
      )
      .put(
          TokenTypes.POW,
          Arrays.asList(
            operatorFor(TokenTypes.POW, Double.class, (a, b) -> Math.pow(a, b))
          )
      )
      .build();

  public static final DefaultOperators INSTANCE = new DefaultOperators();

  public <T extends Object> RiskscapeFunction predicateFor(
      TokenTypes operator, Class<T> javaType, BiPredicate<T, T> predicate) {

    Type type = Types.fromJavaType(javaType);
    List<Type> argumentTypes = Arrays.asList(type, type);

    return new RiskscapeFunction() {

      @Override
      public Object call(List<Object> args) {
        return predicate.test(javaType.cast(args.get(0)), javaType.cast(args.get(1)));
      }

      @Override
      public List<Type> getArgumentTypes() {
        return argumentTypes;
      }

      @Override
      public Type getReturnType() {
        return Types.BOOLEAN;
      }
    };
  }

  /**
   * Returns a {@link RiskscapeFunction} implementing OR logic.
   *
   * If nullableInputs is true, then an exception to the any null input returns null
   * rule (refer to {@link NullSafeFunction}) is made and true/false is returned as long as one of
   * the arguments is not null. If both arguments are null then null is returned.
   *
   * The function returned by this method should not be wrapped in a {@link NullSafeFunction} as that
   * will change the OR logic behaviour with one null argument.
   *
   * @param nullableInputs indicates if either input could be null
   * @return or function
   */
  public RiskscapeFunction orFunction(boolean nullableInputs) {

    Type type = Nullable.ifTrue(nullableInputs, Types.BOOLEAN);
    List<Type> argumentTypes = Arrays.asList(type, type);

    return new RiskscapeFunction() {

      @Override
      public Object call(List<Object> args) {
        Object lhs = args.get(0);
        Object rhs = args.get(1);
        if (!nullableInputs || lhs != null && rhs != null) {
          return Boolean.logicalOr(Boolean.class.cast(lhs), Boolean.class.cast(rhs));
        } else if (lhs != null) {
          return lhs;
        } else if (rhs != null) {
          return rhs;
        }
        return null;
      }

      @Override
      public List<Type> getArgumentTypes() {
        return argumentTypes;
      }

      @Override
      public Type getReturnType() {
        return type;
      }
    };
  }

  public static <T extends Object> RiskscapeFunction operatorFor(
      TokenTypes operationToken,
      Class<T> javaType,
      BinaryOperator<T> function
  ) {
    return new BinaryOperatorFunction<>(operationToken, function, javaType);
  }

  @Override
  public Optional<RiskscapeFunction> resolve(
      RealizationContext context,
      BinaryOperation operation,
      Type inputType,
      Type lhs,
      Type rhs
  ) {

    TokenTypes tt = operation.getNormalizedOperator();

    Optional<RiskscapeFunction> result = Optional.empty();
    if (MATHS_OPERATORS.contains(tt)) {
      result = resolveMathsOperator(operation, inputType, lhs, rhs);
    } else if (BOOLEAN_LOGIC_OPERATORS.contains(tt)) {
      result = resolveLogicOperator(operation, inputType, lhs, rhs);
    } else if (BOOLEAN_COMPARATORS.contains(tt)) {
      result = resolveComparators(operation, inputType, lhs, rhs);
    } else if (EQUALITY_OPERATORS.contains(tt)) {
      result = resolveEqualityOperator(operation, inputType, lhs, rhs);
    }

    if (result.isPresent()) {
      return result;
    } else {
      if (Nullable.any(lhs, rhs)) {
        // pull nulls off and try again, wrapping with null safety
        return resolve(context, operation, inputType, Nullable.strip(lhs), Nullable.strip(rhs))
            .map(function -> NullSafeFunction.wrap(function));
      } else {
        // if both sides are structs, then we might be able to operate on their members...
        Struct lhsStruct = lhs.find(Struct.class).orElse(null);
        Struct rhsStruct = rhs.find(Struct.class).orElse(null);

        if (lhsStruct != null && rhsStruct != null) {
          return buildStructCompositeFunction(context, operation, inputType, lhsStruct, rhsStruct)
              .map(pairs -> new BinaryStructCompositeFunction(context, lhs, rhs, pairs));
        } else {
          if (lhsStruct != null || rhsStruct != null) {
            return buildOneSidedStructCompositeFunction(context, operation, inputType, lhs, rhs, lhsStruct != null);
          }
        }

        return result;
      }
    }
  }

  private Optional<RiskscapeFunction> buildOneSidedStructCompositeFunction(
      RealizationContext context, BinaryOperation operation, Type inputType,
      Type lhs, Type rhs, boolean lhsIsStruct) {

    Struct struct   = (lhsIsStruct ? lhs : rhs).find(Struct.class).get();
    Type scalarType = (lhsIsStruct ? rhs : lhs);

    Struct returnType = Struct.EMPTY_STRUCT;
    List<RiskscapeFunction> functions = new ArrayList<>(struct.size());
    for (StructMember structMember : struct.getMembers()) {

      Type[] args = lhsIsStruct
          ? new Type[] {structMember.getType(), scalarType}
          : new Type[] {scalarType, structMember.getType()};

      Optional<RiskscapeFunction> built = resolve(context, operation, inputType, args[0], args[1]);

      if (!built.isPresent()) {
        return Optional.empty();
      } else {
        functions.add(built.get());
      }

      returnType = returnType.add(structMember.getKey(), built.get().getReturnType());
    }

    returnType = context.normalizeStruct(returnType);

    return Optional.of(new OneSidedStructCompositeFunction(Arrays.asList(lhs, rhs), returnType, lhsIsStruct,
        functions.toArray(new RiskscapeFunction[0])));
  }

  @RequiredArgsConstructor
  private class OneSidedStructCompositeFunction implements RiskscapeFunction {

    @Getter
    private final List<Type> argumentTypes;

    @Getter
    private final Struct returnType;

    private final boolean lhsIsStruct;

    private final RiskscapeFunction[] memberFunctions;

    @Override
    public Object call(List<Object> args) {
      Object lhs = args.get(0);
      Object rhs = args.get(1);

      Tuple tuple   = (Tuple) (lhsIsStruct ? lhs : rhs);
      Object scalar =         (lhsIsStruct ? rhs : lhs);

      Tuple toReturn = new Tuple(returnType);
      Object[] newArgs = new Object[2];
      List<Object> newArgsList = Arrays.asList(newArgs);

      newArgs[lhsIsStruct ? 1 : 0] = scalar;
      for (int i = 0; i < toReturn.size(); i++) {
        newArgs[lhsIsStruct ? 0 : 1] = tuple.fetch(i);
        toReturn.set(i, memberFunctions[i].call(newArgsList));
      }

      return toReturn;
    }

  }

  private Optional<RiskscapeFunction> resolveEqualityOperator(BinaryOperation operation, Type inputType, Type lhs,
      Type rhs) {

    Type unwrappedLhs = Nullable.unwrap(lhs);
    Type unwrappedRhs = Nullable.unwrap(rhs);
    Set<Type> types = Sets.newHashSet(unwrappedLhs, unwrappedRhs);

    if (types.stream().allMatch(Type::isNumeric)) {
      return resolveComparators(operation, inputType, lhs, rhs);
    }

    Optional<RiskscapeFunction> equalityFunction;
    boolean negate = operation.getNormalizedOperator() == TokenTypes.NOT_EQUALS;

    equalityFunction = Optional
        .of(BinaryPredicateFunction.untyped(
            operation.getNormalizedOperator(), unwrappedLhs, unwrappedRhs,
            (l, r) -> Objects.equals(l, r) ^ negate));

    return equalityFunction
        // special case - null == null = null - null should always short circuit
        .map(function -> Nullable.any(lhs, rhs) ? NullSafeFunction.wrapIgnoringArgs(function) : function);
  }

  @SuppressWarnings("rawtypes")
  private Optional<RiskscapeFunction> resolveComparators(BinaryOperation operation, Type inputType, Type lhs,
      Type rhs) {

    Type rawLhs = Nullable.unwrap(lhs);
    Type rawRhs = Nullable.unwrap(rhs);

    TokenTypes tt = operation.getNormalizedOperator();
    boolean hasCommonComparableAncestor = ReflectionUtils
        .findCommonAncestorOfType(rawLhs.internalType(), rawRhs.internalType(), Comparable.class).isPresent();

    @SuppressWarnings("unchecked")
    BiPredicate<Comparable, Comparable> predicate = (l, r) -> {
      int result = l.compareTo(r);

      switch (tt) {
      case GREATER_THAN:
        return result == 1;
      case LESS_THAN:
        return result == -1;
      case GREATER_THAN_EQUAL:
        return result != -1;
      case LESS_THAN_EQUAL:
        return result != 1;
      case EQUALS:
        return result == 0;
      case NOT_EQUALS:
        return result != 0;
      default:
        throw new RiskscapeException("Unexpected token - " + tt);
      }
    };

    Set<Type> rawTypes = Sets.newHashSet(rawLhs, rawRhs);
    boolean mixedRawTypes = rawTypes.stream().distinct().toList().size() > 1;

    // special case - if the types aren't the same, but they are all numeric, then we can do a double comparison.
    // If we end up adding a smallint type, we'll probably want a better mechanism than this to do the most accurate
    // up-cast, e.g. mixing smallint and int goes to int (not to double)
    if (mixedRawTypes && rawTypes.stream().allMatch(Type::isNumeric)) {
      BiPredicate<Comparable, Comparable> wrap = predicate;
      predicate = (l, r) -> wrap.test(((Number)l).doubleValue(), ((Number)r).doubleValue());
    } else if (!hasCommonComparableAncestor) {
      return Optional.empty();
    }

    return Optional.of(
        new BinaryPredicateFunction<>(tt, Comparable.class, rawLhs, Comparable.class, rawRhs, predicate))
        .map(f -> Nullable.any(lhs, rhs) ? NullSafeFunction.wrapIgnoringArgs(f): f);
  }

  private Optional<RiskscapeFunction> resolveLogicOperator(BinaryOperation operation, Type inputType, Type lhs,
      Type rhs) {

    Type rawLhs = Nullable.unwrap(lhs);
    Type rawRhs = Nullable.unwrap(rhs);

    if (rawLhs.equals(Types.BOOLEAN) && rawRhs.equals(Types.BOOLEAN)) {
      BiPredicate<Boolean, Boolean> predicate;
      TokenTypes normalizedOperator = operation.getNormalizedOperator();
      switch (normalizedOperator) {
      case OR:
        //Or is a special case that we want to return true if either the left or right are true,
        //regardless of the other side possibly being null.
        return Optional.of(orFunction(Nullable.any(lhs, rhs)));
      case AND:
        predicate = (l, r) -> l && r;
        break;
      default:
        return Optional.empty();
      }

      RiskscapeFunction function = predicateFor(normalizedOperator, Boolean.class, predicate);

      return Optional.of(function)
          .map(f -> Nullable.any(lhs, rhs) ? NullSafeFunction.wrapIgnoringArgs(f): f);
    } else {
      return Optional.empty();
    }
  }

  private Optional<RiskscapeFunction> resolveMathsOperator(BinaryOperation functionCall, Type inputType, Type lhs,
      Type rhs) {
    Type rawLhs = lhs.getUnwrappedType();
    Type rawRhs = rhs.getUnwrappedType();

    Optional<RiskscapeFunction> found = rslvRawMathsOperator(functionCall, inputType, rawLhs, rawRhs);

    if (!found.isPresent()) {
      Set<Type> rawTypes = Sets.newHashSet(rawLhs, rawRhs);
      // if all of our args, we can coerce to floating/double as a lowest-common-"good enough"-denominator
      if (rawTypes.stream().allMatch(type -> type.isNumeric() && !type.isNullable())) {
        found = rslvRawMathsOperator(functionCall, inputType, Types.FLOATING, Types.FLOATING)
            .map(rf -> coerceIntArgsForOperator(rf));
      }
    }

    return found;
  }

  @RequiredArgsConstructor
  private static class StructPair {
    final StructMember lhs;
    final StructMember rhs;
    final RiskscapeFunction applyTo;
  }

  private static class BinaryStructCompositeFunction implements RiskscapeFunction {

    @Getter
    private final List<Type> argumentTypes;

    @Getter
    private final Struct returnType;

    private final List<StructPair> pairs;

    BinaryStructCompositeFunction(RealizationContext context, Type lhs, Type rhs, List<StructPair> pairs) {
      this.pairs = pairs;
      this.argumentTypes = Arrays.asList(lhs, rhs);
      StructBuilder builder = Struct.builder();
      for (StructPair pair : pairs) {
        builder.add(pair.lhs.getKey(), pair.applyTo.getReturnType());
      }
      this.returnType = context.normalizeStruct(builder.build());
    }

    @Override
    public Object call(List<Object> args) {
      Tuple lhs = (Tuple) args.get(0);
      Tuple rhs = (Tuple) args.get(1);

      Tuple result = new Tuple(returnType);
      List<Object> argsList = Arrays.asList(null, null);
      int index = 0;
      for (StructPair pair : pairs) {
        argsList.set(0, lhs.fetch(pair.lhs));
        argsList.set(1, rhs.fetch(pair.rhs));
        result.set(index++, pair.applyTo.call(argsList));
      }

      return result;
    }
  }

  private Optional<List<StructPair>> buildStructCompositeFunction(
      RealizationContext context,
      BinaryOperation functionCall,
      Type inputType,
      Struct rawLhs,
      Struct rawRhs
  ) {

    // pair up members from the lhs and rhs by name
    List<Pair<StructMember, Optional<StructMember>>> paired = rawLhs.getMembers().stream().map(lhsMember -> {
      Optional<Struct.StructMember> rhsMember =
          rawRhs.getMembers().stream().filter(rm -> rm.getKey().equals(lhsMember.getKey())).findFirst();

      return Pair.of(lhsMember, rhsMember);
    }).collect(Collectors.toList());

    // we want a match for every lhs member
    if (!paired.stream().allMatch(pair -> pair.getRight().isPresent())) {
      return Optional.empty();
    }

    List<StructPair> pairs = paired.stream()
        // we can now strip out the absent rhs members safely
      .map(pair -> Pair.of(pair.getLeft(), pair.getRight().get()))
      // find an function for each pair
      .map(pair -> {
      return
        resolve(context, functionCall, inputType, pair.getLeft().getType(), pair.getRight().getType())
          .map(function -> new StructPair(pair.getLeft(), pair.getRight(), function));
    })
      // drop any pairs that had no function
      .filter(optPair -> optPair.isPresent())
      .map(Optional::get)
      .collect(Collectors.toList());

    // if we didn't match a pair for every member, fail
    if (pairs.size() != rawLhs.size()) {
      return Optional.empty();
    } else {
      return Optional.of(pairs);
    }
  }

  private Optional<RiskscapeFunction> rslvRawMathsOperator(BinaryOperation functionCall, Type inputType, Type rawLhs,
      Type rawRhs) {

    List<RiskscapeFunction> matchedForOperator = MATHS_FUNCTIONS.getOrDefault(
        functionCall.getOperator().type, Collections.emptyList());
    List<Type> args = Arrays.asList(rawLhs, rawRhs);

    RiskscapeFunction found = null;
    for (RiskscapeFunction function : matchedForOperator) {
      if (function.getArgumentTypes().equals(args)) {
        found = function;
      }
    }

    return Optional.ofNullable(found);
  }

  /**
   * Returns a function that wraps a binary double function so that all args are doubles/floating and result is
   * double/floating
   */
  private RiskscapeFunction coerceIntArgsForOperator(RiskscapeFunction operator) {
    // XXX This could be done in a special case of the BinaryOperatorFunction
    return new RiskscapeFunction() {

      @Override
      public Type getReturnType() {
        return Types.FLOATING;
      }

      @Override
      public List<Type> getArgumentTypes() {
        return Arrays.asList(Types.FLOATING, Types.FLOATING);
      }

      @Override
      public Object call(List<Object> args) {
        Number lhs = (Number) args.get(0);
        Number rhs = (Number) args.get(1);

        return operator.call(Arrays.asList(lhs.doubleValue(), rhs.doubleValue()));
      }
    };
  }
}
