/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.cpython;

import static nz.org.riskscape.engine.Assert.*;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.junit.Test;
import org.locationtech.jts.geom.Coordinate;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.GeometryFactory;
import org.locationtech.jts.geom.Point;
import org.locationtech.jts.geom.PrecisionModel;
import org.locationtech.jts.io.WKBWriter;

import nz.org.riskscape.engine.GeometryMatchers;
import nz.org.riskscape.engine.Matchers;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.TupleMatchers;
import nz.org.riskscape.engine.function.IdentifiedFunction;
import nz.org.riskscape.engine.types.CoercionException;
import nz.org.riskscape.engine.types.Enumeration;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.RSList;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.engine.types.WithinRange;
import nz.org.riskscape.engine.types.WithinSet;

/**
 * Calls CPython functions with a variety of different argument/return types,
 * to check that serialization is working correctly.
 */
public class CPythonSerializationTest extends CPythonBaseTest {

  @Test
  public void canCallASimpleFunction() throws Exception {
    String script = ""
        + "def function(a):\n"
        + "  return 'cool ' + str(a)\n"
        + "\n";

    IdentifiedFunction function = makeFunction(script,
        Arrays.asList(Types.INTEGER),
        Types.TEXT);
    assertEquals("cool 1", call(function, Arrays.asList(1L)));
  }

  @Test
  public void canCallAMultiArgsFunction() throws Exception {
    String script = ""
        + "def function(a, b, c):\n"
        + "  return 'easy as {0}, {1}, {2}'.format(a, b, c)\n"
        + "\n";

    IdentifiedFunction function = makeFunction(script,
        Arrays.asList(Types.INTEGER, Types.FLOATING, Types.TEXT),
        Types.TEXT);
    assertEquals("easy as 1, 2.0, three", call(function, Arrays.asList(1L, 2D, "three")));
  }

  @Test
  public void canCallAFunctionWithStructs() throws Exception {
    String script = ""
        + "def function(multiply, divide):\n"
        + "  sum = multiply.get('x') * multiply.get('y') + divide.get('a') / divide.get('b') \n"
        + "  return { 'total': sum, 'rounded': round(sum) }\n"
        + "\n";

    Struct args1 = Struct.of("x", Types.FLOATING, "y", Types.INTEGER);
    Struct args2 = Struct.of("a", Types.INTEGER, "b", Types.FLOATING);
    Struct returnType = Struct.of("total", Types.FLOATING, "rounded", Types.INTEGER);
    IdentifiedFunction function = makeFunction(script, Arrays.asList(args1, args2), returnType);

    Tuple args1Value = Tuple.ofValues(args1, 2.5, 3L);
    Tuple args2Value = Tuple.ofValues(args2, 5L, 4.0);
    // 2.5 * 3 + 5 / 4.0 = 8.75
    Tuple expected = Tuple.ofValues(returnType, 8.75, 9L);
    assertEquals(expected, call(function, Arrays.asList(args1Value, args2Value)));
  }

  @Test
  public void canCallAFunctionWithLists() throws Exception {
    String script = ""
        + "def function(myList):\n"
        + "  return range(0, len(myList))\n"
        + "\n";

    IdentifiedFunction function = makeFunction(script,
        Arrays.asList(RSList.create(Types.TEXT)),
        RSList.create(Types.INTEGER));

    List<String> myList = Arrays.asList("a", "b", "c", "d");
    assertEquals(Arrays.asList(0L, 1L, 2L, 3L), call(function, Arrays.asList(myList)));
    myList = Arrays.asList("a", "b");
    assertEquals(Arrays.asList(0L, 1L), call(function, Arrays.asList(myList)));
  }

  @Test
  public void canCallAFunctionWithNullableTypes() throws Exception {
    String script = ""
        + "def function(toCheck):\n"
        + "  if toCheck == None or toCheck.get('foo') < 5.0:\n"
        + "    return None\n"
        + "  else:\n"
        + "    return True\n"
        + "\n";

    Struct args1 = Struct.of("foo", Types.FLOATING);
    IdentifiedFunction function = makeFunction(script,
        Arrays.asList(Nullable.of(args1)),
        Nullable.of(Types.BOOLEAN));

    List<Object> nullArgs = new ArrayList<>();
    nullArgs.add(null);
    assertEquals(null, call(function, nullArgs));
    assertEquals(null, call(function, Arrays.asList(Tuple.ofValues(args1, 4.9))));
    assertEquals(true, call(function, Arrays.asList(Tuple.ofValues(args1, 5.1))));
  }

  @Test
  public void canCallAFunctionWithEnums() throws Exception {
    String script = ""
        + "def function(myEnum):\n"
        + "  return myEnum + 1\n";

    Enumeration testEnum = Enumeration.oneBased("foo", "bar", "baz");

    IdentifiedFunction function = makeFunction(script, Arrays.asList(testEnum), testEnum);

    assertEquals("bar", call(function, Arrays.asList("foo")));
    assertEquals("baz", call(function, Arrays.asList("bar")));
    assertThrows(CoercionException.class, // baz + 1 is out of range
        () -> call(function, Arrays.asList("baz")));
  }

  @Test
  public void canCallAFunctionWithWrappedTypes() throws Exception {
    String script = ""
        + "def function(time):\n"
        + "  if time < 10:\n"
        + "    return 'breakfast'\n"
        + "  elif time < 14:\n"
        + "    return 'lunch'\n"
        + "  else:\n"
        + "    return 'dinner'\n";

    WithinSet testSet = new WithinSet(Types.TEXT, "breakfast", "lunch", "dinner");
    WithinRange testRange = new WithinRange(Types.INTEGER, 0, 24);

    IdentifiedFunction function = makeFunction(script, Arrays.asList(testRange), testSet);

    assertEquals("breakfast", call(function, Arrays.asList(7L)));
    assertEquals("lunch", call(function, Arrays.asList(12L)));
    assertEquals("dinner", call(function, Arrays.asList(18L)));
  }

  @Test
  public void canCallAFunctionWithNestedTypes() throws Exception {
    String script = ""
        + "def function(foo):\n"
        + "  nested = foo.get('nested')\n"
        + "  if nested is not None:\n"
        + "    nested = { 'list': [ x / 2.0 for x in nested.get('list') ] }\n"
        + "  return { 'bar': foo.get('bar') * 2, 'nested': nested }\n"
        + "\n";

    Struct nested = Struct.of("list", RSList.create(Types.INTEGER));
    Struct testStruct = Struct.of("nested", Nullable.of(nested), "bar", Types.FLOATING);
    IdentifiedFunction function = makeFunction(script, Arrays.asList(testStruct), testStruct);

    // nested is null, so we just expect 'bar' to be doubled
    Tuple testArgs = Tuple.ofValues(testStruct, null, 3.5);
    assertEquals(Tuple.ofValues(testStruct, null, 7D), call(function, Arrays.asList(testArgs)));

    // call with a nested list - the list items should get halved
    // (note python-side serialization is coercing the list items from double to integer)
    testArgs = Tuple.ofValues(testStruct, Tuple.ofValues(nested, Arrays.asList(1L, 3L, 5L)), 2D);
    assertEquals(Tuple.ofValues(testStruct, Tuple.ofValues(nested, Arrays.asList(0L, 1L, 2L)), 4D),
        call(function, Arrays.asList(testArgs)));
  }

  @Test
  public void canCallCoercePythonReturnValues() throws Exception {
    // we shouldn't choke if the user tells us a number if floating, but then then
    // function gives us an integer to serialize, etc
    String script = ""
        + "def function(a):\n"
        + "  return {'floaty': int(a), 'inty': a / 2.0, 'stringy': a}\n";

    Struct returnType = Struct.of("floaty", Types.FLOATING, "inty", Types.INTEGER, "stringy", Types.TEXT);
    IdentifiedFunction function = makeFunction(script, Arrays.asList(Types.FLOATING), returnType);

    assertEquals(Tuple.ofValues(returnType, 5.0, 2L, "5.0"), call(function, Arrays.asList(5D)));
  }

  @Test
  public void canReadAndWriteSmallfloat() throws Exception {
    // we shouldn't choke if the user tells us a number if floating, but then then
    // function gives us an integer to serialize, etc
    String script = """
        def function(a):
          return a * 2
        """;

    IdentifiedFunction function = makeFunction(script, Arrays.asList(Types.SMALLFLOAT), Types.SMALLFLOAT);
    assertEquals(10F, call(function, Arrays.asList(5F)));
  }

  @Test
  public void canCoerceNumbersToSmallfloats() throws Exception {
    // we shouldn't choke if the user tells us a number if floating, but then then
    // function gives us an integer to serialize, etc
    String script = """
        def function(a):
          return {
            "wasInteger": 1,
            "wasFloating": 45.5,
            "wasText": "3.5"
          }
        """;

    Struct returnType =
        Struct.of("wasInteger", Types.SMALLFLOAT, "wasFloating", Types.SMALLFLOAT, "wasText", Types.SMALLFLOAT);
    IdentifiedFunction function = makeFunction(script, Arrays.asList(Types.SMALLFLOAT), returnType);
    assertEquals(Tuple.ofValues(returnType, 1F, 45.5F, 3.5F), call(function, Arrays.asList(5F)));
  }

  @Test
  public void canSendNonWesternCharactersToFunction() throws Exception {
    String script = ""
        + "def function(text):\n"
        + "  return len(text)\n";

    IdentifiedFunction function = makeFunction(script, Arrays.asList(Types.TEXT), Types.INTEGER);
    assertEquals(5L, call(function, Arrays.asList("すみません")));
  }

  @Test
  public void canReceiveNonWesternCharactersToFunction() throws Exception {
    String script = ""
        + "def function(num):\n"
        + "  return 'すみません'\n";

    IdentifiedFunction function = makeFunction(script, Arrays.asList(Types.INTEGER), Types.TEXT);
    assertEquals("すみません", call(function, Arrays.asList(1L)));
  }

  @Test
  public void canSendAndReceiveNonWesternCharactersToFunction() throws Exception {

    List<String> examples = Arrays.asList(
        "為",
        "すみません",
        "Te ao Māori",
        "🏚️🔥"
    );
    String script = ""
        + "def function(text):\n"
        + "  return text\n";

    IdentifiedFunction function = makeFunction(script, Arrays.asList(Types.TEXT), Types.TEXT);

    for (String exampleText : examples) {
      try {
        assertEquals(exampleText, call(function, Arrays.asList(exampleText)));
      } catch (Exception e) {
        throw new AssertionError("failed on example " + exampleText, e);
      }
    }
  }

  @Test
  public void canSendGeometryToAFunction() throws Exception {
    String script = ""
        + "def function(geometry):\n"
        + "  srid = geometry[1]\n"
        + "  bytes = geometry[0]\n"
        + "  return {'srid': srid, 'len': len(bytes), 'original': geometry}";

    IdentifiedFunction function = makeFunction(
        script,
        Arrays.asList(Types.GEOMETRY),
        Struct.of("srid", Types.INTEGER, "len", Types.INTEGER, "original", Types.GEOMETRY)
    );

    GeometryFactory gf = new GeometryFactory(new PrecisionModel(), 420);
    Point point = gf.createPoint(new Coordinate(1, 1));

    WKBWriter writer = new WKBWriter();
    long numBytes = writer.write(point).length;

    Object returned = call(function, Arrays.asList(point));
    assertThat(
      returned,
      Matchers.instanceOfAnd(
        Tuple.class,
        allOf(
          TupleMatchers.tupleWithValue("srid", equalTo(420L)),
          TupleMatchers.tupleWithValue("len", equalTo(numBytes)),
          TupleMatchers.tupleWithValue("original", allOf(
              equalTo(point),
              GeometryMatchers.geometryWithSrid(420)
          ))
        )
      )
    );
  }

    @Test
    public void canReturnJustBytesForGeometry() throws Exception {
      // there's some special case logic in the geometry processing that allows vanilla wkb to be returned, sans srid
      String script = ""
          + "def function(geometry):\n"
          + "  return geometry[0]\n";

      IdentifiedFunction function = makeFunction(
          script,
          Arrays.asList(Types.GEOMETRY),
          Types.GEOMETRY
      );

      GeometryFactory gf = new GeometryFactory(new PrecisionModel(), 420);
      Point point = gf.createPoint(new Coordinate(1, 1));

      Object returned = call(function, Arrays.asList(point));
      assertThat(
        returned,
        Matchers.instanceOfAnd(Geometry.class,
        allOf(
          equalTo(point),
          GeometryMatchers.geometryWithSrid(0)
        ))
      );
  }
}
