/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.rl;

import static nz.org.riskscape.engine.Assert.*;
import static nz.org.riskscape.engine.Matchers.*;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import org.apache.commons.math3.exception.OutOfRangeException;
import org.junit.Before;
import org.junit.Test;

import nz.org.riskscape.engine.ArgsProblems;
import nz.org.riskscape.engine.Assert;
import nz.org.riskscape.engine.RiskscapeException;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.function.ArgumentList;
import nz.org.riskscape.engine.function.IdentifiedFunction;
import nz.org.riskscape.engine.function.RiskscapeFunction;
import nz.org.riskscape.engine.types.CoercionException;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.RSList;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.engine.types.WithinRange;
import nz.org.riskscape.problem.ProblemSink;
import nz.org.riskscape.rl.ast.ExpressionProblems;

public class MathsFunctionsTest extends BaseExpressionRealizerTest {

  Struct floating = Types.FLOATING.asStruct();
  Struct integer = Types.INTEGER.asStruct();


  @Before
  public void setup() {
    project.getFunctionSet().addAll(LogicFunctions.LOGIC_FUNCTIONS);
    project.getFunctionSet().insertFirst(DefaultOperators.INSTANCE);
    project.getFunctionSet().addAll(MathsFunctions.FUNCTIONS);
    project.getFunctionSet().addAll(LanguageFunctions.FUNCTIONS);
  }

  @Test
  public void min() throws Exception {
    assertEquals(1L, evaluate("min(1, 2)", null));
    assertEquals(1L, evaluate("min(2, 1)", null));
    assertEquals(1D, evaluate("min(1.0, 2.0)", null));
    assertEquals(1D, evaluate("min(2.0, 1.0)", null));
    assertEquals(1D, evaluate("min(1.0, 2)", null));
    assertEquals(1D, evaluate("min(1, 2.0)", null));
  }

  @Test
  public void minWorksAgainstWrappedTypes() throws Exception {
    Struct inputStruct = Struct.of("range", new WithinRange(Types.INTEGER, 0, 10),
        "nullable", Nullable.FLOATING,
        "null_range", Nullable.of(new WithinRange(Types.FLOATING, 0D, 100.0)),
        "phooey", Nullable.TEXT
        ).add("actual_null", Nullable.INTEGER);
    Tuple input = Tuple.ofValues(inputStruct, 1L, 2D, 3D, "4", null);
    assertEquals(1L, evaluate("min(range, 2)", input));
    assertEquals(2D, evaluate("min(3.0, nullable)", input));
    assertEquals(3D, evaluate("min(null_range, 4.0)", input));
    assertEquals(2D, evaluate("min(nullable, null_range)", input));
    assertEquals(1D, evaluate("min(range, null_range)", input));
    assertEquals(1D, evaluate("min(nullable, range)", input));

    // check it's still safe to call with an actual null value
    assertEquals(null, evaluate("min(actual_null, 1)", input));
    assertThat(realizationProblems, empty());

    // and check we don't unwrap non-numeric types
    IdentifiedFunction minFunc = project.getFunctionSet().get("min", ProblemSink.DEVNULL);
    assertEquals(null, evaluate("min(phooey, 1)", input));
    assertThat(realizationProblems, contains(hasAncestorProblem(is(ArgsProblems.get().realizableDidNotMatch(minFunc,
        Arrays.asList(Types.TEXT, Types.INTEGER))))));
  }

  @Test
  public void minListAggregation() throws Exception {
    assertExprEquals(1L, "min([1, 2])", Nullable.INTEGER);
    assertExprEquals(2L, "min([2])", Nullable.INTEGER);
    assertExprEquals(1L, "min([2, 1, 7, 9, 4])", Nullable.INTEGER);
    assertExprEquals(-12L, "min([-1, -12])", Nullable.INTEGER);
    assertExprEquals(-47.0D, "min([9.0, 1.456, -47.0, 7.9])", Nullable.FLOATING);
    assertExprEquals(7D, "min([square_root(49), 10.0, 20.0])", Nullable.FLOATING);
    assertExprEquals(-0.9D, "min([0.01, 0.002, -0.9, 0.1])", Nullable.FLOATING);

    // the mean of an empty list is null
    Tuple emptyList = Tuple.ofValues(Struct.of("value", RSList.create(Nullable.INTEGER)), Arrays.asList());
    assertExprEquals(null, "min(value)", emptyList, Nullable.INTEGER);

    // nullable elements are ignored
    assertExprEquals(1L, "min([1, 2, null_of('integer')])", Nullable.INTEGER);
    assertExprEquals(-1.0D, "min([null_of('floating'), -0.5, -1.0])", Nullable.FLOATING);
  }

  @Test
  public void minListAggregationOnSingleAttributeTuples() {
    Struct itemType = Struct.of("item", Types.INTEGER);
    Type listType = RSList.create(itemType);
    Struct inputType = Struct.of("items", listType);
    Tuple one = Tuple.ofValues(itemType, 1L);
    Tuple two = Tuple.ofValues(itemType, 2L);
    assertExprEquals(one, "min(items)", Tuple.ofValues(inputType, Arrays.asList(one, two)), Nullable.of(itemType));
  }

  @Test
  public void max() throws Exception {
    assertEquals(2L, evaluate("max(1, 2)", null));
    assertEquals(2L, evaluate("max(2, 1)", null));
    assertEquals(2D, evaluate("max(1.0, 2.0)", null));
    assertEquals(2D, evaluate("max(2.0, 1.0)", null));
    assertEquals(2D, evaluate("max(1.0, 2)", null));
    assertEquals(2D, evaluate("max(1, 2.0)", null));
  }

  @Test
  public void maxWorksAgainstWrappedTypes() throws Exception {
    Struct inputStruct = Struct.of("range", new WithinRange(Types.INTEGER, 0, 10),
        "nullable", Nullable.FLOATING,
        "null_range", Nullable.of(new WithinRange(Types.FLOATING, 0D, 100.0)),
        "phooey", Nullable.TEXT)
        .add("actual_null", Nullable.INTEGER);
    Tuple input = Tuple.ofValues(inputStruct, 1L, 2D, 3D, "4", null);
    assertEquals(1L, evaluate("max(range, 0)", input));
    assertEquals(Types.INTEGER, realized.getResultType());
    assertEquals(2D, evaluate("max(1.0, nullable)", input));
    assertEquals(Nullable.FLOATING, realized.getResultType());
    assertEquals(3D, evaluate("max(null_range, 2.0)", input));
    assertEquals(Nullable.FLOATING, realized.getResultType());
    assertEquals(3D, evaluate("max(nullable, null_range)", input));
    assertEquals(Nullable.FLOATING, realized.getResultType());
    assertEquals(3D, evaluate("max(range, null_range)", input));
    assertEquals(Nullable.FLOATING, realized.getResultType());
    assertEquals(2D, evaluate("max(nullable, range)", input));
    assertEquals(Nullable.FLOATING, realized.getResultType());

    // check it's still safe to call with an actual null value
    assertEquals(null, evaluate("max(actual_null, 1)", input));
    assertThat(realizationProblems, empty());

    // and check we don't unwrap non-numeric types
    IdentifiedFunction maxFunc = project.getFunctionSet().get("max", ProblemSink.DEVNULL);
    assertEquals(null, evaluate("max(phooey, 1)", input));
    assertThat(realizationProblems, contains(hasAncestorProblem(is(ArgsProblems.get().realizableDidNotMatch(maxFunc,
        Arrays.asList(Types.TEXT, Types.INTEGER))))));
  }

  @Test
  public void maxListAggregation() throws Exception {
    assertExprEquals(2L, "max([1, 2])", Nullable.INTEGER);
    assertExprEquals(2L, "max([2])", Nullable.INTEGER);
    assertExprEquals(9L, "max([2, 1, 7, 9, 4])", Nullable.INTEGER);
    assertExprEquals(-1L, "max([-1, -12])", Nullable.INTEGER);
    assertExprEquals(9D, "max([9.0, 1.456, -47.0, 7.9])", Nullable.FLOATING);
    assertExprEquals(7D, "max([square_root(49), -1.0, 2.0])", Nullable.FLOATING);
    assertExprEquals(0.1D, "max([0.01, 0.002, -0.9, 0.1])", Nullable.FLOATING);

    // the mean of an empty list is null
    Tuple emptyList = Tuple.ofValues(Struct.of("value", RSList.create(Nullable.INTEGER)), Arrays.asList());
    assertExprEquals(null, "max(value)", emptyList, Nullable.INTEGER);

    // nullable elements are ignored
    assertExprEquals(2L, "max([1, 2, null_of('integer')])", Nullable.INTEGER);
    assertExprEquals(-0.5D, "max([null_of('floating'), -0.5, -1.0])", Nullable.FLOATING);
  }

  @Test
  public void maxErrorCases() throws Exception {
    // check wrong number args
    evaluate("max()", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(is(ArgsProblems.get().wrongNumber(2, 0)))));
    evaluate("max(1, 2, 3)", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(is(ArgsProblems.get().wrongNumber(2, 3)))));

    // right number, but obviously wrong types
    IdentifiedFunction maxFunc = project.getFunctionSet().get("max");
    evaluate("max([1, 2], 3)", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.get().realizableDidNotMatch(maxFunc,
            Arrays.asList(RSList.create(Types.INTEGER), Types.INTEGER))))));
    evaluate("max(1, 'foo')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.get().realizableDidNotMatch(maxFunc,
            Arrays.asList(Types.INTEGER, Types.TEXT))))));
  }

  @Test
  public void maxOfMixedNumberTypes() {
    assertExprEquals(2.0D, "max([1, 2.0])", Nullable.FLOATING);
    assertExprEquals(4.0D, "max([2.0, 4])", Nullable.FLOATING);
  }

  @Test
  public void medianAggregation() {
    // smokes tests for median, more tests in MedianAggregationFunctionTest.java
    assertExprEquals(4.5D, "median([4, 1, 9, 5])", Nullable.FLOATING);
    assertThat(realizationProblems, is(empty()));

    assertExprEquals(5D, "median([7, 1, 5])", Nullable.FLOATING);
    assertThat(realizationProblems, is(empty()));
  }

  @Test
  public void absoluteValue() throws Exception {
    assertEquals(4L, evaluate("abs(4)", null));
    assertEquals(4L, evaluate("abs(-4)", null));
    assertEquals(Types.INTEGER, realized.getResultType());
    assertEquals(4.4D, evaluate("abs(4.4)", null));
    assertEquals(0.3D, evaluate("abs(-0.3)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
  }

  @Test
  public void round() throws Exception {
    ArgumentList expectedArgs = project.getFunctionSet().get("round").getArguments();
    assertEquals(2L, evaluate("round(1.9)", null));
    assertEquals(2L, evaluate("round(2.1)", null));
    assertEquals(2L, evaluate("round(2.49)", null));
    assertEquals(3L, evaluate("round(2.5)", null));
    assertEquals(3L, evaluate("round(2.999)", null));
    assertEquals(Types.INTEGER, realized.getResultType());

    // round to given decimal places
    assertEquals(3.0D, evaluate("round(2.999, 2)", null));
    assertEquals(2.99D, evaluate("round(2.991, 2)", null));
    assertEquals(2.10001D, evaluate("round(2.100005, 5)", null));
    assertEquals(2.49D, evaluate("round(2.49, 4)", null));
    assertEquals(Types.FLOATING, realized.getResultType());

    // negative ndigits works like python
    assertEquals(100.0D, evaluate("round(111.11, -2)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
    // rounding to zero works like python
    assertEquals(3.0D, evaluate("round(3.333, 0)", null));
    assertEquals(Types.FLOATING, realized.getResultType());

    // handles nullable values
    Struct inputType = Struct.of("value", Nullable.FLOATING, "dp", Nullable.INTEGER);
    assertNull(evaluate("round(value, 3)", Tuple.ofValues(inputType)));
    assertEquals(Nullable.FLOATING, realized.getResultType());

    // ndigits must be constant
    assertThat(realizeOnly("round(12.3, dp)", inputType),
        failedResult(hasAncestorProblem(is(ExpressionProblems.get().constantRequired(expressionParser.parse("dp")))))
    );

    // will do int -> float coercion for arg 0
    assertEquals(2L, evaluate("round(2)", null));
    assertEquals(2.0D, evaluate("round(2, 2)", null));

    // ndigits must be integer
    assertThat(realizeOnly("round(12.3, 0.2)", Struct.EMPTY_STRUCT), failedResult(hasAncestorProblem(
        is(TypeProblems.get().mismatch(expectedArgs.get(1),
            Nullable.INTEGER, Types.FLOATING))
    )));

    // error on wrong number of arguments
    assertThat(realizeOnly("round()", Struct.EMPTY_STRUCT),
        failedResult(hasAncestorProblem(is(ArgsProblems.get().wrongNumber(expectedArgs.getArity(), 0))))
    );
    assertThat(realizeOnly("round(12.3, 0.2, 'foo')", Struct.EMPTY_STRUCT),
        failedResult(hasAncestorProblem(is(ArgsProblems.get().wrongNumber(expectedArgs.getArity(), 3))))
    );
  }

  @Test
  public void ceil() throws Exception {
    assertEquals(2L, evaluate("ceil(1.9)", null));
    assertEquals(3L, evaluate("ceil(2.1)", null));
    assertEquals(3L, evaluate("ceil(2.49)", null));
    assertEquals(3L, evaluate("ceil(2.5)", null));
    assertEquals(3L, evaluate("ceil(2.999)", null));
    assertEquals(Types.INTEGER, realized.getResultType());
  }

  @Test
  public void floor() throws Exception {
    assertEquals(1L, evaluate("floor(1.9)", null));
    assertEquals(2L, evaluate("floor(2.1)", null));
    assertEquals(2L, evaluate("floor(2.49)", null));
    assertEquals(2L, evaluate("floor(2.5)", null));
    assertEquals(2L, evaluate("floor(2.999)", null));
    assertEquals(Types.INTEGER, realized.getResultType());
  }

  @Test
  public void test_float() throws Exception {
    //from integer
    assertEquals(1.0D, evaluate("float(1)", null));
    assertEquals(2.0D, evaluate("float(2)", null));
    assertEquals(0D, evaluate("float(0)", null));
    assertEquals(-2D, evaluate("float(-2)", null));

    //from text
    assertEquals(1.0D, evaluate("float('1')", null));
    assertEquals(2.0D, evaluate("float('2')", null));
    assertEquals(0D, evaluate("float('0')", null));
    assertEquals(-2D, evaluate("float('-2')", null));
    assertEquals(Types.FLOATING, realized.getResultType());

    assertThat(assertThrows(EvalException.class, () -> evaluate("float('one')", null)).getCause(),
        instanceOf(CoercionException.class));
  }

  @Test
  public void test_int() throws Exception {
    //from float
    assertEquals(1L, evaluate("int(1.4)", null));
    assertEquals(2L, evaluate("int(1.5)", null));
    assertEquals(0L, evaluate("int(0.49)", null));
    assertEquals(-2L, evaluate("int(-2)", null));

    //from text
    assertEquals(1L, evaluate("int('1')", null));
    assertEquals(2L, evaluate("int('2')", null));
    assertEquals(0L, evaluate("int('0')", null));
    assertEquals(-2L, evaluate("int('-2')", null));
    assertEquals(1L, evaluate("int('1.4')", null));
    assertEquals(2L, evaluate("int('1.5')", null));
    assertEquals(0L, evaluate("int('0.49')", null));
    assertEquals(-2L, evaluate("int('-2')", null));
    assertEquals(Types.INTEGER, realized.getResultType());

    assertThat(assertThrows(EvalException.class, () -> evaluate("int('one')", null)).getCause(),
        instanceOf(CoercionException.class));
  }

  @Test
  public void power() throws Exception {
    assertEquals(1D, evaluate("pow(1, 2)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
    assertEquals(1D, evaluate("pow(1.0, 2)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
    assertEquals(1D, evaluate("pow(1, 2.0)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
    assertEquals(4D, evaluate("pow(2.0, 2.0)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
  }

  @Test
  public void square_root() throws Exception {
    assertExprEquals(3D, "square_root(9)");
    assertExprEquals(7D, "square_root(49.0)");
    assertExprEquals(1D, "square_root(1)");
    assertExprEquals(1.7720D, "square_root(3.14)");
  }

  @Test
  public void log() throws Exception {
    assertEquals(0.6931471805599453D, evaluate("log(2)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
    assertEquals(0.6931471805599453D, evaluate("log(2.0)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
    assertEquals(0.30102999566398114D, evaluate("log(2.0, 10)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
    assertEquals(0.30102999566398114D, evaluate("log(2, 10)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
    assertEquals(0.30102999566398114D, evaluate("log(2, 10.0)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
  }

  @Test
  public void exp() throws Exception {
    // exp(1) is the same as the e constant
    assertExprEquals(Math.E, "exp(1)");
    assertExprEquals(20.08554, "exp(3)");
    assertExprEquals(2.0, "exp(0.69315)");
    // exp() is the inverse of log()
    assertExprEquals(10.0, "exp(log(10))");
    assertExprEquals(10.0, "log(exp(10))");
  }

  @Test
  public void log10() throws Exception {
    assertEquals(0.3010299956639812D, evaluate("log10(2)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
    assertEquals(0.3010299956639812D, evaluate("log10(2.0)", null));
    assertEquals(Types.FLOATING, realized.getResultType());
  }

  private void assertExprEquals(Object expectedValue, String expression, Tuple input, Type expectedType) {
    assertEquals(expectedValue, evaluate(expression, input));
    assertEquals(Collections.emptyList(), realizationProblems);
    assertEquals(expectedType, realized.getResultType());
  }

  private void assertExprEquals(Object expectedValue, String expression, Type expectedType) {
    assertExprEquals(expectedValue, expression, Tuple.EMPTY_TUPLE, expectedType);
  }

  // Double we handle separately because we use a delta/tolerance parameter so as
  // to avoid floating-point precision issues when testing equality
  private void assertExprEquals(double expectedValue, String expression) {
    assertEquals(expectedValue, ((Double) evaluate(expression, null)).doubleValue(), 0.00001);
    assertEquals(Types.FLOATING, realized.getResultType());
    assertEquals(Collections.emptyList(), realizationProblems);
  }

  @Test
  public void normal_pdf() throws Exception {
    // python scipy was used to cross-check these results, i.e.
    // import scipy.stats
    // scipy.stats.norm(50, 1).pdf(50)
    assertExprEquals(0.39894D, "norm_pdf(50, 50, 1)");
    assertExprEquals(0.08567D, "norm_pdf(mean: 0, stddev: 1, x: -1.754)");
    assertExprEquals(0.09879D, "norm_pdf(4.13, 5.48, stddev: 3.79)");
  }

  @Test
  public void lognorm_pdf() throws Exception {
    // unlike the normal curve, log-normal is a right-skewed distribution (so mean
    // is to the right of the curve's peak). We can see this here if we take the PDF
    // the same distance either side of the mean - LHS has higher probability/curve
    // than the RHS
    assertExprEquals(0.33261, String.format("lognorm_pdf(%f, 1.0, 0.5)", Math.exp(0.75)));
    assertExprEquals(0.29353, String.format("lognorm_pdf(%f, 1.0, 0.5)", Math.exp(1.0)));
    assertExprEquals(0.20174, String.format("lognorm_pdf(%f, 1.0, 0.5)", Math.exp(1.25)));

    // Note that equivalent calculation in Python is:
    // scipy.stats.lognorm(s=shape, scale=exp(scale)).pdf(x)
    assertExprEquals(0.22056, "lognorm_pdf(scale: 0.4187, shape: 0.8594, x: 2.0)");
    assertExprEquals(0.41226, "lognorm_pdf(scale: 0.4187, shape: 0.8594, x: 1.0)");
  }

  @Test
  public void normal_cdf() throws Exception {
    // if we use a standard normal curve (mean=0, stddev=1), then the CDF is the
    // zscore, which you can get from any standard zscore table. Or can use python:
    // scipy.stats.norm(0, 1).cdf(0)
    assertExprEquals(0.54379D, "norm_cdf(0.11, 0, 1)");
    assertExprEquals(0.04947D, "norm_cdf(mean: 0.0, stddev: 1.0, x: -1.65)");

    // shifting mean/stddev by a proportionate amount should give the same results
    assertExprEquals(0.54379D, "norm_cdf(200.33, 200, 3)");
    assertExprEquals(0.04947D, "norm_cdf(mean: -50, stddev: 2.0, x: -53.3)");

    // when x == mean, result should always be 0.5
    assertExprEquals(0.5D, "norm_cdf(2.149, 2.149, 1.894)");

    assertExprEquals(0.65542D, "norm_cdf(0.0, -.2, 0.5)");
  }

  @Test
  public void normal_ppf() throws Exception {
    // PPF is the inverse of the CDF. So if we plug in the answers from the CDF test
    // case, then we should get the 'x' value back
    assertExprEquals(0.11D, "norm_ppf(0.543795, 0, 1)");
    assertExprEquals(-1.65D, "norm_ppf(mean: 0.0, stddev: 1.0, x: 0.049471)");

    // shifting mean/stddev by a proportionate amount should give the same results
    assertExprEquals(200.33D, "norm_ppf(0.543795, 200, 3)");
    assertExprEquals(-53.3D, "norm_ppf(mean: -50, stddev: 2.0, x: 0.049471)");

    // 0.5 should always return the mean
    assertExprEquals(2.149D, "norm_ppf(0.5, 2.149, 1.894)");

    assertExprEquals(0.0D, "norm_ppf(0.65542, -.2, 0.5)");

    // should get an error if we use a percentage rather than a probability
    RiskscapeException ex = Assert.assertThrows(RiskscapeException.class,
        () -> evaluate("norm_ppf(50, 0, 1)", null));
    assertThat(ex.getProblem(), isProblem(OutOfRangeException.class));
  }

  @Test
  public void lognormal_cdf() throws Exception {
    // these examples are the same as norm_cdf() - log() has already been applied to
    // the mean (scale) and stddev (shape), we just need to use the exp() of x.
    // Note that python equivalent is:
    // scipy.stats.lognorm(s=shape, scale=exp(scale)).cdf(x)
    assertExprEquals(0.54379D, String.format("lognorm_cdf(%f, 0, 1)", Math.exp(0.11)));
    assertExprEquals(0.04947D, String.format("lognorm_cdf(scale: 0.0, shape: 1.0, x: %f)", Math.exp(-1.65)));

    // when exp(x) == mean, result should always be 0.5
    assertExprEquals(0.5D, String.format("lognorm_cdf(%f, 2.149, 1.894)", Math.exp(2.149)));

    assertExprEquals(0.65542D, String.format("lognorm_cdf(%f, -.2, 0.5)", Math.exp(0.0)));
  }

  @Test
  public void lognormal_ppf() throws Exception {
    // PPF is the inverse of the CDF. So if we plug in the answers from the CDF test
    // case, then we should get the 'x' value back
    assertExprEquals(Math.exp(0.11), String.format("lognorm_ppf(0.543795, 0, 1)"));
    assertExprEquals(Math.exp(-1.65), String.format("lognorm_ppf(scale: 0.0, shape: 1.0, x: 0.049471)"));

    // when exp(x) == mean, result should always be 0.5
    assertExprEquals(Math.exp(2.149), String.format("lognorm_ppf(0.5, 2.149, 1.894)"));

    assertExprEquals(Math.exp(0.0), String.format("lognorm_ppf(0.65542, -.2, 0.5)"));

    // should get an error if we use a percentage rather than a probability
    RiskscapeException ex = Assert.assertThrows(RiskscapeException.class,
        () -> evaluate("lognorm_ppf(50, 0, 1)", null));
    assertThat(ex.getProblem(), isProblem(OutOfRangeException.class));
  }

  @Test
  public void logNormalVsNormalDistribution() throws Exception {
    // In a log normal distribution, the log of the values are normally distributed.
    // This test is mostly to reassure myself that exp(norm_ppf(prob, log(mean), log(stddev)))
    // is equivalent to lognorm_ppf(prob, log(mean), log(stddev))
    assertExprEquals(0.11D, "norm_ppf(0.543795, 0, 1)");
    assertExprEquals(Math.exp(0.11), String.format("lognorm_ppf(0.543795, 0, 1)"));

    // the result here should be 2 x stddev + 1.0, so approx exp(5.0)
    assertExprEquals(148.13951D, String.format("exp(norm_ppf(0.9772, log(%f), log(%f)))",
        Math.exp(1.0D), Math.exp(2.0D)));
    assertExprEquals(148.13951D, String.format("lognorm_ppf(0.9772, 1.0, 2.0)"));
  }

  @Test
  public void polynomial() throws Exception {
    // 4x**2 + 3
    assertEquals(7.0D, evaluate("polynomial(1.0, [4.0, 3.0])", null));
    assertEquals(10.0D, evaluate("polynomial(c: [4.0, 3.0], x: 2.0)", null));
    // x**3 - 2.5x + 7
    assertEquals(26.5D, evaluate("polynomial(c: [7.0, -2.5, 0.0, 1.0], x: 3.0)", null));
  }

  @Test
  public void mean() throws Exception {
    // note that mean always returns floating, even though input list is integer
    assertExprEquals(0D, "mean([-1, 0, 1])", Nullable.FLOATING);
    assertExprEquals(2D, "mean([1.0, 2.0, 3.0])", Nullable.FLOATING);

    // the mean of an empty list is null
    Tuple emptyList = Tuple.ofValues(Struct.of("value", RSList.create(Nullable.INTEGER)), Arrays.asList());
    assertExprEquals(null, "mean(value)", emptyList, Nullable.FLOATING);

    // null elements get omitted from the mean
    assertExprEquals(1.5D, "mean([1, 2, null_of('integer')])", Nullable.FLOATING);
  }

  @Test
  public void meanOfStruct() throws Exception {
    Struct expected = Struct.of("a", Types.FLOATING);
    // mean of two struct
    assertExprEquals(Tuple.ofValues(expected, 1.5D),
        "mean([{a: 1}, {a: 2}])",
        Nullable.of(expected));

    // bit more complicated, as are taking the mean of means
    assertExprEquals(Tuple.ofValues(expected, 2.5D),
        "mean([ mean([{a: 1}, {a: 2}]), mean([{a: 3}, {a: 4}]) ])",
        Nullable.of(expected));
  }

  @Test
  public void meanOfStructList() throws Exception {
    // this test is checking that mean aggregation will work when given a list of struct items.
    // note that the itemTypes struct is not normalized which is in line with bookmarks (they don't
    // normalize the produced structs either.
    Struct itemType = Struct.of("value", Types.INTEGER);
    Struct inputType = Struct.of("a", RSList.create(itemType));
    List<Tuple> values = Arrays.asList(Tuple.ofValues(itemType, 10L), Tuple.ofValues(itemType, 20L));

    Struct expectedType = Struct.of("value", Types.FLOATING);
    assertThat(evaluate("mean(a)", Tuple.ofValues(inputType, values)), is(Tuple.ofValues(expectedType, 15D)));
    assertThat(realized.getResultType(), is(Nullable.of(expectedType)));
  }

  @Test
  public void count() {
    // note that count uses identity, so always returns a non-nullable result
    assertExprEquals(3L, "count([-1, 0, 1])", Types.INTEGER);
    // count is not a maths operation so works with any element type, but null values are not counted
    assertExprEquals(3L, "count([-1, 0, 1.0, null_of('integer')])", Types.INTEGER);
    assertExprEquals(3L, "count(['test', 1, {}, null_of('list(integer)')])", Types.INTEGER);
    assertExprEquals(0L, "count([])", Types.INTEGER);

    // for expressions that yield a boolean result, only true values should be counted
    assertExprEquals(3L, "count([true, false, true, true])", Types.INTEGER);
    assertExprEquals(1L, "count([true, false, false, false])", Types.INTEGER);
    assertExprEquals(0L, "count([false, false, false])", Types.INTEGER);
  }

  @Test
  public void sum() {
    Struct struct = Struct.of("value", RSList.create(Nullable.INTEGER));
    assertExprEquals(null, "sum(value)", Tuple.ofValues(struct, Arrays.asList()), Nullable.INTEGER);
    assertExprEquals(0L, "sum(value)", Tuple.ofValues(struct, Arrays.asList(-1L, 0L, 1L)), Nullable.INTEGER);
    assertExprEquals(6L, "sum(value)", Tuple.ofValues(struct, Arrays.asList(1L, 2L, 3L)), Nullable.INTEGER);
    assertExprEquals(5D, "sum([1.5, 3.5])", Nullable.FLOATING);

    // null elements get omitted from the sum
    assertExprEquals(3L, "sum([1, 2, null_of('integer')])", Nullable.INTEGER);
  }

  @Test
  public void canScaleSimpleTypes() {
    assertExprEquals(1D, "scale(2, 0.5)", Types.FLOATING);
    assertExprEquals(4D, "scale(2, 2)", Types.FLOATING);

    assertExprEquals(1D, "scale(scale_factor: 0.5, input_value: 2)", Types.FLOATING);

    assertExprEquals(4D, "scale(10, 0.4)", Types.FLOATING);
    assertExprEquals(1.25D, "scale(2.5, 0.5)", Types.FLOATING);

    Struct inputType = Struct.of("value", Nullable.INTEGER, "sf", Types.FLOATING);
    assertExprEquals(3D, "scale(value, sf)", Tuple.ofValues(inputType, 12L, 0.25D), Nullable.FLOATING);
    assertExprEquals(null, "scale(value, sf)", Tuple.ofValues(inputType, null, 0.3D), Nullable.FLOATING);

    inputType = Struct.of("value", Types.INTEGER, "sf", Nullable.FLOATING);
    assertExprEquals(3D, "scale(value, sf)", Tuple.ofValues(inputType, 12L, 0.25D), Nullable.FLOATING);
    assertExprEquals(null, "scale(value, sf)", Tuple.ofValues(inputType, 12L, null), Nullable.FLOATING);

    RiskscapeFunction scaleFunction = project.getFunctionSet().get("scale");
    realize(Struct.EMPTY_STRUCT, parse("scale(10, 'cat')"));
    assertNull(realized);
    assertThat(realizationProblems, contains(
        hasAncestorProblem(is(ArgsProblems.mismatch(scaleFunction.getArguments().get(1), Types.TEXT)))
    ));
  }

  @Test
  public void canScaleComplexTypes() {
    Struct resultType = Struct.of("a", Types.TEXT, "b", Types.FLOATING, "c", Types.FLOATING);
    assertExprEquals(Tuple.ofValues(resultType, "cat", 1D, 3D), "scale({a: 'cat', b: 2, c: 6}, 0.5)", resultType);
    assertExprEquals(Tuple.ofValues(resultType, "cat", 1D, 3D), "scale({a: 'cat', b: 2.0, c: 6.0}, 0.5)", resultType);

    // does not scale nested structs
    Struct nestedStruct = Struct.of("x", Types.INTEGER);
    resultType = Struct.of("a", Types.FLOATING, "b", nestedStruct);
    assertExprEquals(Tuple.ofValues(resultType, 5D, Tuple.ofValues(nestedStruct, 3L)),
        "scale({a: 10, b: {x: 3}}, 0.5)", resultType);
  }

  @Test
  public void canScaleComplexTypesWithNullableMembers() {
    Struct complexType = Struct.of("x", Nullable.INTEGER, "foo", Nullable.TEXT, "y", Nullable.FLOATING);
    Struct inputType = Struct.of("value", complexType);

    Struct resultType = Struct.of("x", Nullable.FLOATING, "foo", Nullable.TEXT, "y", Nullable.FLOATING);

    assertExprEquals(Tuple.ofValues(resultType, 3D, "dog", 9D), "scale(value, 0.5)",
        Tuple.ofValues(inputType, Tuple.ofValues(complexType, 6L, "dog", 18L)), resultType);

    assertExprEquals(Tuple.ofValues(resultType, null, "dog", 9D), "scale(value, 0.5)",
        Tuple.ofValues(inputType, Tuple.ofValues(complexType, null, "dog", 18L)), resultType);
  }

  @Test
  public void cannotScaleNonNumericInputs() {
    RiskscapeFunction scaleFunction = project.getFunctionSet().get("scale");
    // you can't scale a cat
    realize(Struct.EMPTY_STRUCT, parse("scale('cat', 0.5)"));
    assertNull(realized);
    assertThat(realizationProblems, contains(
        hasAncestorProblem(is(ArgsProblems.get().notNumeric(scaleFunction.getArguments().get(0), Types.TEXT)))
    ));

    // and you can't scale a struct with no numeric members
    Struct complexType = Struct.of("x", Types.TEXT, "y", Types.TEXT);
    Struct inputType = Struct.of("value", complexType);
    realize(inputType, parse("scale(value, 0.5)"));
    assertNull(realized);
    assertThat(realizationProblems, contains(
        hasAncestorProblem(is(ArgsProblems.get().notNumeric(scaleFunction.getArguments().get(0), complexType)))
    ));

    // even if the struct has nested structs that are scalable
    complexType = Struct.of("x", Types.TEXT, "y", Types.TEXT, "z", Struct.of("a", Types.INTEGER));
    inputType = Struct.of("value", complexType);
    realize(inputType, parse("scale(value, 0.5)"));
    assertNull(realized);
    assertThat(realizationProblems, contains(
        hasAncestorProblem(is(ArgsProblems.get().notNumeric(scaleFunction.getArguments().get(0), complexType)))
    ));
  }


  @Test
  public void canTestNumbersAgainstMagicConstants() throws Exception {
    // integer comparisons
    assertThat(evaluate("0 < maxint()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("maxint() > 0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("maxint() = maxint()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("0 < maxfloat()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("maxfloat() > 0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("maxfloat() = maxfloat()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("0 > minint()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("minint() < 0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("minint() = minint()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("0 > minfloat()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("minfloat() < 0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("minfloat() = minfloat()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("0 < inf()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("inf() > 0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("inf() = inf()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("0 > negative_inf()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("negative_inf() < 0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("negative_inf() = negative_inf()", Tuple.EMPTY_TUPLE), is(true));

    // floating comparisons
    assertThat(evaluate("0.0 < maxint()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("maxint() > 0.0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("maxint() = maxint()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("0.0 < maxfloat()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("maxfloat() > 0.0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("maxfloat() = maxfloat()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("0.0 > minint()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("minint() < 0.0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("minint() = minint()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("0.0 > minfloat()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("minfloat() < 0.0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("minfloat() = minfloat()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("0.0 < inf()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("inf() > 0.0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("inf() = inf()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("0.0 > negative_inf()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("negative_inf() < 0.0", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("negative_inf() = negative_inf()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("maxfloat() < inf()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("maxint() < inf()", Tuple.EMPTY_TUPLE), is(true));

    assertThat(evaluate("minfloat() > negative_inf()", Tuple.EMPTY_TUPLE), is(true));
    assertThat(evaluate("minint() > negative_inf()", Tuple.EMPTY_TUPLE), is(true));
  }

}
