/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.defaults.interp;

import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.junit.Before;
import org.junit.Test;

import com.google.common.collect.Lists;

import nz.org.riskscape.engine.ArgsProblems;
import nz.org.riskscape.engine.Matchers;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.function.ArgumentList;
import nz.org.riskscape.engine.function.FunctionArgument;
import nz.org.riskscape.engine.rl.BaseExpressionRealizerTest;
import nz.org.riskscape.engine.rl.DefaultOperators;
import nz.org.riskscape.engine.rl.LanguageFunctions;
import nz.org.riskscape.engine.rl.MathsFunctions;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.rl.agg.Accumulator;
import nz.org.riskscape.engine.rl.agg.RealizedAggregateExpression;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.RSList;
import nz.org.riskscape.engine.types.ScopedLambdaType;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.engine.util.RandomUtils;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ast.FunctionCall;

public class ContinuousFunctionsTest extends BaseExpressionRealizerTest {

  ArgumentList applyArguments = new ApplyContinuousFunction().asFunction().getArguments();

  static class UnsupportedCurve extends ContinuousFunctionType {
    UnsupportedCurve(int dimensions, RealizedExpression valueExpression, Type returnType) {
      super(dimensions, valueExpression, returnType);
    }
  }
  static final UnsupportedCurve UNSUPPORTED_CURVE_TYPE = new UnsupportedCurve(-99, null, Types.FLOATING);

  @Before
  public void setup() {
    project.getFunctionSet().add(new ApplyContinuousFunction().asFunction().identified("apply_continuous"));
    project.getFunctionSet().add(new CreateContinuousFunction().asFunction().identified("create_continuous"));
    project.getFunctionSet().insertFirst(new DefaultOperators());
    project.getFunctionSet().addAll(MathsFunctions.FUNCTIONS);
    project.getFunctionSet().addAll(LanguageFunctions.FUNCTIONS);
  }

  @Test
  public void canBuildAndApplyABasicFunction() throws Exception {
    evaluate("create_continuous([0, 1, 2, 3, 5, 10], x1 -> x1 * 2)", tuple("{}"));
    assertThat(evaluated, isA(LinearContinuousFunction.class));
    LinearContinuousFunction function = (LinearContinuousFunction) evaluated;
    LinearContinuousFunctionType functionType = (LinearContinuousFunctionType) realized.getResultType();
    // it's all empty to start with
    for (int i = 0; i < functionType.getSize(); i++) {
      assertNull(functionType.getValue(function, i));
    }
    for (int i = 0; i < function.mValues.length; i++) {
      assertNull(function.mValues[i]);
      assertNull(function.cValues[i]);
    }

    assertEquals(functionType.getSize(), 6);
    // there's one less of these as they sit "between" the declared points on the x-axis
    assertEquals(function.mValues.length, 5);
    assertEquals(function.cValues.length, 5);

    assertThat(realized.getResultType(), isA(ContinuousFunctionType.class));

    assertThat(functionType.getReturnType(), equalTo(Types.FLOATING));
    assertThat(functionType.getXValues(), equalTo(new double[] {0, 1, 2, 3, 5, 10}));


    Tuple tuple = Tuple.ofValues(functionType.asStruct(), function);
    assertThat(evaluate("apply_continuous(value, 0.5)", tuple), equalTo(1.0D));
    assertThat(realized.getResultType(), equalTo(Types.FLOATING));
    // the yvalues should now be populated
    assertThat(functionType.getValue(function, 0), equalTo(0D));
    assertThat(functionType.getValue(function, 1), equalTo(2D));
    assertThat(functionType.getValue(function, 2), nullValue());
    assertThat(function.mValues[0], equalTo(2D));
    assertThat(function.cValues[0], equalTo(0D));


    // test that memoizing doesn't do anything weird
    assertThat(evaluate("apply_continuous(value, 0.5)", tuple), equalTo(1.0D));
    assertThat(evaluate("apply_continuous(value, 0.5)", tuple), equalTo(1.0D));
  }

  @Test
  public void canBuildAndApplyABasicStructFunction() throws Exception {
    evaluate("create_continuous([0, 1, 2, 3, 5, 10], x1 -> {a: x1 * 2, b: x1 * 3})", tuple("{}"));
    assertThat(evaluated, isA(ContinuousFunction.class));
    ContinuousFunction function = (ContinuousFunction) evaluated;

    Tuple tuple = Tuple.ofValues(realized.getResultType().asStruct(), function);
    assertThat(evaluate("apply_continuous(value, 0.5)", tuple), equalTo(tuple("{a: 1.0, b: 1.5}")));
    assertThat(evaluate("apply_continuous(value, 1.0)", tuple), equalTo(tuple("{a: 2.0, b: 3.0}")));
    assertThat(realized.getResultType(), equalTo(Struct.of("a", Types.FLOATING, "b", Types.FLOATING)));
  }

  @Test
  public void canBuildAndApplyFromNullableCurve() throws Exception {
    // First we need to build a curve
    evaluate("create_continuous([0, 1, 2, 3, 5, 10], x1 -> x1 * 2)", tuple("{}"));
    assertThat(evaluated, isA(LinearContinuousFunction.class));
    LinearContinuousFunction function = (LinearContinuousFunction) evaluated;
    assertThat(realized.getResultType(), isA(ContinuousFunctionType.class));
    LinearContinuousFunctionType functionType = (LinearContinuousFunctionType) realized.getResultType();
    assertThat(functionType.getReturnType(), equalTo(Types.FLOATING));

    // now apply the using the nullable curve type
    Struct applyType = Struct.of("value", Nullable.of(functionType));
    Tuple nullApply = Tuple.ofValues(applyType);
    assertThat(evaluate("apply_continuous(value, 0.5)", nullApply), nullValue());
    assertThat(realized.getResultType(), equalTo(Nullable.of(Types.FLOATING)));

    Tuple functionApply = Tuple.ofValues(applyType, function);
    assertThat(evaluate("apply_continuous(value, 0.5)", functionApply), equalTo(1.0D));
    assertThat(realized.getResultType(), equalTo(Nullable.of(Types.FLOATING)));
  }

  @Test
  public void canBuildAndApplyABasicFunctionWithMixedNumberTypes() throws Exception {
    evaluate("create_continuous([0, 1.0, 2.0, 3, 5, 10], x1 -> x1 * 2)", tuple("{}"));
    assertThat(evaluated, isA(LinearContinuousFunction.class));
    LinearContinuousFunction function = (LinearContinuousFunction) evaluated;

    assertThat(realized.getResultType(), isA(ContinuousFunctionType.class));
    LinearContinuousFunctionType functionType = (LinearContinuousFunctionType) realized.getResultType();

    assertThat(functionType.getReturnType(), equalTo(Types.FLOATING));
    assertThat(functionType.getXValues(), equalTo(new double[] {0, 1, 2, 3, 5, 10}));


    Tuple tuple = Tuple.ofValues(functionType.asStruct(), function);
    assertThat(evaluate("apply_continuous(value, 0.5)", tuple), equalTo(1.0D));
    // the yvalues should now be populated
    assertThat(functionType.getValue(function, 0), equalTo(0D));
    assertThat(functionType.getValue(function, 1), equalTo(2D));
    assertThat(functionType.getValue(function, 2), nullValue());
    assertThat(function.mValues[0], equalTo(2D));
    assertThat(function.cValues[0], equalTo(0D));


    // test that memoizing doesn't do anything weird
    assertThat(evaluate("apply_continuous(value, 0.5)", tuple), equalTo(1.0D));
    assertThat(evaluate("apply_continuous(value, 0.5)", tuple), equalTo(1.0D));
  }

  @Test
  public void createFailsIfTheLambdaFails() throws Exception {
    evaluate("create_continuous([0,1], x -> call_me(x))", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
      Matchers.equalIgnoringChildren(Problems.foundWith("apply-to"))
    )));
  }

  @Test
  public void createFailsIfTheXValuesAreNotNumbers() throws Exception {
    evaluate("create_continuous(['a', 'b'], x -> 1)", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
      Matchers.equalIgnoringChildren(TypeProblems.get().mismatch("x-values", Number.class, String.class))
    )));
  }

  @Test
  public void createFailsIfTheYValuesCanNotBeAddedAndMultiplied() throws Exception {
    evaluate("create_continuous([1, 2, 3], x -> 'a')", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
      Matchers.equalIgnoringChildren(
        CreateContinuousFunction.PROBLEMS.couldNotCreateContinuousFunctionFromYValue(parse("'a'"), Types.TEXT)
      )
    )));
  }

  @Test
  public void createFailsIfThereAreExtraArguments() throws Exception {
    evaluate("create_continuous([1, 2, 3], x -> x, false, 'bob')", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
      Matchers.equalIgnoringChildren(
          ArgsProblems.get().wrongNumber(3, 4)
      )
    )));
  }

  @Test
  public void createFailsIfThereAreNotEnoughArguments() throws Exception {
    evaluate("create_continuous([1, 2, 3])", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
      Matchers.equalIgnoringChildren(
          ArgsProblems.get().wrongNumber(2, 1)
      )
    )));
  }

  @Test
  public void createFailsIfLambdaHasWrongArguments() throws Exception {
    FunctionArgument applyToArg = new CreateContinuousFunction().asFunction().getArguments().get("apply-to");

    // too many
    evaluate("create_continuous([1, 2, 3], (x, y) -> x * x + 1)", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
      Matchers.equalIgnoringChildren(
          ArgsProblems.mismatch(applyToArg, new ScopedLambdaType(Struct.EMPTY_STRUCT, "x", "y"))
      )
    )));

    // not enough
    evaluate("create_continuous([1, 2, 3], () -> 1)", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
      Matchers.equalIgnoringChildren(
          ArgsProblems.mismatch(applyToArg, new ScopedLambdaType(Struct.EMPTY_STRUCT))
      )
    )));
  }

  @Test
  public void theLambdaExpressionClosesOverInputScope() throws Exception {
    Object function = evaluate("create_continuous([0, 1, 2, 3, 5, 10], x -> x * foo)", tuple("{foo: 2}"));
    Tuple tuple = Tuple.ofValues(realized.getResultType().asStruct(), function);
    assertThat(evaluate("apply_continuous(value, 0.5)", tuple), equalTo(1.0D));
  }

  @Test
  public void multipleThreadsUsingTheMemoizedValuesDoNotAffectTheResults() throws Exception {
    Object func = evaluate("create_continuous([0, 1, 2, 3, 5, 10], x -> x * 2)", tuple("{}"));

    Struct inputType = Struct.of("function", realized.getResultType(), "xvalue", Types.FLOATING);
    realize(inputType, parse("apply_continuous(function, xvalue)"));
    List<Pair<Throwable, Double>> failures = Collections.synchronizedList(new LinkedList<>());

    List<Thread> threads = IntStream.range(0, 4).mapToObj(integer -> new Thread(() -> {
      Tuple tuple = new Tuple(inputType);
      tuple.set(0, func);
      for (int i = 0; i < 10000; i++) {
        double x = RandomUtils.getFromRange(0, 10);
        try {
          tuple.set(1, x);

          Object result = realized.evaluate(tuple);
          assertThat(result, equalTo(x * 2D));
        } catch (Throwable e) {
          failures.add(Pair.of(e, x));
        }
      }
    })).collect(Collectors.toList());

    threads.forEach(Thread::start);

    threads.forEach(t -> {
      try {
        t.join();
      } catch (InterruptedException e) {
        e.printStackTrace();
      }
    });

    assertThat(failures, empty());
  }

  @Test
  public void canBuildTheFunctionWithLogScalingOnTheXAxis() throws Exception {
    Object function = evaluate("create_continuous([0.1, 1.0, 2.0, 3.0, 5.0, 10.0], x -> x * 2, true)", tuple("{}"));
    assertThat(evaluated, isA(ContinuousFunction.class));

    Tuple tuple = Tuple.ofValues(realized.getResultType().asStruct(), function);
    assertThat(evaluate("apply_continuous(value, 0.5)", tuple), equalTo(1.4581460078048338D));
  }

  @Test
  public void applyFailsIfXValueIsNotANumber() throws Exception {
    evaluate("apply_continuous(create_continuous([1, 2, 3], x -> x * x + 1), 'llama')", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
        Matchers.equalIgnoringChildren(ArgsProblems.mismatch(applyArguments.get(1), Types.TEXT))
    )));
  }

  @Test
  public void applyFailsIfXValueIsNullable() throws Exception {
    evaluate("apply_continuous(create_continuous([1, 2, 3], x -> x * x + 1), null_of('floating'))", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
        Matchers.equalIgnoringChildren(ArgsProblems.mismatch(applyArguments.get(1), Nullable.FLOATING))
    )));
  }

  @Test
  public void applyFailsIfExtraArguments() throws Exception {
    evaluate("apply_continuous(create_continuous([1, 2, 3], x -> x * x + 1), 1.5, 'foo')", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
      Matchers.equalIgnoringChildren(ArgsProblems.get().wrongNumber(2, 3))
    )));

    // and again with a numeric extra arg
    evaluate("apply_continuous(create_continuous([1, 2, 3], x -> x * x + 1), 1.5, 2.5)", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
      Matchers.equalIgnoringChildren(ArgsProblems.get().wrongNumber(2, 3))
    )));
  }

  @Test
  public void applyFailsIfMissingArguments() throws Exception {
    evaluate("apply_continuous(create_continuous([1, 2, 3], x -> x * x + 1))", Tuple.EMPTY_TUPLE);
    assertThat(realizationProblems, contains(Matchers.hasAncestorProblem(
      Matchers.equalIgnoringChildren(ArgsProblems.get().wrongNumberRange(2, 3, 1))
    )));
  }

  @Test
  public void applyWithAnOutOfRangeXValueReturnsTheMinOrMax() throws Exception {
    assertEquals(
      2D,
      evaluate("apply_continuous(create_continuous([1, 2, 3], x -> x * x + 1), 0)", Tuple.EMPTY_TUPLE)
    );

    assertEquals(
      10D,
      evaluate("apply_continuous(create_continuous([1, 2, 3], x -> x * x + 1), 100)", Tuple.EMPTY_TUPLE)
    );

    // sanity-check *exactly* the last x value is handled correctly
    assertEquals(
        10D,
        evaluate("apply_continuous(create_continuous([1, 2, 3], x -> x * x + 1), 3)", Tuple.EMPTY_TUPLE)
    );
  }

  @Test
  public void functionsCanBeStackedViaAggregateFunction() throws Exception {
    StackContinuousFunction aggregateFunction = new StackContinuousFunction();
    Struct inputType = Struct.of("alpha", Types.FLOATING);

    FunctionCall fc = parse("stack_continuous(create_continuous([0.1, 0.5, 1.0], x -> x * alpha))")
        .isA(FunctionCall.class).get();


    ResultOrProblems<RealizedAggregateExpression> exprOr = aggregateFunction.realize(realizationContext, inputType, fc);
    assertThat(exprOr, Matchers.result(not(nullValue())));

    RealizedAggregateExpression agg = exprOr.get();
    Accumulator acc = agg.newAccumulator();
    acc.accumulate(tuple("{alpha: 0.1}"));
    acc.accumulate(tuple("{alpha: 0.2}"));
    acc.accumulate(tuple("{alpha: 0.3}"));

    LinearContinuousFunction function = (LinearContinuousFunction) acc.process();

    assertThat(
      Arrays.asList((Object[]) function.values),
      Matchers.isCollectionOf(Double.class, contains(
          closeTo(0.06D, 0.00000001), closeTo(0.3D, 0.00000001), closeTo(0.6D, 0.00000001)
      ))
    );

    // for good measure, let's call the function on it and prove it still works
    evaluate("apply_continuous(function, 0.4)", Tuple.ofValues(Struct.of("function", agg.getResultType()), function));
    assertThat(evaluated, Matchers.instanceOfAnd(Double.class, closeTo(0.24D, 0.00001)));
  }

  @Test
  public void stackingFailsIfUnsupportedCurveType() throws Exception {
    // curve types have to override some methods to support stacking. so stack continuous only lets
    // supported curves though the door.
    StackContinuousFunction aggregateFunction = new StackContinuousFunction();
    Struct inputType = Struct.of("curve", UNSUPPORTED_CURVE_TYPE);

    FunctionCall fc = parse("stack_continuous(curve)")
        .isA(FunctionCall.class).get();

    ResultOrProblems<RealizedAggregateExpression> exprOr = aggregateFunction.realize(realizationContext, inputType, fc);
    assertThat(exprOr, Matchers.failedResult(
      Matchers.hasAncestorProblem(equalTo(
        TypeProblems.get().requiresOneOf(
            "value",
            Lists.newArrayList(LinearContinuousFunctionType.ANY, BilinearContinuousFunctionType.ANY_BILINEAR),
            UNSUPPORTED_CURVE_TYPE
        )
      ))
    ));
  }

  @Test
  public void stackingFailsIfValueIsNotAFunction() throws Exception {
    StackContinuousFunction aggregateFunction = new StackContinuousFunction();
    Struct inputType = Struct.of("alpha", Types.FLOATING);

    FunctionCall fc = parse("stack_continuous([0.1, 0.5, 1.0])")
        .isA(FunctionCall.class).get();

    ResultOrProblems<RealizedAggregateExpression> exprOr = aggregateFunction.realize(realizationContext, inputType, fc);
    assertThat(exprOr, Matchers.failedResult(
      Matchers.hasAncestorProblem(equalTo(
        TypeProblems.get().mismatch("value", ContinuousFunctionType.ANY, RSList.create(Types.FLOATING))
      ))
    ));
  }

  @Test
  public void stackingFailsIfThereAreExtraArguments() throws Exception {
    StackContinuousFunction aggregateFunction = new StackContinuousFunction();
    Struct inputType = Struct.of("alpha", Types.FLOATING);

    FunctionCall fc = parse("stack_continuous(create_continuous([0.1, 0.5, 1.0], x -> x * alpha), 'foo')")
        .isA(FunctionCall.class).get();


    ResultOrProblems<RealizedAggregateExpression> exprOr = aggregateFunction.realize(realizationContext, inputType, fc);
    assertThat(exprOr.getAsSingleProblem(), Matchers.hasAncestorProblem(
      Matchers.equalIgnoringChildren(ArgsProblems.get().wrongNumber(1, 2))
    ));
  }

}
