/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.rl;

import static nz.org.riskscape.engine.Matchers.*;

import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;

import java.util.List;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;

import com.google.common.collect.Lists;

import nz.org.riskscape.engine.ArgsProblems;
import nz.org.riskscape.engine.Assert;
import nz.org.riskscape.engine.RiskscapeException;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.function.ArgumentList;
import nz.org.riskscape.engine.types.RSList;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.test.RetryRule;

public class RandomFunctionsTest extends BaseExpressionRealizerTest {

  @Rule
  public RetryRule retry = new RetryRule();

  @Before
  public void setup() {
    project.getFunctionSet().addAll(LogicFunctions.LOGIC_FUNCTIONS);
    project.getFunctionSet().insertFirst(DefaultOperators.INSTANCE);
    project.getFunctionSet().addAll(MathsFunctions.FUNCTIONS);
    project.getFunctionSet().addAll(LanguageFunctions.FUNCTIONS);
  }

  @Test
  public void randomChoice() throws Exception {
    // just check we can call the function and get an item back
    List<String> items = Lists.newArrayList("a", "b", "c");
    Object result = evaluate("random_choice(['a', 'b', 'c'])", null);
    assertTrue(items.contains(result));
    assertEquals(Types.TEXT, realized.getResultType());
    List<Long> numbers = Lists.newArrayList(1L, 2L, 3L);
    result = evaluate("random_choice(items: [1, 2, 3])", null);
    assertTrue(numbers.contains(result));
    assertEquals(Types.INTEGER, realized.getResultType());
  }

  @Test
  public void weightedRandomChoice() throws Exception {
    List<String> items = Lists.newArrayList("a", "b", "c");
    Object result = evaluate("random_choice(['a', 'b', 'c'], [0.5, 0.3, 0.2])", null);
    assertTrue(items.contains(result));
    result = evaluate("random_choice(weights: [0.5, 0.3, 0.2], items: ['a', 'b', 'c'])", null);
    assertTrue(items.contains(result));
    // now let's force its hand and make it give us the first item
    result = evaluate("random_choice(weights: [1.0, 0.0, 0.0], items: ['a', 'b', 'c'])", null);
    assertEquals("a", result);
    // sanity-check corner-case
    assertEquals(null, evaluate("random_choice([])", null));

    // also check it rejects a mismatch in list size
    Assert.assertThrows(RiskscapeException.class,
        () -> evaluate("random_choice(['a', 'b', 'c'], [0.5, 0.5]", null));
  }

  @Test
  public void seededRandomChoice() throws Exception {
    // check that using a random seed produces consistent results
    Object result1 = evaluate("random_choice(range(0, 1000), seed: 123)", null);
    Object result2 = evaluate("random_choice(range(0, 1000), seed: 789)", null);
    assertNotEquals(result1, result2);
    for (int i = 0; i < 10; i++) {
      assertEquals(result1, evaluate("random_choice(range(0, 1000), seed: 123)", null));
      assertEquals(result2, evaluate("random_choice(range(0, 1000), seed: 789)", null));
    }

    result1 = evaluate(
        "random_choice(weights: map(range(0, 1000), x -> 1 /1000), items: range(0, 1000), seed: 123)",
        Tuple.EMPTY_TUPLE);
    result2 = evaluate(
        "random_choice(seed: 789, items: range(0, 1000), weights: map(range(0, 1000), x -> 1 /1000))",
        Tuple.EMPTY_TUPLE);
    assertNotEquals(result1, result2);

    for (int i = 0; i < 10; i++) {
      assertEquals(result1, evaluate(
          "random_choice(weights: map(range(0, 1000), x -> 1 /1000), items: range(0, 1000), seed: 123)",
          Tuple.EMPTY_TUPLE));
      assertEquals(result2, evaluate(
          "random_choice(seed: 789, items: range(0, 1000), weights: map(range(0, 1000), x -> 1 /1000))",
          Tuple.EMPTY_TUPLE));
    }
  }

  @Test
  public void randomChoiceExpectsCorrectArgumentNumber() throws Exception {
    ArgumentList expectedArgs = project.getFunctionSet().get("random_choice").getArguments();
    evaluate("random_choice()", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.get().wrongNumber(expectedArgs.getArity(), 0))
    )));

    evaluate("random_choice(['foo'], [0.1], 123, 'baz')", null);
    assertThat(realizationProblems,
        contains(hasAncestorProblem(is(ArgsProblems.get().wrongNumber(expectedArgs.getArity(), 4)))));

  }

  @Test
  public void randomChoiceExpectsCorrectArgumentTypes() throws Exception {
    ArgumentList expectedArgs = project.getFunctionSet().get("random_choice").getArguments();
    evaluate("random_choice('foo')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get(0), Types.TEXT))
    )));

    evaluate("random_choice(['foo'], 'bar')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get(1), Types.TEXT))
    )));

    evaluate("random_choice(['foo'], ['bar'])", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get(1), RSList.create(Types.TEXT))
    ))));

    evaluate("random_choice(['foo'], [0.1], 'bar')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get(2), Types.TEXT)
    ))));
  }

  @Test
  public void randomUniform() throws Exception {
    // check we can generate random numbers in the range given
    assertThat((Double) evaluate("random_uniform(0.5, 0.9)", null),
        allOf(greaterThanOrEqualTo(0.5), lessThanOrEqualTo(0.9)));
    assertThat((Double) evaluate("random_uniform(start: -1000, stop: -100)", null),
        allOf(greaterThanOrEqualTo(-1000.0), lessThanOrEqualTo(-100.0)));
    assertThat((Double) evaluate("random_uniform(stop: 999, start: 111)", null),
        allOf(greaterThanOrEqualTo(111.0), lessThanOrEqualTo(999.0)));
  }

  @Test
  public void seededRandomUniformProducesConsistentResults() throws Exception {
    Object result1 = evaluate("random_uniform(0, 10, seed: 123)", null);
    Object result2 = evaluate("random_uniform(0, 10, seed: 456)", null);
    assertNotEquals(result1, result2);
    for (int i = 0; i < 10; i++) {
      assertEquals(result1, evaluate("random_uniform(0, 10, seed: 123)", null));
      assertEquals(result2, evaluate("random_uniform(0, 10, seed: 456)", null));
    }
  }

  @Test
  public void randomUniformCanReturnMultipleSamples() throws Exception {
    assertThat(evaluate("length(random_uniform(0, 10, samples: 100))", null), is(100L));
    // overall mean between 0 and 100 should be about 50
    assertThat((Double) evaluate("mean(random_uniform(0, 100, samples: 1000))", null),
        allOf(greaterThanOrEqualTo(48.0), lessThanOrEqualTo(52.0)));
    assertThat((Double) evaluate("min(random_uniform(0, 100, samples: 1000))", null),
        allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(2.0)));
    assertThat((Double) evaluate("max(random_uniform(0, 100, samples: 1000))", null),
        allOf(greaterThanOrEqualTo(98.0), lessThanOrEqualTo(100.0)));

    // NB: 1 sample is still a list - samples could be a model parameter, and so
    // changing the return type based on size could break the pipeline
    assertThat(evaluate("length(random_uniform(0, 10, samples: 1))", null), is(1L));

    // check we don't crash if the user specifies dumb args
    assertThat(evaluate("length(random_uniform(0, 10, samples: 0))", null), is(0L));
    assertThat(evaluate("length(random_uniform(0, 10, samples: -1))", null), is(0L));

    // using a seed should produce the same result
    Object result1 = evaluate("max(random_uniform(0, 100, samples: 1000, seed: 123))", null);
    assertEquals(result1, evaluate("max(random_uniform(seed: 123, stop: 100, samples: 1000, start: 0))", null));
    assertNotEquals(result1, evaluate("min(random_uniform(0, 100, seed: 123, samples: 1000))", null));
  }

  @Test
  public void randomUniformExpectsCorrectArgumentNumber() throws Exception {
    ArgumentList expectedArgs = project.getFunctionSet().get("random_uniform").getArguments();
    evaluate("random_uniform()", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.get().wrongNumber(expectedArgs.getArity(), 0))
    )));

    evaluate("random_uniform(0, 1, 123, 100, 'baz')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.get().wrongNumber(expectedArgs.getArity(), 5))
    )));
  }

  @Test
  public void randomUniformExpectsCorrectArgumentTypes() throws Exception {
    ArgumentList expectedArgs = project.getFunctionSet().get("random_uniform").getArguments();
    evaluate("random_uniform('foo', 1)", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get(0), Types.TEXT))
    )));

    evaluate("random_uniform(0, 'bar')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get(1), Types.TEXT))
    )));

    evaluate("random_uniform(0, 1, seed: 'bar')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get("seed"), Types.TEXT)
    ))));

    evaluate("random_uniform(0, 1, samples: 'bar')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get("samples"), Types.TEXT)
    ))));
  }

  @Test
  public void randomNorm() throws Exception {
    // normal distribution is a bit harder to check as we can't *guarantee*
    // that it falls in any given range. However, 99.73% of results will
    // fall within 3 stddevs of the mean, so let's just check the result is
    // within 5 stddevs, just to be safe
    assertThat((Double) evaluate("random_norm(1000, 5)", null),
        allOf(greaterThanOrEqualTo(975.0), lessThanOrEqualTo(1025.0)));
    assertThat((Double) evaluate("random_norm(mean: 10.0, stddev: 1)", null),
        allOf(greaterThanOrEqualTo(5.0), lessThanOrEqualTo(15.0)));
    assertThat((Double) evaluate("random_norm(stddev: 1.9, mean: -50)", null),
        allOf(greaterThanOrEqualTo(-60.0), lessThanOrEqualTo(-40.0)));
    assertThat((Double) evaluate("random_norm(0, 1.0)", null),
        allOf(greaterThanOrEqualTo(-5.0), lessThanOrEqualTo(5.0)));
  }

  @Test
  public void seededRandomNormProducesConsistentResults() throws Exception {
    Object result1 = evaluate("random_norm(0, 10, seed: 1)", null);
    Object result2 = evaluate("random_norm(0, 10, seed: 2)", null);
    assertNotEquals(result1, result2);
    for (int i = 0; i < 10; i++) {
      assertEquals(result1, evaluate("random_norm(0, 10, seed: 1)", null));
      assertEquals(result2, evaluate("random_norm(0, 10, seed: 2)", null));
    }
  }

  @Test
  public void randomNormCanReturnMultipleSamples() throws Exception {
    assertThat(evaluate("length(random_norm(0, 10, samples: 100))", null), is(100L));
    // we should expect the min/max to fall between 2 and 8 stddevs
    assertThat((Double) evaluate("min(random_norm(1000, 5, samples: 1000))", null),
        allOf(greaterThanOrEqualTo(960.0), lessThanOrEqualTo(990.0)));
    assertThat((Double) evaluate("max(random_norm(1000, 5, samples: 1000))", null),
        allOf(greaterThanOrEqualTo(1010.0), lessThanOrEqualTo(1040.0)));

    // NB: 1 sample is still a list - samples could be a model parameter, and so
    // changing the return type based on size could break the pipeline
    assertThat(evaluate("length(random_norm(0, 10, samples: 1))", null), is(1L));

    // check we don't crash if the user specifies dumb args
    assertThat(evaluate("length(random_norm(0, 10, samples: 0))", null), is(0L));
    assertThat(evaluate("length(random_norm(0, 10, samples: -1))", null), is(0L));

    // using a seed should produce the same result
    Object result1 = evaluate("max(random_norm(1000, 5, samples: 1000, seed: 123))", null);
    assertEquals(result1, evaluate("max(random_norm(seed: 123, stddev: 5, samples: 1000, mean: 1000))", null));
    assertNotEquals(result1, evaluate("min(random_norm(1000, 5, seed: 123, samples: 1000))", null));
  }

  @Test
  public void randomNormExpectsCorrectArgumentNumber() throws Exception {
    ArgumentList expectedArgs = project.getFunctionSet().get("random_norm").getArguments();
    evaluate("random_norm()", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.get().wrongNumber(expectedArgs.getArity(), 0))
    )));

    evaluate("random_norm(0, 1, 123, 100, 'baz')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.get().wrongNumber(expectedArgs.getArity(), 5))
    )));
  }

  @Test
  public void randomNormExpectsCorrectArgumentTypes() throws Exception {
    ArgumentList expectedArgs = project.getFunctionSet().get("random_norm").getArguments();
    evaluate("random_norm('foo', 1)", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get(0), Types.TEXT))
    )));

    evaluate("random_norm(0, 'bar')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get(1), Types.TEXT))
    )));

    evaluate("random_norm(0, 1, seed: 'bar')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get("seed"), Types.TEXT)
    ))));

    evaluate("random_norm(0, 1, samples: 'bar')", null);
    assertThat(realizationProblems, contains(hasAncestorProblem(
        is(ArgsProblems.mismatch(expectedArgs.get("samples"), Types.TEXT)
    ))));
  }
}
