/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.defaults.function;

import java.util.ArrayList;
import java.util.List;

import com.google.common.collect.Range;

import lombok.RequiredArgsConstructor;
import nz.org.riskscape.defaults.curves.CannotFitCurveException;
import nz.org.riskscape.defaults.curves.ObservedPoints;
import nz.org.riskscape.engine.ArgsProblems;
import nz.org.riskscape.engine.function.ArgumentList;
import nz.org.riskscape.engine.function.BaseRealizableFunction;
import nz.org.riskscape.engine.function.RiskscapeFunction;
import nz.org.riskscape.engine.problem.GeneralProblems;
import nz.org.riskscape.engine.problem.ProblemFactory;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.rl.agg.Accumulator;
import nz.org.riskscape.engine.rl.agg.AggregationFunction;
import nz.org.riskscape.engine.rl.agg.RealizedAggregateExpression;
import nz.org.riskscape.engine.types.RSList;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.engine.util.SegmentedList;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ast.FunctionCall;

/**
 * Calculates an AAL for a hazard-based probabilistic data. We have a series of x,y datapoints,
 * representing the loss and EP (Exceedance Probability). The AAL is the area under the x,y "curve".
 * NB: we just connect the data-points via a straight-line (i.e. piecewise linear data), which
 * simplifies the AAL calculation.
 *
 * For the AAL calculation we use x=EP, y=loss. This seems to produce better results, because
 * the spacing between X datapoints is consistent. Whereas if we used x=loss, y=EP, if we get
 * duplicate losses then we end up with a zero AAL for the duplicates (i.e. there's *no* area
 * under the curve). Refer GL993 for more details.
 */
public class AALHazardBasedFunction extends BaseRealizableFunction implements AggregationFunction {

  public static final Range<Double> PROBABILITY_RANGE = Range.closed(0D, 1D);

  public static final int MIN_NUM_EVENTS = 3;

  public interface LocalProblems extends ProblemFactory {
    Problem duplicateProbability(Double value);

    Problem tooFewEvents(int got, int minimum);
  }

  public static final LocalProblems PROBLEMS = Problems.get(LocalProblems.class);

  private static final ArgumentList ARGUMENTS = ArgumentList.create(
      "loss", Types.FLOATING,
      "ep", Types.FLOATING
  );

  static final ArgumentList SCALAR_ARGUMENTS = ArgumentList.create(
      "loss", RSList.create(Types.FLOATING),
      "ep", RSList.create(Types.FLOATING)
  );

  public AALHazardBasedFunction() {
    super(ARGUMENTS, Types.FLOATING);
  }

  // scalar list-based version
  @Override
  public ResultOrProblems<RiskscapeFunction> realize(RealizationContext context, FunctionCall fc,
      List<Type> givenTypes) {
    return ProblemException.catching(() -> {
      if (givenTypes.size() != arguments.size()) {
        throw new ProblemException(ArgsProblems.get().wrongNumber(arguments.size(), givenTypes.size()));
      }
      RSList lossType = SCALAR_ARGUMENTS.getRequiredAs(givenTypes, 0, RSList.class).getOrThrow();
      RSList epType = SCALAR_ARGUMENTS.getRequiredAs(givenTypes, 1, RSList.class).getOrThrow();

      if (!lossType.getContainedType().isNumeric()) {
        throw new ProblemException(ArgsProblems.mismatch(SCALAR_ARGUMENTS.get(0), givenTypes.get(0)));
      }
      if (!epType.getContainedType().isNumeric() || epType.getContainedType().isNullable()) {
        throw new ProblemException(ArgsProblems.mismatch(SCALAR_ARGUMENTS.get(1), givenTypes.get(1)));
      }
      return RiskscapeFunction.create(this, givenTypes, Types.FLOATING, (args) -> {
        @SuppressWarnings("unchecked")
        List<Number> losses = new ArrayList<>((List<Number>) args.get(0));
        @SuppressWarnings("unchecked")
        List<Number> eps = (List<Number>) args.get(1);

        if (losses.size() != eps.size()) {
          throw new CannotFitCurveException(Problems.foundWith(fc,
              GeneralProblems.get().differentListLengths(SCALAR_ARGUMENTS.get(0), losses.size(),
                  SCALAR_ARGUMENTS.get(1), eps.size())));
        }

        AALCalculator calculator = new AALCalculator(fc);
        for (int i = 0; i < losses.size(); i++) {
          calculator.addLossDataPoint(losses.get(i), eps.get(i));
        }
        return calculator.calculateAAL();
      });
    });
  }

  @Override
  public RiskscapeFunction asFunction() {
    return AggregationFunction.addAggregationTo(this, super.asFunction());
  }

  // aggregate version
  @Override
  public ResultOrProblems<RealizedAggregateExpression> realize(RealizationContext context, Type inputType,
      FunctionCall fc) {

    return ProblemException.catching(() -> {

      RealizedExpression xExpression = realizeArg(context, inputType, fc, "ep");
      RealizedExpression yExpression = realizeArg(context, inputType, fc, "loss");

      // there's no good reason why we should allow a nullable EP. Ignoring a null EP would skew the AAL
      if (xExpression.getResultType().isNullable()) {
        throw new ProblemException(
            TypeProblems.get().mismatch(arguments.get("ep"), Types.FLOATING,
                xExpression.getResultType())
        );
      }
      return RealizedAggregateExpression.create(inputType, Types.FLOATING, fc, () ->
      new AccumImpl(fc, xExpression, yExpression)
      );
    });
  }

  private RealizedExpression realizeArg(
      RealizationContext context,
      Type inputType,
      FunctionCall fc,
      String argName
  ) throws ProblemException {
    RealizedExpression realized = context.getExpressionRealizer()
        .realize(inputType, arguments.getRequiredArgument(fc, argName).getOrThrow().getExpression())
        .getOrThrow(Problems.foundWith(arguments.get(argName)));

    if (!realized.getResultType().isNumeric()) {
      throw new ProblemException(
          TypeProblems.get().mismatch(arguments.get(argName), Types.FLOATING, realized.getResultType())
      );
    }

    return realized;
  }

  private static final class AccumImpl extends AALCalculator implements Accumulator {

    final RealizedExpression xExpression;
    final RealizedExpression yExpression;

    AccumImpl(FunctionCall fc, RealizedExpression xExpression, RealizedExpression yExpression) {
      super(fc);
      this.yExpression = yExpression;
      this.xExpression = xExpression;
    }

    @Override
    public Accumulator combine(Accumulator rhs) {
      AccumImpl impl = (AccumImpl) rhs;
      this.xValues.addAll(impl.xValues);
      this.yValues.addAll(impl.yValues);
      return this;
    }

    @Override
    public void accumulate(Object input) {
      Number probability = (Number) xExpression.evaluate(input);
      Number loss = (Number) yExpression.evaluate(input);
      addLossDataPoint(loss, probability);
    }

    @Override
    public Object process() {
      return calculateAAL();
    }

    @Override
    public boolean isEmpty() {
      return xValues.isEmpty();
    }
  }

  @RequiredArgsConstructor
  private static class AALCalculator {
    // for context when reporting errors
    final FunctionCall fc;

    List<Double> xValues = SegmentedList.forClass(Double.class);
    List<Double> yValues = SegmentedList.forClass(Double.class);

    public void addLossDataPoint(Number loss, Number probability) {
      if (probability == null || probability.doubleValue() <= 0 || probability.doubleValue() >= 1) {
        // help! someone is abusing this function
        // better to throw an error here than just ignore it and output a misleading AAL
        throw new CannotFitCurveException(Problems.foundWith(fc,
            GeneralProblems.get().badValue(probability, ARGUMENTS.get("ep"), PROBABILITY_RANGE)));
      }

      if (xValues.contains(probability.doubleValue())) {
        // it makes no sense to have the same probability multiple times. It's going to screw up the
        // AAL calc (i.e. what's the correct area under the curve when there are multiple Y values in the
        // same X location?). It may indicate something screwy is going on with the input/event data,
        // or that the user just forgot to aggregate by event first
        throw new CannotFitCurveException(Problems.foundWith(fc,
            PROBLEMS.duplicateProbability(probability.doubleValue())));
      }

      // the loss coming out of the Python function might be null, in which case
      // treat it as a zero loss. Ignoring it will impact the AAL because there's
      // suddenly less area under the curve to plot
      if (loss == null) {
        loss = 0;
      }

      if (loss.doubleValue() < 0) {
        // mixing negative and positive losss values could result in a misleading AAL
        throw new CannotFitCurveException(Problems.foundWith(fc,
            GeneralProblems.get().badValue(loss, ARGUMENTS.get("loss"), Range.atLeast(0))));
      }

      xValues.add(probability.doubleValue());
      yValues.add(loss.doubleValue());
    }

    private double calculateAAL(List<Pair<Double, Double>> datapoints) {
      if (datapoints.size() < MIN_NUM_EVENTS) {
        // not enough loss data to produce an AAL
        throw new CannotFitCurveException(
            Problems.foundWith(fc, PROBLEMS.tooFewEvents(datapoints.size(), MIN_NUM_EVENTS)));
      }

      List<Pair<Double, Double>> xyPairs = new ArrayList<>(datapoints);

      // the AAL calculation relies on the loss/EP datapoints being "plotted" in order.
      // We order by x-value, i.e. EP
      xyPairs.sort((a, b) -> a.getLeft().compareTo(b.getLeft()));

      // now calculate the AAL, i.e. the area between the datapoints
      double aal = 0D;
      for (int i = 0; i < xyPairs.size() - 1; i++) {
        double x1 = xyPairs.get(i).getLeft();
        double x2 = xyPairs.get(i + 1).getLeft();
        double y1 = xyPairs.get(i).getRight();
        double y2 = xyPairs.get(i + 1).getRight();

        // we're dealing with a straight-line here, so we can just take the average between
        // the 2 y-values and multiply that by the difference between the 2 x-values, and that
        // will give us the area under the "curve" between this data-point and the next
        aal += (x2 - x1) * ((y1 + y2) / 2);
      }

      // the event data won't extend all the way to EP=0, but we can infer the 'missing' area
      // under the curve between EP=0 and EP=minEP. An EP curve will always continue at the same
      // (loss) value or higher out to infinity RP/EP=0. So we can assume there's an extra data-point
      // with EP=0, loss=maxLoss. The trapezoid calculation (minEp - 0) * (maxLoss + maxLoss) / 2
      // simplifies down to just minEP * maxLoss in this case. Refer to the docs for more on how
      // the hazard-based AAL is calculated
      double minEP = xyPairs.get(0).getLeft();
      double maxLoss = xyPairs.get(0).getRight();
      aal += minEP * maxLoss;

      return aal;
    }

    public double calculateAAL() {
      ObservedPoints points = new ObservedPoints(xValues, yValues);
      return calculateAAL(points.asListOfPairs());
    }
  }
}
