/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.defaults.function;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import com.google.common.collect.Sets;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import nz.org.riskscape.engine.ArgsProblems;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.function.ArgumentList;
import nz.org.riskscape.engine.function.FunctionArgument;
import nz.org.riskscape.engine.problem.GeneralProblems;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.rl.agg.AggregationFunction;
import nz.org.riskscape.engine.rl.agg.RealizedAggregateExpression;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.RSList;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ExpressionParser;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.FunctionCall;
import nz.org.riskscape.rl.ast.FunctionCall.Argument;
import nz.org.riskscape.rl.ast.StructDeclaration;
import nz.org.riskscape.util.ListUtils;

/**
 * Aggregate function for applying the given `select` aggregate expression to rows whose picked
 * value falls between a range of user-defined numbers. This function is essentially a wrapper
 * around the {@link BucketFunction}, but with a simpler, terse argument syntax that is optimized
 * for bucketing according to values within a range.
 */
@RequiredArgsConstructor
public class BucketRange implements AggregationFunction {

  private final AggregationFunction bucketFunction;

  // these two sets mirror what's returned from maxint etc - if these values are in the range,
  // they are given a simpler/human-readable label of '+' or '-'
  private static final Set<Object> MAX_VALUES = Sets.newHashSet(
      Long.MAX_VALUE,
      Double.POSITIVE_INFINITY,
      Double.MAX_VALUE
  );
  private static final Set<Object> MIN_VALUES = Sets.newHashSet(
      Long.MIN_VALUE,
      Double.NEGATIVE_INFINITY,
      -Double.MAX_VALUE
  );

  private static final List<?> INTEGER_BOUNDS = Arrays.asList(Long.MIN_VALUE, Long.MAX_VALUE);

  private static final List<?> FLOATING_BOUNDS = Arrays.asList(-Double.MAX_VALUE, Double.MAX_VALUE);

  @Getter
  private final ArgumentList arguments = ArgumentList.fromArray(
    new FunctionArgument("pick", Types.ANYTHING),
    new FunctionArgument("select", Types.ANYTHING),
    new FunctionArgument("range", RSList.create(Types.ANYTHING)),
    new FunctionArgument("options", Nullable.ANYTHING)
  );

  /**
   * Maps the given args into a {@link BucketFunction} call by turning the `pick` attribute into a
   * lambda expression, and converting the `range` list into a `bucket` struct expression with
   * start/end members for each range band.
   */
  @Override
  public ResultOrProblems<RealizedAggregateExpression> realize(RealizationContext context, Type inputType,
      FunctionCall fc) {

    return ProblemException.catching(() -> {

      Expression rewrittenPickExpression = rewritePick(context, inputType, fc);

      boolean addBounds = getOption(context, fc, "add_bounds", Boolean.class).orElse(true);
      Expression bucketsExpression =
          createBucketExpression(context, arguments.getRequiredArgument(fc, "range").getOrThrow(), addBounds);

      FunctionCall newFunctionCall = new FunctionCall(fc.getIdentifier(), Arrays.asList(
          new Argument(rewrittenPickExpression, "pick"),
          new Argument(arguments.getRequiredArgument(fc, "select").getOrThrow().getExpression(), "select"),
          new Argument(bucketsExpression, "buckets")
      ));

      return bucketFunction.realize(context, inputType, newFunctionCall).getOrThrow();
    });
  }

  private <T> Optional<T> getOption(RealizationContext context, FunctionCall fc, String key, Class<T> expected)
      throws ProblemException {
    // turn the optional options argument into a tuple
    Map<String, Object> options = arguments.getArgument(fc, "options")
        .map(arg -> arg.evaluateConstant(context, Tuple.class, Struct.EMPTY_STRUCT))
        .orElse(ResultOrProblems.of(Tuple.EMPTY_TUPLE))
        .orElse(Tuple.EMPTY_TUPLE)
        .toMap();

    // read the desired attribute and check it's in the correct type
    Object value = options.get(key);
    if (value == null) {
      return Optional.empty();
    }
    if (!expected.isInstance(value)) {
      throw new ProblemException(Problems.foundWith(arguments.get("options"),
          TypeProblems.get().mismatch(value, expected, value.getClass())));
    }
    return Optional.of(expected.cast(value));
  }

  private Expression asExpression(Object obj) {
    // NB we should only ever be turning numbers back into an Expression here, so this should be fine
    return ExpressionParser.INSTANCE.parse(obj.toString());
  }

  /**
   * The range argument is used to define the set of buckets used to group rows in the aggregation.
   * Buckets are formed from list elements, such that `n` elements gives `n - 1` buckets. For example,
   * the list `[0, 1, 2, maxint()]` would yield a bucket struct like:
   * `{range_0_1 : {start: 0, end: 1}, range_1_2: {start: 1, end: 2}, range_2_+" {start: 2, end: maxint()}}`
   */
  private Expression createBucketExpression(RealizationContext context, Argument argument, boolean addBounds)
      throws ProblemException {
    // this validates we were given a list
    List<?> values = argument.evaluateConstant(context, List.class, RSList.LIST_ANYTHING)
        .getOrThrow(Problems.foundWith(arguments.get("range")));

    // if the user hasn't ask for the bounds to be omitted, then they just need to give us one value - this would be
    // 'friendly' mode so doesn't make sense to let them create a range that covers *everything*.  If they've asked for
    // the bounds to be omitted then we trust they know what they're doing and all we need to do is construct at east a
    // single valid bucket
    final int minValues = addBounds ? 1 : 2;
    if (values.size() < minValues) {
      throw new ProblemException(
          Problems.foundWith(arguments.get("range"),
              GeneralProblems.get().badListLength(minValues + "+", values.size())
      ));
    }

    // list items need to all be the same type (numeric)
    RealizedExpression realized = context.getExpressionRealizer().realizeConstant(argument.getExpression())
        .getOrThrow();
    Type memberType = realized.getResultType().find(RSList.class).map(list -> list.getMemberType())
        .orElse(Types.ANYTHING);

    if (!memberType.isNumeric()) {
      throw new ProblemException(Problems.foundWith(arguments.get("range"),
          TypeProblems.get().listItemMismatch(Number.class, memberType)));
    }

    // automatically add min/max bounds, so any values outside the given range don't get missed
    if (addBounds) {
      for (Object limit : memberType == Types.INTEGER ? INTEGER_BOUNDS : FLOATING_BOUNDS) {
        if (!values.contains(limit)) {
          values = ListUtils.append(values, limit);
        }
      }
    }

    // make sure ranges are in increasing order, otherwise the results won't really make sense, NB array gets cloned to
    // make sure it can be sorted - it could be immutable in some cases (see #1398)
    if (!(values instanceof ArrayList)) {
      values = new ArrayList<>(values);
    }
    values.sort((lhs, rhs) -> ((Comparable) lhs).compareTo(rhs));

    List<StructDeclaration.Member> bucketMembers = new ArrayList<>(values.size() - 1);

    // turn the range list items into struct members that define the start/end of each range band
    for (int i = 0; i < values.size() - 1; i++) {
      String leftLabel = valueLabel(values.get(i));
      Expression left = asExpression(values.get(i));

      String rightLabel = valueLabel(values.get(i + 1));
      Expression right = asExpression(values.get(i + 1));

      String label = "range_" + leftLabel +  "_" + rightLabel;
      StructDeclaration rangeExpression = new StructDeclaration(
          Arrays.asList(
              StructDeclaration.jsonStyleMember("start", left),
              StructDeclaration.jsonStyleMember("end", right)
          ),
          Optional.empty()
      );

      bucketMembers.add(StructDeclaration.jsonStyleMember(
          label,
          rangeExpression
      ));
    }

    return new StructDeclaration(bucketMembers, Optional.empty());
  }

  private String valueLabel(Object toLabel) {
    if (MAX_VALUES.contains(toLabel)) {
      return "+";
    }
    if (MIN_VALUES.contains(toLabel)) {
      return "<";
    }

    return toLabel.toString().replace('.', '_'); // Dots are problematic in attribute names
  }

  /**
   * Turns a given `pick` attribute and turns it into a lambda expression that can be used with the
   * {@link BucketFunction}, e.g. `bucket -> pick >= bucket.start && pick < bucket.end`.
   */
  private Expression rewritePick(
      RealizationContext context,
      Type inputType,
      FunctionCall fc
    ) throws ProblemException {

    // we check that the pick expression realizes against the input type - we'll leave it to the bucket range function
    // to do further validation, but this should pick up 80% of type style errors without obfuscating the error (in the
    // way that this function will do)
    RealizedExpression realized = arguments.get("pick").mapFunctionCall(fc,
        arg -> context.getExpressionRealizer().realize(inputType, arg.getExpression())
      ).getOrThrow();

    // we're taking the pick argument and turning it into an >= expression behind the scenes.
    // This won't be immediately obvious to the user, and is going to throw up a convoluted error
    // if pick is anything other than a numeric type (i.e. Text, Struct, etc)
    if (!realized.getResultType().isNumeric()) {
      throw new ProblemException(
          ArgsProblems.get().notNumeric(arguments.get(0), realized.getResultType()));
    }

    return ExpressionParser.parseString(String.format(
        "bucket -> (%s) >= bucket.start && (%s) < bucket.end",
        realized.getExpression().toSource(), realized.getExpression().toSource()
    ));
  }

}
