/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import nz.org.riskscape.engine.function.FunctionResolver;
import nz.org.riskscape.engine.function.IdentifiedFunction;
import nz.org.riskscape.engine.function.NullSafeFunction;
import nz.org.riskscape.engine.function.OverloadedFunction;
import nz.org.riskscape.engine.function.RiskscapeFunction;
import nz.org.riskscape.engine.rl.MissingFunctionException;
import nz.org.riskscape.engine.rl.RealizableFunction;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.engine.types.eqrule.Coercer;
import nz.org.riskscape.engine.types.varule.Variance;
import nz.org.riskscape.engine.typeset.TypeSet;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ProblemCode;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.ExpressionProblems;
import nz.org.riskscape.rl.ast.FunctionCall;
import nz.org.riskscape.rl.ast.FunctionCall.Argument;

@Slf4j
@RequiredArgsConstructor
public class DefaultFunctionResolver implements FunctionResolver {

  // memoized for matching in constant expression
  private static final ProblemCode NO_SUCH_MEMBER =
      ExpressionProblems.get().noSuchStructMember("", Collections.emptyList()).getCode();

  /**
   * Helper method for a {@link RealizableFunction} to extract a constant from the arguments with various error
   * handling along the way.  This method will fail if the expression is not a constant expression (e.g. it depends on
   * values extracted from scope).
   * @param <T>
   * @param context context in which realization is happening - typically the argument that was passed to
   * {@link RealizableFunction#realize(RealizationContext, FunctionCall, List)}
   * @param functionCall the functionCall expression that contains an argument to be realized constantly
   * @param argIndex the index of the argument (0 based) that should be a constant
   * @param requiredJavaType the desired java type of the constant - must map to a built-in riskscape type via
   * {@link Types#fromJavaType(Class)}
   * @param requiredType the Riskscape version of the requiredJavaType argument.  Is only used for a problem and no
   * riskscape type checking is actually done against this type.
   * @return the constant value from the function call expression
   * @throws ProblemException if the expression was bad, not a constant, or of the wrong type.  Contained problem
   * should be suitable for display to the user without wrapping with more function-call specific context
   */
  public static <T> T evaluateConstant(
      RealizationContext context,
      FunctionCall functionCall,
      int argIndex,
      Class<T> requiredJavaType,
      Type requiredType
  ) throws ProblemException {

    Argument argument = functionCall.getArguments().get(argIndex);
    Expression expression = argument.getExpression();

    ResultOrProblems<RealizedExpression> constantExpressionOr =
        context.getExpressionRealizer().
        realize(Struct.EMPTY_STRUCT, expression);

    List<Problem> problems  = constantExpressionOr.getProblems();

    if (Problem.hasErrors(problems)) {
      // generally, if it's not a constant expression then it'll have missing attribute errors,
      // e.g. the expression is trying to access foo.bar, which wasn't found
      if (problems.stream().allMatch(prob -> prob.getCode().equals(NO_SUCH_MEMBER))) {
        throw new ProblemException(ExpressionProblems.get().constantRequired(expression));
      } else {
        // not to do with missing attributes, it might just be an invalid expression.
        // Pass the error back unadulterated
        throw new ProblemException(problems);
      }

    } else {
      // make sure it's the type the caller asked for, or return an error
      RealizedExpression realized = constantExpressionOr.get();

      if (!requiredJavaType.equals(realized.getResultType().internalType())) {
        throw new ProblemException(
            TypeProblems.get().mismatch(expression, requiredType, realized.getResultType())
        );
      }

      Object returned = constantExpressionOr.get().evaluate(Tuple.EMPTY_TUPLE);
      return requiredJavaType.cast(returned);
    }
  }

  public static class SetPair {

    final TypeSet typeSet;

    @Getter
    final Type given;

    @Getter
    final Type receiver;

    public SetPair(TypeSet typeSet, Type given, Type receiver) {
      this.typeSet = typeSet;
      this.given = given;
      this.receiver = receiver;
    }

    private static Type pick(Iterator<Type> types) {
      if (types.hasNext()) {
        return types.next();
      } else {
        return Types.NOTHING;
      }
    }

    public static SetPair from(TypeSet typeSet, Iterator<Type> given, Iterator<Type> receiver) {
      return new SetPair(typeSet, pick(given), pick(receiver));
    }

    public static SetPair from(TypeSet typeSet, Type given, Type receiver) {
      return new SetPair(typeSet, given, receiver);
    }

    public boolean isOnlyGivenNullable() {
      return isGivenNullable() && !isReceiverNullable();
    }

    public boolean isGivenNullable() {
      return Nullable.is(given);
    }

    public boolean isReceiverNullable() {
      return Nullable.is(receiver);
    }

    public boolean isAssignable() {
      return typeSet.isAssignable(given, receiver);
    }

    @Override
    public String toString() {
      return String.format("ArgPair(given=%s, receiver=%s)", given, receiver);
    }

    public Optional<Coercer> findEquivalenceCoercer() {
      // NB - this is a safety precaution in case an equivalence coercer is being found for an already legal assignment
      // if these are returned, the routine for adapting functions can get in to an infinite loop
      return isAssignable() ? Optional.empty() : typeSet.findEquivalenceCoercer(given, receiver);
    }

    public Variance testVariance() {
      return typeSet.testVariance(given, receiver);
    }
  }
  @Override
  public ResultOrProblems<RiskscapeFunction> resolve(
      RealizationContext context,
      FunctionCall functionCall,
      Type inputType,
      List<Type> argumentTypes,
      IdentifiedFunction function
  ) {
    return adaptFunction(context, functionCall, function, inputType, argumentTypes)
        .map(rf -> normalizeReturnType(context, functionCall, rf));
  }

  private boolean all(List<SetPair> pairs, Predicate<SetPair> predicate) {
    for (SetPair pair : pairs) {
      if (!predicate.test(pair)) {
        return false;
      }
    }
    return true;
  }

  private <T> boolean any(List<T> pairs, Predicate<T> predicate) {
    for (T pair : pairs) {
      if (predicate.test(pair)) {
        return true;
      }
    }
    return false;
  }

  private ResultOrProblems<RiskscapeFunction> adaptFunction(
      RealizationContext context,
      FunctionCall functionCall,
      IdentifiedFunction targetFunction,
      Type inputType,
      List<Type> givenTypes
  ) {
    RiskscapeFunction toUse;
    TypeSet typeSet = context.getProject().getTypeSet();

    log.debug("Attempting to adapt function {} against {}", targetFunction, givenTypes);

    List<RiskscapeFunction> alternatives = targetFunction.getOverloaded()
        .map(ol -> ol.getAlternatives())
        .orElse(Collections.emptyList());

    if (alternatives.size() > 0) {
      RiskscapeFunction found = searchForBestAlternative(typeSet, alternatives, givenTypes);
      if (found != null) {
        return ResultOrProblems.of(found);
      }
    }

    log.debug("No overloaded functions from {} match",
        targetFunction);

    if (targetFunction.getOverloaded().map(OverloadedFunction::ignoreThis).orElse(false)) {
      // overloaded function ignoreThis means 'don't try and resolve me', so we fail
      return ResultOrProblems.error(new MissingFunctionException(functionCall, givenTypes, targetFunction));
    }

    RealizableFunction realizable = targetFunction.getRealizable().orElse(null);
    if (realizable != null) {
      log.debug("Function {} is realizable, attempting realization...", targetFunction);

      // first, we attempt to realize the original function to see if we can apply any adaption to the types
      // advertised by the realizable function - we don't want every realizable function to have to reimplement
      // type variance/coercion rules
      RiskscapeFunction adapted;
      if (realizable.isDoTypeAdaptation()) {
        adapted = adaptFunction1(typeSet, targetFunction, givenTypes);
      } else {
        adapted = targetFunction;
      }

      ResultOrProblems<RiskscapeFunction> realizationResult;
      // if original, nothing needs to be done
      if (adapted == targetFunction || adapted == null) {
        realizationResult = realizable.realize(context, functionCall, givenTypes);
      } else {
        // at this point, we need to see if type adaption has mapped the input types to something else - because we
        // should realize against those mapped types, not the realizable function's advertised types
        // NB at the moment, this is just making sure that realizing gets to see any difference in types caused
        // by covariance - any coerced/equivalent types get transformed and will have been adapted to the realizable
        // functions advertised types
        // I'd like to replace this with something about more solid and defined and tested, maybe by being able
        // to see the result of the mapping more clearly, that is, the adapted function can produce a mapping from
        // input type to output type - it's possible that the adapting function routine could be refactored to instead
        // return a description of the transformation, rather than a function that applies, so this can be interrogated
        // and reasoned and manipulated before then building a function
        List<Type> realizeAgainst = new ArrayList<>(givenTypes.size());

        for (int i = 0; i < givenTypes.size(); i++) {
          realizeAgainst.add(getCoercedArgumentType(adapted, i, givenTypes.get(i)));
        }
        realizationResult = realizable.realize(context, functionCall, realizeAgainst)
            .map(rz -> rewrap(targetFunction, rz, adapted));
      }

      if (realizationResult.hasErrors()) {
        return realizationResult.composeProblems((s, c) ->
          Problems.foundWith(RiskscapeFunction.class, targetFunction.getId(), c));
      }
      realizationResult = realizationResult.drainWarnings(context.getProblemSink(),
          (severity, children)
              -> Problems.foundWith(RiskscapeFunction.class, targetFunction.getId())
                  .withSeverity(severity)
                  .withChildren(children)
      );

      List<SetPair> finalTypes = zipArgTypes(typeSet, givenTypes, realizationResult.get().getArgumentTypes());
      if (!all(finalTypes, SetPair::isAssignable)) {
        // todo vary the code depending on whether realization failed after type adaption
        return ResultOrProblems.failed(
            Problems.get(ArgsProblems.class).realizableDidNotMatch(targetFunction, givenTypes));
      }

      return realizationResult;
    }


    // not realizable or overloaded - use it directly
    toUse = adaptFunction1(typeSet, targetFunction, givenTypes);

    if (toUse == null) {
      return ResultOrProblems.error(new MissingFunctionException(functionCall, givenTypes, targetFunction));
    } else {
      return ResultOrProblems.of(toUse);
    }
  }

  /**
   * Finds the {@link Type} that the given type would be mapped to by any number of {@link NullSafeFunction}
   * or {@link CoercingFunctionWrapper}s that may be wrapping the target function.
   */
  private Type getCoercedArgumentType(RiskscapeFunction function, int argumentIndex, Type givenType) {
    return function.isA(NullSafeFunction.class)
        .map(nullSafe -> {
          boolean stripNullable = nullSafe.getNotNullableIndices()[argumentIndex];
          Type coerced = getCoercedArgumentType(nullSafe.getTarget(), argumentIndex, givenType);
          if (stripNullable) {
            return Nullable.strip(coerced);
          }
          return coerced;
        })
        .orElseGet(() -> {
          return function.isA(CoercingFunctionWrapper.class)
            .map(coerced -> {
              // map to the the produced type, or given if not coerced
              Type coercedType = coerced.getCoercers().get(argumentIndex)
                .map(Coercer::getTargetType)
                .orElse(givenType);
              // and check for any deeper mapping
              return getCoercedArgumentType(coerced.getWrapped(), argumentIndex, coercedType);
            })
            .orElse(givenType);
        });
  }

  /**
   * Reapplies wrapping functions recursively from around one function to another.  This method is here to support
   * mixing realizable functions with the adaptFunction routine.  One way to avoid this slightly kludgy method here
   * would be to get adaptFunction to return a description of the wrapping, rather than the wrapped function, but it
   * might be overkill.
   *
   * @param target the function that has already been wrapped
   * @param realized the function that needs to be wrapped in the same way as `target`
   * @param adapted the "outer layer" function.  This might be `target` or it might be a wrapping function
   * @return a function that wraps `realized`, or `realized` itself if `target` isn't wrapped
   */
  private RiskscapeFunction rewrap(
      RiskscapeFunction target,
      RiskscapeFunction realized,
      RiskscapeFunction adapted
  ) {
    if (adapted == target) {
      return realized;
    } else {
      if (adapted instanceof NullSafeFunction) {
        NullSafeFunction nsFunction = (NullSafeFunction) adapted;
        return NullSafeFunction.wrap(
            rewrap(target, realized, nsFunction.getTarget())
          );
      } else if (adapted instanceof CoercingFunctionWrapper) {
        CoercingFunctionWrapper cwFunction = (CoercingFunctionWrapper) adapted;
        return CoercingFunctionWrapper.wrap(
            rewrap(target, realized, cwFunction.getWrapped()),
            cwFunction.getCoercers()

        );
      } else {
        throw new AssertionError("wrong function type " + adapted.getClass());
      }
    }
  }

  private boolean isViableAlternative(RiskscapeFunction alternative, List<Type> givenTypes) {
    // a function alternative may have optional args that can be omitted, but we
    // should never match extra given args against a function that takes fewer
    return givenTypes.size() <= alternative.getArgumentTypes().size();
  }

  /**
   * Searches through the list of riskscape functions, finding the one that best matches the given arguments
   */
  private RiskscapeFunction searchForBestAlternative(
      TypeSet typeSet,
      List<RiskscapeFunction> alternatives,
      List<Type> givenTypes
  ) {
    List<Pair<RiskscapeFunction, List<SetPair>>> functionsAndPairedArgs = alternatives
        .stream()
        .filter(rf -> isViableAlternative(rf, givenTypes))
        .map(rf -> Pair.of(rf,  zipArgTypes(typeSet, givenTypes, rf.getArgumentTypes())))
        .collect(Collectors.toList());
    // find a perfect match
    log.debug("Searching for perfect match for types {}...", givenTypes);
    for (Pair<RiskscapeFunction, List<SetPair>> pair : functionsAndPairedArgs) {
      RiskscapeFunction alternative = pair.getLeft();
      List<SetPair> pairedArgs = pair.getRight();

      if (all(pairedArgs, sp -> sp.testVariance() == Variance.EQUAL)) {
        log.debug("...Found it.", alternative.getArgumentTypes());
        return alternative;
      }
    }

    // find a covariant match
    log.debug("Searching a covariant match for types {}...", givenTypes);
    for (Pair<RiskscapeFunction, List<SetPair>> pair : functionsAndPairedArgs) {
      RiskscapeFunction alternative = pair.getLeft();
      List<SetPair> pairedArgs = pair.getRight();

      if (all(pairedArgs, sp -> sp.isAssignable())) {
        log.debug("...Found it.", alternative.getArgumentTypes());
        return alternative;
      }
    }

    // find an equivalent match
    log.debug("Searching for an equivalent match for types {}...", givenTypes);
    List<RiskscapeFunction> coerced = new ArrayList<>(alternatives.size());
    for (Pair<RiskscapeFunction, List<SetPair>> pair : functionsAndPairedArgs) {
      RiskscapeFunction alternative = pair.getLeft();
      List<SetPair> pairedArgs = pair.getRight();

      List<Optional<Coercer>> adapted = pairedArgs
          .stream()
          .map(SetPair::findEquivalenceCoercer)
          .collect(Collectors.toList());

      if (any(adapted, Optional::isPresent)) {
        coerced.add(CoercingFunctionWrapper.wrap(alternative, adapted));
      }
    }

    if (coerced.size() > 0) {
      log.debug("{} alternatives had coercible equivalents, recursively checking these...", coerced.size());
      RiskscapeFunction found = searchForBestAlternative(typeSet, coerced, givenTypes);
      if (found != null) {
        return found;
      }
    }

    log.debug("... no equivalents could be found, last strategy is to try with nulls striped from given...");

    List<RiskscapeFunction> nullSafes = new ArrayList<>(alternatives.size());
    for (Pair<RiskscapeFunction, List<SetPair>> pair : functionsAndPairedArgs) {
      RiskscapeFunction alternative = pair.getLeft();
      List<SetPair> pairedArgs = pair.getRight();

      if (any(pairedArgs, SetPair::isOnlyGivenNullable)) {
        // TODO we are marking some arguments as nullable, even though they don't need to be, which is
        // adding extra complexity in to the next round of adapting
        nullSafes.add(NullSafeFunction.wrap(alternative));
      }
    }

    if (nullSafes.size() > 0) {
      log.debug("...{} alternatives have possible null-safe version, trying these...");
      return searchForBestAlternative(typeSet, nullSafes, givenTypes);
    } else {
      log.debug("...No null safe alternatives found.");

      return null;
    }
  }

  private RiskscapeFunction adaptFunction1(TypeSet typeSet, RiskscapeFunction rf, List<Type> givenTypes) {
    return searchForBestAlternative(typeSet, Collections.singletonList(rf), givenTypes);
  }

  private List<SetPair> zipArgTypes(TypeSet typeSet, List<Type> givenTypes, List<Type> receiverTypes) {
    Iterator<Type> given = givenTypes.iterator();
    Iterator<Type> received = receiverTypes.iterator();

    List<SetPair> deltas = new ArrayList<SetPair>(Math.min(givenTypes.size(), receiverTypes.size()));
    while (given.hasNext() || received.hasNext()) {
      deltas.add(SetPair.from(typeSet, given, received));
    }

    return deltas;
  }

  /**
   * Normalize the struct return type of the riskscape function so that it's more likely to play nicely with chains of
   * struct-manipulating expressions that expect structs to be normalized.
   */
  private RiskscapeFunction normalizeReturnType(RealizationContext context, FunctionCall fc, RiskscapeFunction rf) {
    Struct returnTypeStruct = rf.getReturnType().findAllowNull(Struct.class).orElse(null);

    if (returnTypeStruct != null) {
      Struct normalized = context.normalizeStruct(returnTypeStruct);
      // NB this is the best we can do without having a general purpose type visiting method for composite types.  It's
      // also just a guard against common errors/problems, and not supposed to be a fool-proof mechanism.  Note that the
      // check/warning here doesn't detect cases where a struct wasn't yet normalized but the call to normalize was able
      // make the given struct the canonical one.
      if (normalized != returnTypeStruct) {
        log.info(
            "function {} ({}) returned a struct which could not be normalized.  RiskscapeFunction implementations"
            + " should normalize any struct return types to avoid possible struct member owner errors",
            fc.getIdentifier(),
            rf
        );
      }
    }

    return rf;
  }

}
