/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.defaults.function.expression;

import static nz.org.riskscape.engine.Assert.*;
import static nz.org.riskscape.engine.Matchers.*;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;

import java.net.URI;
import java.util.List;

import org.hamcrest.Matcher;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

import nz.org.riskscape.dsl.SourceLocation;
import nz.org.riskscape.engine.RiskscapeException;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.function.ArgumentList;
import nz.org.riskscape.engine.function.IdentifiedFunction;
import nz.org.riskscape.engine.function.RiskscapeFunction;
import nz.org.riskscape.engine.function.StringFunctions;
import nz.org.riskscape.engine.problem.GeneralProblems;
import nz.org.riskscape.engine.problem.ProblemMatchers;
import nz.org.riskscape.engine.rl.BaseExpressionRealizerTest;
import nz.org.riskscape.engine.rl.DefaultOperators;
import nz.org.riskscape.engine.rl.EvalException;
import nz.org.riskscape.engine.rl.LanguageFunctions;
import nz.org.riskscape.engine.rl.LogicFunctions;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ast.ExpressionProblems;
import nz.org.riskscape.rl.ast.Lambda;

public class ExpressionFunctionFrameworkTest extends BaseExpressionRealizerTest {

  IdentifiedFunction function;
  List<Problem> functionProblems;


  @Before
  public void setup() {
    engine.getFunctionFrameworks().add(new ExpressionFunctionFramework());
    project.getFunctionSet().insertFirst(DefaultOperators.INSTANCE);
    project.getFunctionSet().addAll(StringFunctions.FUNCTIONS);
    project.getFunctionSet().addAll(LanguageFunctions.FUNCTIONS);
    project.getFunctionSet().addAll(LogicFunctions.LOGIC_FUNCTIONS);
  }

  @Test
  public void canDefineAHelloWorldFunction() throws Exception {
    addFunction("hello", "source = () -> 'Hello, World!'");
    assertThat(function, hasArguments(ArgumentList.create()));
    assertThat(function, hasReturnType(Nullable.ANYTHING));

    assertEval("hello()", "Hello, World!");
    assertThat(realized.getResultType(), equalTo(Types.TEXT));
  }

  @Test
  public void canDefineAMinimalBinaryExpression() throws Exception {
    addFunction("plus", "source = (a, b) -> a + b");
    assertThat(function, hasArguments(ArgumentList.create("a", Nullable.ANYTHING, "b", Nullable.ANYTHING)));
    assertThat(function, hasReturnType(Nullable.ANYTHING));

    assertEval("plus(1, 2)", 3L);
    assertThat(realized.getResultType(), equalTo(Types.INTEGER));
  }

  @Test
  public void canDefineABinaryExpressionWithArgTypes() throws Exception {
    addFunction("plus", """
        source = (a, b) -> a + b
        argument-types = [integer, integer]
        """);

    // check that the argument types were correctly labeled from the lambda expression
    assertThat(function, hasArguments(ArgumentList.create("a", Types.INTEGER, "b", Types.INTEGER)));
    assertThat(function, hasReturnType(Nullable.ANYTHING));

    assertEval("plus(1, 2)", 3L);
    assertThat(realized.getResultType(), equalTo(Types.INTEGER));
  }

  @Test
  public void canDefineABinaryExpressionWithNamedArgTypes() throws Exception {
    addFunction("plus", """
        source = (a, b) -> a + b
        argument-types = [lhs: integer, rhs: integer]
        """);

    // check that the argument types were correctly labeled from the definition, not the lambda
    assertThat(function, hasArguments(ArgumentList.create("lhs", Types.INTEGER, "rhs", Types.INTEGER)));
    assertThat(function, hasReturnType(Nullable.ANYTHING));

    assertEval("plus(1, 2)", 3L);
    assertThat(realized.getResultType(), equalTo(Types.INTEGER));
  }

  @Test
  public void canDefineAnExpressionWithNullableReturnType() throws Exception {
    addFunction("test", "source = (test) -> if(test, 'foo')");
    assertThat(function, hasArguments(ArgumentList.create("test", Nullable.ANYTHING)));
    assertThat(function, hasReturnType(Nullable.ANYTHING));

    assertEval("test(true)", "foo");
    assertThat(realized.getResultType(), equalTo(Nullable.TEXT));
  }

  @Test
  public void declaredReturnTypeIsRespected() throws Exception {
    addFunction("test", """
        return-type = text
        # this expression actually returns a nullable so it might blow up run
        source = (test) -> if(test, 'foo')
    """);
    assertThat(function, hasArguments(ArgumentList.create("test", Nullable.ANYTHING)));
    assertThat(function, hasReturnType(Types.TEXT));

    assertEval("test(true)", "foo");
    // actual return type is what's declared
    assertThat(realized.getResultType(), equalTo(Types.TEXT));

    // but return type coercion may blow up
    RiskscapeException ex = assertThrows(RiskscapeException.class, () -> assertEval("test(false)", 5L));
    assertThat(ex.getProblem(), hasAncestorProblem(
        isError(ExpressionFunctionFramework.LocalProblems.class, "returnTypeCoercionFailed"))
    );

    // return type coercion also fails to coerce null -> anything. just saying.
    addFunction("test_anything", """
        # surly users mean anything goes, including nulls
        return-type = anything
        source = (test) -> if(test, 'foo')
    """);
    assertThat(function, hasArguments(ArgumentList.create("test", Nullable.ANYTHING)));
    assertThat(function, hasReturnType(Types.ANYTHING));

    assertEval("test_anything(true)", "foo");
    // actual return type is what's declared
    assertThat(realized.getResultType(), equalTo(Types.ANYTHING));

    // but return type coercion may blow up
    ex = assertThrows(RiskscapeException.class, () -> assertEval("test_anything(false)", "foo"));
    assertThat(ex.getProblem(), hasAncestorProblem(
        isError(ExpressionFunctionFramework.LocalProblems.class, "returnTypeCoercionFailed"))
    );
  }

  @Test
  public void acceptsNullInputsIfExpressionAllows() {
    addFunction("accepts_null", """
        source = '''
          (hazard, resource) ->
            if(hazard > 0,
               then: () -> if(is_null(resource), 'Exposed', 'Exposed with ' + str(resource)),
               else: 'Not exposed')
        '''
        """);

    assertThat(function, hasArguments(ArgumentList.create(
        "hazard", Nullable.ANYTHING, "resource", Nullable.ANYTHING
    )));
    assertThat(function, hasReturnType(Nullable.ANYTHING));

    assertEval("accepts_null(1, 'foo')", "Exposed with foo");
    assertThat(realized.getResultType(), equalTo(Types.TEXT));

    assertEval("accepts_null(1, null_of('text'))", "Exposed");
    assertThat(realized.getResultType(), equalTo(Types.TEXT));
  }

  @Test
  public void nullSafeWrappingKicksInIfExpressionNotNullSafe() {
    addFunction("no_nulls", """
        source = '''
          (hazard, value) ->
            if(hazard > 0,
               then: value,
               else: 0)
        '''
        """);

    assertThat(function, hasArguments(ArgumentList.create(
        "hazard", Nullable.ANYTHING, "value", Nullable.ANYTHING
    )));
    assertThat(function, hasReturnType(Nullable.ANYTHING));

    assertEval("no_nulls(1, 10)", 10L);
    assertThat(realized.getResultType(), equalTo(Types.INTEGER));

    assertEval("no_nulls(1, null_of('integer'))", null);
    assertThat(realized.getResultType(), equalTo(Nullable.INTEGER));
  }

  @Test
  public void canDefineAnExpressionWithAReturnType() throws Exception {
    addFunction("append_foo", """
        source = (thing) -> str(thing) + 'foo'
        return-type = text
        """);

    assertThat(function, hasArguments(ArgumentList.create("thing", Nullable.ANYTHING)));
    assertThat(function, hasReturnType(Types.TEXT));

    assertEval("append_foo(1)", "1foo");
    assertThat(realized.getResultType(), equalTo(Types.TEXT));
  }

  @Test
  public void canDefineAnExpressionWithACoercingReturnType() throws Exception {
    // this is actually casting rather than coercing - but we don't have a type-set aware casting method in riskscape
    // at all, so the function is falling back to the early/primitive Type#coerce method
    addFunction("add_zero", """
        source = (thing) -> str(thing) + '0'
        return-type = integer
        """);

    assertThat(function, hasReturnType(Types.INTEGER));

    assertEval("add_zero(1)", 10L);
    assertThat(realized.getResultType(), equalTo(Types.INTEGER));
  }

  @Test
  public void closesAnyChildCloseablesWhenClosed() throws Exception {
    AutoCloseable closeable = Mockito.mock(AutoCloseable.class);
    project.getFunctionSet().add(
        RiskscapeFunction.create(this, List.of(), Types.TEXT, (args) -> "hi", closeable).identified("closeme"));

    addFunction("check_close", "source = () -> closeme()");

    assertEval("check_close()", "hi");
    // not yet...
    Mockito.verify(closeable, Mockito.never()).close();

    realized.close();

    // now!
    Mockito.verify(closeable, Mockito.times(1)).close();
  }

  @Test
  public void doesNotBuildIfExpressionDoesNotParse() throws Exception {
    addFunction("broken", "source = (foo) -> foo foo");

    URI expectedLocation = SourceLocation.index(0).addToUri(functionSource);
    assertBuildFailed(
        ProblemMatchers.isProblem(
            GeneralProblems.class,
            (r, f) -> f.failedResourceLookedUp(r.eq("broken"), r.eq(expectedLocation), r.any())
        )
    );
  }

  // does not build if expression isn't a lambda, plus a hint for what the format looks like
  @Test
  public void doesNotBuildIfExpressionNotALambda() throws Exception {
    addFunction("broken", "source = foo + foo");

    assertBuildFailed(
        ProblemMatchers.isProblem(
            ExpressionProblems.class,
            (r, f) -> f.mismatch(r.any(), r.eq(Lambda.class), r.any())
        )
    );
  }

  @Test
  public void doesNotBuildIfArgTypesSizeMismatches() throws Exception {
    // make sure there's an error message when the number of argument types does not match the lambda's arg count
    addFunction("broken", """
        argument-types = [integer]
        source = (a, b) -> a + b
        """);

    assertBuildFailed(
        equalTo(ExpressionFunctionFramework.PROBLEMS.argCountMismatch(1, 2))
    );
  }

  @Test
  public void doesNotBuildIfDeclaredTypesBad() throws Exception {
    // does not build if arg types bad
    addFunction("badArgs", """
        argument-types = [intyboy]
        source = (a) -> a * 2
        """);

    // not going to over test these, they are covered in detail elsewhere
    assertBuildFailed(any(Problem.class));

    // does not build if return type bad
    addFunction("badReturn", """
        return-type = intyboy
        source = (a) -> a * 2
        """);
    assertBuildFailed(any(Problem.class));
  }

  // does not realize if given arg types mismatch function defined ones
  @Test
  public void doesNotRealizeIfGivenArgTypesMismatchDefinedOnes() throws Exception {
    addFunction("plus", """
        argument-types = [floating, floating]
        source = (a, b) -> a + b
        """);

    evaluate("plus('1', '2')", Tuple.EMPTY_TUPLE);

    assertThat(
        realizationProblems,
        contains(ProblemMatchers.problemsInTree(
            ProblemMatchers.isProblem(
                TypeProblems.class,
                (r, f) -> f.mismatch(r.any(), r.eq(Types.FLOATING), r.eq(Types.TEXT))
            )
        ))
    );
  }

  @Test
  public void doesNotRealizeIfExpressionWithGivenTypesDoesNotRealize() throws Exception {
    addFunction("div", """
        source = (a, b) -> a / b
        """);

    evaluate("div('1', '2')", Tuple.EMPTY_TUPLE);

    assertThat(
        realizationProblems,
        contains(ProblemMatchers.problemsInTree(
            equalTo(ExpressionProblems.get().failedToRealize(
                parse("(a, b) -> a / b"),
                Struct.of("a", Types.TEXT, "b", Types.TEXT)
            ))
        ))
    );
  }

  // does not evaluate if return type does not coerce (and includes information about where/why it failed)
  @Test
  public void doesNotEvaluateIfReturnTypeDoesNotCoerce() throws Exception {
    addFunction("cast", """
        return-type = integer
        source = (a) -> a
        """);

    try {
      evaluate("cast('one')", Tuple.EMPTY_TUPLE);
      fail("should have thrown");
    } catch (EvalException e) {
      assertThat(
          e.getProblem(),
          ProblemMatchers.problemsInTree(
              equalTo(ExpressionFunctionFramework.PROBLEMS.returnTypeCoercionFailed("one", Types.TEXT, Types.INTEGER))
          )
      );
    }
  }

  @Test
  public void missingArgumentTypeIsAProblem() {
    addFunction("missing_arg_type", """
                                   argument-types = [lookup('missing')]
                                   source = (a) -> a
                                   """);
    assertBuildFailed(isError(GeneralProblems.class, "noSuchObjectExists"));

  }

  @Test
  public void missingReturnTypeIsAProblem() {
    addFunction("missing_arg_type", """
                                   argument-types = [anything]
                                   return-type = "missing"
                                   source = (a) -> a
                                   """);
    assertBuildFailed(isError(GeneralProblems.class, "noSuchObjectExists"));
  }

  private void assertBuildFailed(Matcher<Problem> problem) {
    assertThat(functionProblems, contains(ProblemMatchers.problemsInTree(problem)));
  }

  @Override
  protected void addFunction(String name, String functionDefinition) {
    super.addFunction(name, functionDefinition);

    ResultOrProblems<IdentifiedFunction> functionOr = project.getFunctionSet().getOr(name);
    function = functionOr.orElse(null);
    functionProblems = functionOr.getProblems();
  }

  private Matcher<IdentifiedFunction> hasArguments(ArgumentList argumentList) {
    return hasProperty("arguments", equalTo(argumentList));
  }

  private Matcher<IdentifiedFunction> hasReturnType(Type type) {
    return hasProperty("returnType", equalTo(type));
  }

  private void assertEval(String expression, Object value) {
    evaluate(expression, Tuple.EMPTY_TUPLE);

    assertThat(realizationProblems, empty());
    assertThat(evaluated, equalTo(value));
  }
}
