/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.function;

import static nz.org.riskscape.rl.ExpressionParser.*;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import nz.org.riskscape.engine.ArgsProblems;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.problem.ProblemFactory;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.rl.agg.Accumulator;
import nz.org.riskscape.engine.rl.agg.AggregationFunction;
import nz.org.riskscape.engine.rl.agg.RealizedAggregateExpression;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.eqrule.Coercer;
import nz.org.riskscape.engine.typeset.TypeSet;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ExpressionParser;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.ExpressionProblems;
import nz.org.riskscape.rl.ast.FunctionCall;
import nz.org.riskscape.rl.ast.Lambda;

/**
 * Build an {@link AggregationFunction} based on a set of riskscape expressions that reduce the input into an
 * accumulator value, combine accumulated values (to support parallelism) and then emit a collected value.
 */
@RequiredArgsConstructor @EqualsAndHashCode
public class ExpressionAggregationFunction implements AggregationFunction {

  /**
   * An expression that always returns its input argument (like {@link Function#identity()}, etc)
   */
  private static final Lambda IDENTITY = ExpressionParser.parseString("v -> v").isA(Lambda.class).get();

  public interface LocalProblems extends ProblemFactory {

    static LocalProblems get() {
      return Problems.get(LocalProblems.class);
    }

    Problem couldNotRealizeExpression(String whichOne, Expression expression);

    /**
     * Returned when an expression can't be realized with the given type
     */
    Problem typeNotSupportedForThisFunction(Type valueType);
  }

  /**
   * Wee builder to decrease chances of bugs due to mis-assigned constructor args.
   */
  public static class Builder {
    private Expression identityExpression;
    private Lambda mapExpression;
    private Lambda reduceExpression;
    private Lambda processExpression;


    /**
     * Supply a single arg lambda expression that maps an input value in to an accumulated value that can be reduced.
     * If omitted,
     * an identity  expression is used, i.e. no mapping occurs.  For a `mean` function, we can use a
     * mapping expression of `value -> {count: 1, total: value}`
     */
    public Builder map(String expr) {
      mapExpression = parseLambda(expr, "value");
      return this;
    }

    /**
     * Supply a binary arg lambda expression that reduces two mapped values in to one.
     * For example, the expression `value + last` would be the reduce expression for a sum function.  A mean reduction
     * can be expressed like `(l, r) -> {count: l.count + r.count, total: l.total + r.total}`
     */
    public Builder reduce(String expr) {
      reduceExpression = parseLambda(expr, "lhs", "rhs");
      return this;
    }

    /**
     * Supply a single arg lambda expression that maps an accumulated value in to a return value.  If omitted, an
     * identity expression is used, e.g. the accumulated value is returned.
     *
     * For the mean example, we would process the accumulated value with `value -> value.total / value.count`
     */
    public Builder process(String expr) {
      processExpression = parseLambda(expr, "value");
      return this;
    }

    /**
     * Supply an identity expression for the function.  This expression should yield a constant value for when the
     * aggregate function would otherwise yield null, e.g. on an empty set.  This value will be coerced to the return
     * type of the function if required, e.g. floating to integer.
     */
    public Builder identity(String expr) {
      identityExpression = parseString(expr);
      return this;
    }

    public ExpressionAggregationFunction build() {
      return new ExpressionAggregationFunction(
          identityExpression,
          mapExpression,
          reduceExpression,
          processExpression
      );
    }

    private Lambda parseLambda(String expr, String... defaultArgs) {
      return parseString(expr)
          .isA(Lambda.class)
          .orElseGet(() -> parseString(
              Arrays.asList(defaultArgs)
              .stream()
              .collect(Collectors.joining(", ")) + " -> " + expr
            ).isA(Lambda.class).get()
          );
    }
  }

  /**
   * Construct a new Builder to build an ExpressionSetFunction
   * @return new Builder
   */
  public static Builder builder() {
    return new Builder();
  }

  private final Expression identityExpression;
  private final Lambda mapExpression;
  private final Lambda reduceExpression;
  private final Lambda processExpression;

  /**
   * @return true if an identity expression has been given.  Note that it's possible for this to realize to a nullable
   * type, but it's probably a user error - not sure we want that to happen
   */
  public boolean hasIdentity() {
    return identityExpression != null;
  }

  @Override
  public ResultOrProblems<RealizedAggregateExpression> realize(
      RealizationContext context,
      Type inputType,
      FunctionCall fc
  ) {

    // arity check
    if (fc.getArguments().size() != 1) {
      return ResultOrProblems.failed(ArgsProblems.get().wrongNumber(1, fc.getArguments().size()));
    }

    return ProblemException.catching(() -> {

      // for extracting the value to be aggregated from our input - this is very, very unlikely to happen unless we
      // defer realizing the arg expressions during expression realization
      RealizedExpression valueExpressionR =
          context.getExpressionRealizer().realize(inputType, fc.getArguments().get(0).getExpression())
          .getOrThrow(Problems.foundWith(fc.getArguments().get(0)));

      Type valueType = valueExpressionR.getResultType();

      RiskscapeFunction mapFunction =
          toFunction(context, valueType, "map", mapExpression == null ? IDENTITY : mapExpression, valueType);
      Type mapType = mapFunction.getReturnType();

      RiskscapeFunction reduceFunction =
          toFunction(context, valueType, "reduce", reduceExpression, mapType, mapType);

      RiskscapeFunction processFunction =
          toFunction(context, valueType, "process", processExpression == null ? IDENTITY : processExpression, mapType);

      Object identityValue;
      Type resultType;
      // identity value is optional, but when given it needs to be checked against the result type and potentially
      // coerced
      if (identityExpression != null) {
        RealizedExpression identityRexpr =
          context.getExpressionRealizer().realize(Struct.EMPTY_STRUCT, identityExpression)
            .getOrThrow(LocalProblems.get().couldNotRealizeExpression("identity", identityExpression));

        identityValue = identityRexpr.evaluate(Tuple.EMPTY_TUPLE);

        TypeSet typeSet = context.getProject().getTypeSet();
        Type identityType = identityRexpr.getResultType();
        Type processType = processFunction.getReturnType();

        if (!typeSet.isAssignable(identityType, processType)) {
          Coercer coercer = context.getProject().getTypeSet()
              .findEquivalenceCoercer(identityType, processType)
              .orElseThrow(() -> new ProblemException(
                  LocalProblems.get().typeNotSupportedForThisFunction(valueType).withChildren(
                      Problems.foundWith(identityExpression,
                          TypeProblems.get().couldNotCoerce(identityType, processType))
                  )
              ));

          identityValue = coercer.apply(identityValue);
        }
        // identity value means we won't ever be null
        resultType = Nullable.strip(processType);
      } else {
        resultType = processFunction.getReturnType();
        identityValue = null;
      }

      Object finalIdentityValue = identityValue;
      return RealizedAggregateExpression.create(inputType, resultType, fc,
          () -> new AccumInstance(valueExpressionR, mapFunction, reduceFunction, processFunction, finalIdentityValue));
    });
  }

  /**
   * Builds a riskscape function from the given lambda expression
   * @param part the part of the function being realized - used for problem reporting
   * @param inputType the type of the value being aggregated - used for problem reporting
   */
  private RiskscapeFunction toFunction(
      RealizationContext context,
      Type inputType,
      String part,
      Lambda lambda,
      Type... argTypes
  ) throws ProblemException {
    Struct argsType = context.normalizeStruct(buildStruct(lambda, argTypes));

    RealizedExpression realized =
        context.getExpressionRealizer().realize(argsType, lambda.getExpression())
        .getOrThrow(ps -> LocalProblems.get().typeNotSupportedForThisFunction(inputType).withChildren(
          LocalProblems.get().couldNotRealizeExpression(part, lambda).withChildren(ps)
        ));

    return new RiskscapeFunction() {

      @Override
      public Type getReturnType() {
        return realized.getResultType();
      }

      @Override
      public List<Type> getArgumentTypes() {
        return Arrays.asList(argTypes);
      }

      @Override
      public Object call(List<Object> args) {
        Tuple input = new Tuple(argsType);
        input.setAll(args);

        return realized.evaluate(input);
      }
    };
  }

  /**
   * Constructs a scope type for the lambda function
   * TODO move this to LambdaType?
   */
  private Struct buildStruct(Lambda lambda, Type... argTypes) throws ProblemException {
    if (argTypes.length != lambda.getArguments().size()) {
      throw new ProblemException(ExpressionProblems.get()
          .lambdaArityError(lambda, lambda.getArguments().size(), argTypes.length));
    }

    Struct struct = Struct.of();
    for (int i = 0; i < argTypes.length; i++) {
      struct = struct.add(lambda.getArguments().get(i).getValue(), argTypes[i]);
    }

    return struct;
  }

  @Override
  public String toString() {
    return String.format(
        "ExpressionAggregationFunction(identity=%s, map=%s, reduce=%s, process=%s)",
        identityExpression == null ? "None" : identityExpression.toSource(),
        mapExpression == null ? "None" : mapExpression.toSource(),
        reduceExpression == null ? "None" : reduceExpression.toSource(),
        processExpression == null ? "None" : processExpression.toSource()
    );
  }

  private static class AccumInstance implements Accumulator {

    private final RealizedExpression valueExpression;
    private final RiskscapeFunction map;
    private final RiskscapeFunction reduce;
    private final RiskscapeFunction process;
    private final Object identity;
    private Object accumulated;

    @Getter
    private boolean empty;

    AccumInstance(RealizedExpression valueExpression, RiskscapeFunction map,
        RiskscapeFunction reduce, RiskscapeFunction process, Object identity
    ) {
      this.valueExpression = valueExpression;
      this.map = map;
      this.reduce = reduce;
      this.process = process;
      this.identity = identity;
      // we aren't empty if there's an identity value - process will return something
      this.empty = identity == null;
    }

    @Override
    public Accumulator combine(Accumulator other) {
      AccumInstance rhs = (AccumInstance) other;

      if (rhs.accumulated == null) {
        return this;
      } else if (this.accumulated == null) {
        return rhs;
      } else {
        AccumInstance cloned = new AccumInstance(valueExpression, map, reduce, process, rhs);
        cloned.accumulated = reduce.call(Arrays.asList(this.accumulated, rhs.accumulated));
        return cloned;
      }
    }

    @Override
    public void accumulate(Object input) {
      Object value = valueExpression.evaluate(input);
      // even a null mapping is not empty
      empty = false;
      // null value, skip it
      // NB is there some way we can signal that the mapping can accept a null input?
      if (value == null) {
        return;
      }

      Object mapped = map.call(Arrays.asList(value));

      if (accumulated == null) {
        accumulated = mapped;
      } else {
        if (mapped != null) {
          // NB is there some way we can signal that reduction can accept a null input?
          accumulated = reduce.call(Arrays.asList(accumulated, mapped));
        }
      }
    }

    @Override
    public Object process() {
      if (accumulated == null) {
        return identity;
      } else {
        Object processed = process.call(Arrays.asList(accumulated));
        if (processed == null) {
          return identity;
        } else {
          return processed;
        }
      }
    }
  }

}
