/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine;

import static nz.org.riskscape.engine.Assert.*;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.geotools.referencing.CRS;
import org.geotools.referencing.CRS.AxisOrder;
import org.junit.Test;
import org.locationtech.jts.geom.Coordinate;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.GeometryFactory;
import org.locationtech.jts.geom.MultiPoint;
import org.locationtech.jts.geom.Point;
import org.locationtech.jts.io.WKTReader;
import org.mockito.Mockito;
import org.geotools.api.referencing.crs.CoordinateReferenceSystem;

import nz.org.riskscape.engine.geo.GeometryFixer;
import nz.org.riskscape.engine.geo.GeometryValidation;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.Problem.Severity;

@SuppressWarnings("unchecked")
public class SRIDSetTest implements CrsHelper {

  AtomicReference<Geometry> fixInput = new AtomicReference<>();
  AtomicReference<Geometry> fixedResponse = new AtomicReference<>();
  GeometryFixer fixer = (geom) -> {
    fixInput.set(geom);
    if (fixedResponse.get() != null) {
      return fixedResponse.get();
    }
    // fallback to default implementation
    return GeometryFixer.DEFAULT.fix(geom);
  };

  List<Problem> problemSink = new ArrayList<>();
  SRIDSet sridSet = new SRIDSet(p -> problemSink.add(p), fixer);
  CoordinateReferenceSystem nzTransverseMercator = nzTransverseMercator();
  CoordinateReferenceSystem longLat = longLat();

  GeometryFactory nzFactory = sridSet.getGeometryFactory(nzTransverseMercator);
  GeometryFactory longLatFactory = sridSet.getGeometryFactory(longLat);

  WKTReader longLatReader = new WKTReader(longLatFactory);
  WKTReader nzReader = new WKTReader(nzFactory);

  Point nztmPoint = nzFactory.createPoint(new Coordinate(1807565, 5596775));
  Point llPoint = longLatFactory.createPoint(new Coordinate(175, -40));

  final int nztmId = sridSet.get(nzTransverseMercator);
  final int longLatId = sridSet.get(longLat);

  @Test
  public void willAssignAndRememberSridsForLaterRetrieval() {
    assertNotEquals(nztmId, longLatId);

    assertSame(longLat, sridSet.get(longLatId));
    assertSame(nzTransverseMercator, sridSet.get(nztmId));
  }

  @Test
  public void willAssignSameSridToEquivalentCRSs() {
    CoordinateReferenceSystem crs1 = crsFromWkt("EPSG32702.wkt");
    CoordinateReferenceSystem crs2 = crsFromWkt("EPSG32702-GCS_WGS.wkt");

    // let's sanity check the input crs. They should not be equal objects, but should be equal ignoring
    // metadata
    assertNotEquals(crs1, crs2);
    assertTrue(CRS.equalsIgnoreMetadata(crs1, crs2));

    assertEquals(sridSet.get(crs1), sridSet.get(crs2));

    // this wkt has changed the names from EPSG32702.wkt so it should be seen as the same
    CoordinateReferenceSystem crs3 = crsFromWkt("EPSG32702-changed-names.wkt");
    // let's sanity check the input crs. They should not be equal objects, but should be equal ignoring
    // metadata
    assertNotEquals(crs1, crs3);
    assertTrue(CRS.equalsIgnoreMetadata(crs1, crs3));

    assertEquals(sridSet.get(crs1), sridSet.get(crs3));
  }

  @Test
  public void willDetectNonMetaDataChanges() {
    // in these tests we check EPSG32702.wkt again other variants of that file where non metadata changes
    // have been made. This is to give us confidence that these changes to the CRS do get noticed and
    // certainly not swept under the carpet.
    CoordinateReferenceSystem crs1 = crsFromWkt("EPSG32702.wkt");

    // a variant of EPSG32702.wkt where the numbers in the SPHEROID have been changed (by just a little)
    CoordinateReferenceSystem altered = crsFromWkt("EPSG32702-altered-spheriod.wkt");
    assertFalse(CRS.equalsIgnoreMetadata(crs1, altered));
    assertNotEquals(sridSet.get(crs1), sridSet.get(altered));

    // a variant of EPSG32702.wkt where the central meridian has been changed (by just a little)
    altered = crsFromWkt("EPSG32702-shifted-meridian.wkt");
    assertFalse(CRS.equalsIgnoreMetadata(crs1, altered));
    assertNotEquals(sridSet.get(crs1), sridSet.get(altered));

    // a variant of EPSG32702.wkt where the scale-factor has been changed (by just a little)
    altered = crsFromWkt("EPSG32702-altered-scale-factor.wkt");
    assertFalse(CRS.equalsIgnoreMetadata(crs1, altered));
    assertNotEquals(sridSet.get(crs1), sridSet.get(altered));
  }

  @Test
  public void willAssignDifferentSridIfAxisOrderNotSame() throws Exception {
    CoordinateReferenceSystem nztmYX = CRS.decode("EPSG:4326", false);
    CoordinateReferenceSystem nztmXY = CRS.decode("EPSG:4326", true);

    assertNotEquals(sridSet.get(nztmXY), sridSet.get(nztmYX));
  }

  @Test
  public void detectsEquivalentCRSs() throws Exception {
    CoordinateReferenceSystem crs1 = crsFromWkt("EPSG32702.wkt");
    CoordinateReferenceSystem crs2 = crsFromWkt("EPSG32702-GCS_WGS.wkt");

    assertFalse(sridSet.requiresReprojection(crs1, crs2));
    assertTrue(sridSet.requiresReprojection(crs1, nzTransverseMercator));
    assertTrue(sridSet.requiresReprojection(CRS.decode("EPSG:4326", false), CRS.decode("EPSG:4326", true)));
  }

  @Test(expected=UnknownSRIDException.class)
  public void willThrowAnExceptionIfAnSRIDNotRemembered() {
    sridSet.get(3);
  }

  @Test
  public void canReprojectGeometry() {
    assertEquals(0, sridSet.cachedTransforms.size());
    Geometry reprojected = sridSet.reproject(nztmPoint, sridSet.get(longLat));
    assertEquals(longLat, sridSet.get(reprojected.getSRID()));
    assertEquals(longLatFactory, reprojected.getFactory());
    assertEquals(1, sridSet.cachedTransforms.size());
  }

  @Test
  public void canReprojectMultiGeometry() {
    MultiPoint mPoint = longLatFactory.createMultiPoint(new Point[] {llPoint});
    MultiPoint reprojected = (MultiPoint)sridSet.reproject(mPoint, nztmId);

    assertEquals(nzFactory, reprojected.getFactory());
    assertEquals(nztmId, reprojected.getSRID());
    // Check that the srid is updated in the component parts
    assertEquals(nztmId, reprojected.getGeometryN(0).getSRID());
    assertEquals(nzFactory, reprojected.getGeometryN(0).getFactory());
  }

  @Test
  public void reprojectionWillWarnIfDatumShiftIsIgnored() {
    // we expect llPoint to reproject to NZTM but there should be a warning about the missing
    // datum shift parameters
    sridSet.reproject(llPoint, nztmId);
    assertThat(problemSink, contains(
        GeometryProblems.get().reprojectionIgnoringDatumShift(longLat, nzTransverseMercator())
    ));
  }

  @Test
  public void reprojectionWillNotWarnIfDatumShiftIsIgnoredWithinSameCrs() throws Exception {
    // should be no warning reprojecting a lat/long point to long/lat
    GeometryFactory latLongFactory = sridSet.getGeometryFactory(SRIDSet.EPSG4326_LATLON);
    Point llPoint2 = latLongFactory.createPoint(new Coordinate(-40, 175));
    sridSet.reproject(llPoint2, sridSet.get(longLat()));
    assertThat(problemSink, empty());

    // this is an unusually formatted (but valid) WGS84 CRS WKT we encountered in the field
    CoordinateReferenceSystem weirdWgs84Crs = crsFromWkt("EPSG4326_unusual.wkt");

    // both points are essentially in long/lat WGS84, even though they appear as
    // different CRSs, therefore no real reprojection should be occurring
    sridSet.reproject(llPoint, sridSet.get(weirdWgs84Crs));
    assertThat(problemSink, empty());

    // we do however still get a warning if the axis needs to be flipped for this weird CRS.
    // this is probably sensible, as it is still a non-identity transform operation (if we were
    // simply transposing the coordinates, then it'd be safe to ignore this too)
    sridSet.reproject(llPoint2, sridSet.get(weirdWgs84Crs));
    assertThat(problemSink, contains(
        GeometryProblems.get().reprojectionIgnoringDatumShift(SRIDSet.EPSG4326_LATLON, weirdWgs84Crs)
    ));
  }

  @Test(expected=UnknownSRIDException.class)
  public void reprojectFailsIfCrsUnknown() {
    sridSet.reproject(nztmPoint, 3);
  }

  @Test
  public void reprojectFailsIfReprojectionNotPossible() {
    Point notInNZ = longLatFactory.createPoint(new Coordinate(10, 10));
    assertThrows(GeometryReprojectionException.class, () -> sridSet.reproject(notInNZ, nztmId));
  }

  @Test
  public void changingTheDatumGivesNewSRID() {
    CoordinateReferenceSystem nztm = nzTransverseMercator();

    // this CRS has WKT that is copied from nztm, but the GEOGCS is swapped out for WSG84. NZTM should
    // be using NZDG2000.
    // the purpose of the test is to ensure we don't parse the WKT and still think it is NZTM. Which
    // it cannot be because of the difference in GEOGCS.
    CoordinateReferenceSystem nztmWrongDatum = crsFromWkt("EPSG2193-on-wgs84.wkt");

    assertNotEquals(sridSet.get(nztm), sridSet.get(nztmWrongDatum));
    assertTrue(sridSet.requiresReprojection(nztm, nztmWrongDatum));
  }

  @Test
  public void reprojectFailsIfReprojectedGeomIsInvalid() throws Exception {
    // This shape looks a little like a flag on a pole. The bottom of the pole is a point at
    // 165.1229670992294 -46.318004677055015
    Geometry geom = longLatReader.read("POLYGON ((165.14629469705895 -46.337645271124124,"
        + " 165.14880262533865 -46.31974453180639, 165.1229670992294 -46.318004677055015,"
        + " 165.13588463563147 -46.31887531839142, 165.133372725189 -46.33677499295321,"
        + " 165.14629469705895 -46.337645271124124))");
    // sanity check that this lat/long geom is valid
    assertTrue(geom.isValid());

    CoordinateReferenceSystem nztm = nzTransverseMercator();
    int tarsetSrid = sridSet.get(nztm);

    Geometry reprojected = sridSet.reproject(geom, tarsetSrid);
    // sanity check that it is actually invalid
    assertFalse(reprojected.isValid());
    // check that we didn't try to fix it, mode is off
    assertThat(fixInput.get(), nullValue());
    // geometry validation defaults to off, we don't get no invalid geometry problems
    assertThat(problemSink, contains(
        // but there is a warning about reprojection with no datum shift params
        GeometryProblems.get().reprojectionIgnoringDatumShift(longLat, nzTransverseMercator())
    ));

  }

  public void reprojectWillLogWhenInvalidGeometryIsFixed() throws Exception {
    // This shape looks a little like a flag on a pole. The bottom of the pole is a point at
    // -46.318004677055015 165.1229670992294
    Geometry geom = longLatReader.read("POLYGON ((165.14629469705895 -46.337645271124124,"
        + " 165.14880262533865 -46.31974453180639, 165.1229670992294 -46.318004677055015,"
        + " 165.13588463563147 -46.31887531839142, 165.133372725189 -46.33677499295321,"
        + " 165.14629469705895 -46.337645271124124))");
    // sanity check that this lat/long geom is valid
    assertTrue(geom.isValid());

    CoordinateReferenceSystem nztm = nzTransverseMercator();
    int tarsetSrid = sridSet.get(nztm);
    // now turn on warn geometry validation and rinse/repeat
    sridSet.setValidationPostReproject(GeometryValidation.WARN);
    // let's load the geometry fixer with a precanned good geom
    fixedResponse.set(nztmPoint);
    assertThat(sridSet.reproject(geom, tarsetSrid), sameInstance(nztmPoint));
    assertThat(problemSink, contains(
        Matchers.isProblem(Severity.INFO, GeometryProblems.class, "fixedInvalidPostReprojection")
    ));

    // now turn on warn geometry validation and rinse/repeat
    sridSet.setValidationPostReproject(GeometryValidation.ERROR);
    problemSink.clear();
    assertThat(sridSet.reproject(geom, tarsetSrid), sameInstance(nztmPoint));
    assertThat(problemSink, contains(
        Matchers.isProblem(Severity.INFO, GeometryProblems.class, "fixedInvalidPostReprojection")
    ));
  }

  @Test
  public void canCreateCrsWithForceXY() {
    // EPSG:2193 defaults to YX (north east)
    assertThat(CRS.getAxisOrder(SRIDSet.epsgToCrs("EPSG:2193")), is(AxisOrder.NORTH_EAST));
    assertThat(CRS.getAxisOrder(SRIDSet.epsgToCrsWithForceXY("EPSG:2193")), is(AxisOrder.EAST_NORTH));

    // EPSG:4326 defaults to YX (north east)
    assertThat(CRS.getAxisOrder(SRIDSet.epsgToCrs("EPSG:4326")), is(AxisOrder.NORTH_EAST));
    assertThat(CRS.getAxisOrder(SRIDSet.epsgToCrsWithForceXY("EPSG:4326")), is(AxisOrder.EAST_NORTH));
  }

  @Test(timeout = 1000)
  public void returnsSridForKnownCrsWithoutDelay_GL601() {
    // we have two variants of the same CRS (only metadata differences)
    CoordinateReferenceSystem crs1 = crsFromWkt("EPSG32702.wkt");
    CoordinateReferenceSystem crs2 = crsFromWkt("EPSG32702-GCS_WGS.wkt");

    // now we add both to the SRIDSet and check they have the same SRID allocated
    int srid1 = sridSet.get(crs1);
    assertThat(sridSet.get(crs2), is(srid1));

    // now we fetch the SRID lots of time to check the speed performance. This is to prevent #601 from
    // regressing.
    for (int i = 0; i < 10000; i++) {
      assertThat(sridSet.get(crs1), is(srid1));
      assertThat(sridSet.get(crs2), is(srid1));
    }
  }

  @Test(timeout = 1000)
  public void threadSafetyTest() throws Exception {
    final int numCrs = 10000;
    final int numThreads = 50;
    List<CoordinateReferenceSystem> mocks = IntStream.range(0, numCrs)
        .mapToObj(num -> Mockito.mock(CoordinateReferenceSystem.class))
        .collect(Collectors.toList());

    // in serial

    sridSet = new SRIDSet() {
      // we need to mess with the SRIDSet implementation so it can handle our mocked CRSs.
      @Override
      int crsHashCode(CoordinateReferenceSystem crs) {
        return crs.hashCode();
      }

    };
    int lastId = 0;
    for (CoordinateReferenceSystem crs : mocks) {
      int id = sridSet.get(crs);
      assertEquals(lastId + 1, id);
      lastId++;

      assertSame(crs, sridSet.get(id));
    }

    List<Exception> failures = Collections.synchronizedList(new ArrayList<>());

    sridSet.clear();
    // run through all the CRSs, looking them up
    AtomicInteger numCrsAdded = new AtomicInteger(0);
    AtomicInteger numCrsGets = new AtomicInteger(0);

    List<Pair<CoordinateReferenceSystem, Integer>> added = Collections.synchronizedList(new ArrayList<>());
    List<Thread> threads = IntStream.range(0, numThreads).mapToObj(i -> new Thread(() -> {
      wrapped(failures, () -> {
        // half of the threads add our mock CRSs from the list, the other half randomly select a CRS that we know has
        // been added and look it up and confirm it comes back with the SRID we saw when it was added
        if (i % 2 == 0) {
          for (int next = numCrsAdded.getAndIncrement(); next < mocks.size(); next = numCrsAdded.getAndIncrement()) {
            CoordinateReferenceSystem crs = mocks.get(next);
            int srid = sridSet.get(crs);
            numCrsGets.incrementAndGet();
            assertSame(crs, sridSet.get(srid));
            added.add(Pair.of(crs, srid));

          }
        } else {
          // keep asserting while we are adding
          while (numCrsAdded.get() < mocks.size()) {
            if (added.size() > 0) {
              int pick = (int) Math.round(Math.floor(Math.random() * added.size()));
              // get a crs we have definitely added already
              Pair<CoordinateReferenceSystem, Integer> pair = added.get(pick);
              Integer srid = sridSet.get(pair.getLeft());
              assertEquals(srid, pair.getRight());
              numCrsGets.incrementAndGet();
            }
          }
        }
      });
    })).collect(Collectors.toList());

    threads.forEach(Thread::start);
    threads.forEach(t -> {
      try {
        t.join();
      } catch (InterruptedException e) {
        e.printStackTrace();
      }
    });

    assertThat(failures, empty());
    assertThat(numCrsGets.get(), greaterThan(0));
    System.err.println(String.format(
        "Added %d crs and asserted %d gets in %d threads", added.size(), numCrsGets.get(), numThreads));

    // one more sanity check
    for (Pair<CoordinateReferenceSystem,Integer> pair : added) {
      int fetchedSrid = sridSet.get(pair.getLeft());
      CoordinateReferenceSystem fetchedCrs = sridSet.get(pair.getRight());

      assertSame(fetchedCrs, pair.getLeft());
      assertThat(fetchedSrid, equalTo(pair.getRight()));
    }
  }

  interface ExceptionRunnable {
    void run() throws Exception;
  }

  private void wrapped(List<Exception> failures, ExceptionRunnable callback) {
    try {
      callback.run();
    } catch (Exception e) {
      failures.add(e);
      throw new RuntimeException(e);
    }
  }


}
