/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.defaults.function;

import static nz.org.riskscape.engine.Assert.*;
import static nz.org.riskscape.engine.Matchers.*;

import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.commons.math3.util.Precision;
import org.junit.Before;
import org.junit.Test;

import com.google.common.collect.Lists;

import nz.org.riskscape.engine.ArgsProblems;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.problem.GeneralProblems;
import nz.org.riskscape.engine.rl.BaseExpressionRealizerTest;
import nz.org.riskscape.engine.rl.EvalException;
import nz.org.riskscape.engine.types.RSList;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.Types;

public class LossesByPeriodTest extends BaseExpressionRealizerTest {

  LossesByPeriod function = new LossesByPeriod();

  @Before
  public void setup() {
    project.getFunctionSet().add(new LossesByPeriod().asFunction().identified("losses_by_period"));
  }

  @Test
  public void canCalculateReturnPeriodsForEvents() {
    // the RP for event n is pretty straight-forward: investigation-time / n
    assertThat(function.calculateReturnPeriodsForEvents(10000, 4),
        is(new double[] {10000, 5000, 10000 / 3D, 2500}));
    assertThat(function.calculateReturnPeriodsForEvents(5000, 5),
        is(new double[] {5000, 2500, 5000 / 3D, 1250, 1000}));
    assertThat(function.calculateReturnPeriodsForEvents(1000, 10),
        is(new double[] {1000, 500, 1000 / 3D, 250, 200, 1000 / 6D, 1000 / 7D, 125, 1000 / 9D, 100}));

    // spot-check a larger event set
    double[] periods = function.calculateReturnPeriodsForEvents(10000, 20000);
    assertThat(Precision.round(periods[138], 8), is(71.94244604D));
    assertThat(Precision.round(periods[44], 8), is(222.22222222D));
    assertThat(Precision.round(periods[20], 8), is(476.19047619D));
  }

  @Test
  public void canFindLossClosestToReturnPeriod() {
    // each loss increments by 1 from 1 (lowest) to 10000 (highest)
    List<Integer> losses = IntStream.range(1, 10001).boxed()
        .sorted(Collections.reverseOrder())
        .collect(Collectors.toList());
    double[] periods = function.calculateReturnPeriodsForEvents(10000, losses.size());

    // RP=5000 == 2nd largest event exactly
    assertThat(function.findLossClosestToReturnPeriod(5000D, periods, losses), is(10001 - 2));
    // RP=2475 ~= 4th largest event (2500 year event)
    assertThat(function.findLossClosestToReturnPeriod(2475D, periods, losses), is(10001 - 4));
    // RP=2916 is still closest to 4th largest event rather than 3rd (3333.33)
    assertThat(function.findLossClosestToReturnPeriod(2916D, periods, losses), is(10001 - 4));
    // RP=100 == 100th largest event exactly
    assertThat(function.findLossClosestToReturnPeriod(100D, periods, losses), is(10001 - 100));
    // RP=72 ~= 139th largest event (71.9 year event)
    assertThat(function.findLossClosestToReturnPeriod(72D, periods, losses), is(10001 - 139));
    // RP=21.7 ~= 461st largest event (21.69 year event)
    assertThat(function.findLossClosestToReturnPeriod(21.7D, periods, losses), is(10001 - 461));

    // lets test the curving end of the distribution. we can also use these results to compare to the
    // results in canGetPercentileLossForReturnPeriod() below to see that both modes have a similar
    // shape to the distribution
    assertThat(function.findLossClosestToReturnPeriod(10D, periods, losses), is(9001));
    assertThat(function.findLossClosestToReturnPeriod(4D, periods, losses), is(7501));
    assertThat(function.findLossClosestToReturnPeriod(2D, periods, losses), is(5001));
    assertThat(function.findLossClosestToReturnPeriod(1.333D, periods, losses), is(2499));
  }

  @Test
  public void canGetPercentileLossForReturnPeriod() {
    // each loss increments by 1 from 1 (lowest) to 10000 (highest)
    List<Number> losses = IntStream.range(1, 10001).boxed()
        .sorted(Collections.reverseOrder())
        .collect(Collectors.toList());
    LossesByPeriod.LossFunction lossFunction = function.percentileBuilder().build(losses, 10000);

    // percentile = (1 - 1/rp) * 100 so
    // 100 -> 99th
    checkPercentileLoss(lossFunction, 100D, 9900.01);
    // 50 -> 98th
    checkPercentileLoss(lossFunction, 50D, 9800.02);
    // 20 -> 95th
    checkPercentileLoss(lossFunction, 20D, 9500.05);
    // 10 -> 90th
    checkPercentileLoss(lossFunction, 10D, 9000.1);
    // 4 -> 75th
    checkPercentileLoss(lossFunction, 4D, 7500.25);
    // 2 -> 50th
    checkPercentileLoss(lossFunction, 2D, 5000.5);
    // 1.3333 ~> 25th
    checkPercentileLoss(lossFunction, 1.333D, 2498.874718);

    // and all these are up in the 99th plus range
    checkPercentileLoss(lossFunction, 10000D, 9999.0001);
    checkPercentileLoss(lossFunction, 5000D, 9998.0002);
    checkPercentileLoss(lossFunction, 2500D, 9996.0004);
    checkPercentileLoss(lossFunction, 2000D, 9995.0005);
    checkPercentileLoss(lossFunction, 2125D, 9995.294588);
    checkPercentileLoss(lossFunction, 2250D, 9995.556);
  }

  @Test
  public void percentileLossesTendsToZeroForSmallRP() {
    // let's try with only 100 losses this time
    List<Number> losses = IntStream.range(1, 101).boxed()
        .sorted(Collections.reverseOrder())
        .collect(Collectors.toList());
    LossesByPeriod.LossFunction lossFunction = function.percentileBuilder().build(losses, 100);

    // sanity check those bigger RP values
    checkPercentileLoss(lossFunction, 2500D, 99.9604D);
    checkPercentileLoss(lossFunction, 100D, 99.01D);

    // now lets see how those smaller RPs tend down to zero
    checkPercentileLoss(lossFunction, 5D, 80.2);
    checkPercentileLoss(lossFunction, 4D, 75.25D);
    checkPercentileLoss(lossFunction, 3D, 67.0);
    checkPercentileLoss(lossFunction, 2D, 50.5D);
    checkPercentileLoss(lossFunction, 1.5D, 34.0);
    checkPercentileLoss(lossFunction, 1.2D, 17.4999999);
    checkPercentileLoss(lossFunction, 1.1D, 10.0);
    checkPercentileLoss(lossFunction, 1D, 0D);
  }

  private void checkPercentileLoss(LossesByPeriod.LossFunction lossFunction, double rp, double expected) {
    assertThat(
        lossFunction.sample(rp).doubleValue(),
        closeTo(expected, 0.00001)
    );
  }

  @Test
  public void willReturnZeroIfReturnPeriodExceeded() {
    // let's try with only 100 losses this time
    List<Integer> losses = IntStream.range(1, 101).boxed()
        .sorted(Collections.reverseOrder())
        .collect(Collectors.toList());
    double[] periods = function.calculateReturnPeriodsForEvents(10000, losses.size());

    // RP=2475 ~= 4th largest event (2500 year event)
    assertThat(function.findLossClosestToReturnPeriod(2475D, periods, losses), is(101 - 4));
    // RP=100 == 100th largest event exactly
    assertThat(function.findLossClosestToReturnPeriod(100D, periods, losses), is(1));
    // RP=99.9 exceeds the events we have available (100)
    assertThat(function.findLossClosestToReturnPeriod(99.9D, periods, losses), is(0));
    // RP=72 ~= 139th largest event (not enough events)
    assertThat(function.findLossClosestToReturnPeriod(72D, periods, losses), is(0));
    // RP=21.7 ~= 139th largest event (not enough events)
    assertThat(function.findLossClosestToReturnPeriod(21.7D, periods, losses), is(0));
  }

  @Test
  public void canGetIntegerLossesForReturnPeriods() {
    // note we pass the losses in unsorted here
    List<Long> losses = IntStream.range(1, 101).boxed()
        .map(l -> l * 1000000L)
        .collect(Collectors.toList());
    Tuple input = Tuple.ofValues(Struct.of("losses", RSList.create(Types.INTEGER)), losses);

    // RPS == 4th, 10th, 100th 200th events
    // which corresponds to losses $(101 - 4)m , $(101 - 10)m, $(101 - 100)m, N/A
    assertThat(evaluate("losses_by_period(losses, [2475, 1000, 100, 50], 10000, {mode: 'ranked_closest'})", input),
        is(Arrays.asList(97000000L, 91000000L, 1000000L, 0L)));
    assertThat(realized.getResultType(), is(RSList.create(Types.INTEGER)));

    // pass RPs as floats and mix args order
    // RPs == 21st, 45th, 139th events, so losses $(101 - 21)m, $(101 - 45)m, N/A
    assertThat(evaluate("losses_by_period(return-periods: [476.19, 222.22, 71.9], "
        + "investigation-time: 10000, losses: losses, options: {mode: 'ranked_closest'})", input),
        is(Arrays.asList(80000000L, 56000000L, 0L)));
    assertThat(realized.getResultType(), is(RSList.create(Types.INTEGER)));
  }

  @Test
  public void canGetFloatingLossesForReturnPeriods() {
    List<Double> losses = IntStream.range(1, 101).boxed()
        .map(l -> l * 1000000.01D)
        .collect(Collectors.toList());
    Tuple input = Tuple.ofValues(Struct.of("losses", RSList.create(Types.FLOATING)), losses);

    // same losses/RPs as last test, except now we expect cents on the end
    assertThat(evaluate("losses_by_period(losses, [2475, 1000, 100, 50], 10000, {mode: 'ranked_closest'})", input),
        is(Arrays.asList(97000000.97D, 91000000.91D, 1000000.01D, 0D)));
    assertThat(realized.getResultType(), is(RSList.create(Types.FLOATING)));

    // same as list test but some of the return periods are floats
    assertThat(evaluate("losses_by_period(losses, [2475.0, 1000.0, 100, 50], 10000, {mode: 'ranked_closest'})", input),
        is(Arrays.asList(97000000.97D, 91000000.91D, 1000000.01D, 0D)));
    assertThat(realized.getResultType(), is(RSList.create(Types.FLOATING)));

    // pass RPs as floats and mix args order
    assertThat(evaluate("losses_by_period(return-periods: [476.19, 222.22, 71.9], "
        + "investigation-time: 10000, losses: losses, options: {mode: 'ranked_closest'})", input),
        is(Arrays.asList(80000000.80D, 56000000.56D, 0D)));
    assertThat(realized.getResultType(), is(RSList.create(Types.FLOATING)));
  }

  @Test
  public void canGetPercentileInterpolatedLosses() {
    List<Double> losses = IntStream.range(1, 101).boxed()
        .map(l -> l * 1000000.01D)
        .collect(Collectors.toList());
    Tuple input = Tuple.ofValues(Struct.of("losses", RSList.create(Types.FLOATING)), losses);

    // same losses/RPs as last test, but now floating point percentile interpolated results
    assertThat(evaluate("losses_by_period(losses, [2475, 1000, 100, 50], 10000, {mode: 'percentile'})", input),
        is(Arrays.asList(9.99600009996E7, 9.990100099901001E7, 9.90100009901E7, 9.80200009802E7)));
    assertThat(realized.getResultType(), is(RSList.create(Types.FLOATING)));

    // percentile interpolation is the default
    assertThat(evaluate("losses_by_period(losses, [2475, 1000, 100, 50], 10000)", input),
        is(Arrays.asList(9.99600009996E7, 9.990100099901001E7, 9.90100009901E7, 9.80200009802E7)));
    assertThat(realized.getResultType(), is(RSList.create(Types.FLOATING)));

    // pass RPs as floats and mix args order
    assertThat(evaluate("losses_by_period(return-periods: [2475, 1000, 100, 50], "
        + "investigation-time: 10000, losses: losses)", input),
        is(Arrays.asList(9.99600009996E7, 9.990100099901001E7, 9.90100009901E7, 9.80200009802E7)));
    assertThat(realized.getResultType(), is(RSList.create(Types.FLOATING)));
  }

  @Test
  public void mustHaveCorrectNumberOfArgs() {
    evaluate("losses_by_period()", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.get().wrongNumberRange(3, 4, 0))
    )));
    evaluate("losses_by_period(33,33)", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.get().wrongNumberRange(3, 4, 2))
    )));
    evaluate("losses_by_period([1, 2, 3], [4, 5, 5], 1000, 1001, 1001)", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.get().wrongNumberRange(3, 4, 5))
    )));
  }

  @Test
  public void argsMustBeNumeric() {
    evaluate("losses_by_period(['foo'], [1000], 10000)", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(function.getArguments().get(0), RSList.create(Types.TEXT)))
    )));

    // RPs is not a list
    evaluate("losses_by_period([1.0, 2.0], 1000, 10000)", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(function.getArguments().get(1), Types.INTEGER))
    )));

    // investigation-time here is floating (we don't bother supporting that for now)
    evaluate("losses_by_period([1.0, 2.0], [100], 1000.0)", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(function.getArguments().get(2), Types.FLOATING))
    )));
  }

  @Test
  public void failsWhenOptionsAreBad() {
    // bad type
    evaluate("losses_by_period([10], [10], 10000, {mode: 10})", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(GeneralProblems.get().notAnOption("10", LossesByPeriod.Mode.class,
            Lists.newArrayList(LossesByPeriod.Mode.class.getEnumConstants())))
    )));

    // unknown options
    evaluate("losses_by_period([10], [10], 10000, {foo: 10, bar: 20})", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(GeneralProblems.get().notAnOption("foo", "options", Lists.newArrayList("mode")))
    )));

    // not struct
    evaluate("losses_by_period([10], [10], 10000, 'foo')", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(TypeProblems.get().mismatch(parser.parse("'foo'"), function.getArguments().get("options").getType(),
            Types.TEXT))
    )));
  }

  @Test
  public void cannotRequestRPLargerThanInvestigationTime() {
    // if the investigation period is 10,000 years, you cannot ask for a one in 10,001 year event.
    // The error handling could be better here, but it's also a dumb thing for the user to do
    EvalException ex = assertThrows(EvalException.class,
        () -> evaluate("losses_by_period([1], [10001], 10000)", Tuple.EMPTY_TUPLE));
    assertThat(ex.getCause().getClass(), is(IllegalArgumentException.class));

    ex = assertThrows(EvalException.class,
        () -> evaluate("losses_by_period([1], [10001], 10000, {mode: 'ranked_closest'})", Tuple.EMPTY_TUPLE));
    assertThat(ex.getCause().getClass(), is(IllegalArgumentException.class));
  }
}
