/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.usgs;

import static nz.org.riskscape.engine.Matchers.*;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;

import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;

import org.hamcrest.Matcher;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.locationtech.jts.geom.Coordinate;
import org.locationtech.jts.geom.GeometryFactory;
import org.locationtech.jts.geom.Point;
import org.geotools.api.referencing.crs.CoordinateReferenceSystem;

import com.google.common.collect.ImmutableMap;

import hdf.hdf5lib.H5;
import nz.org.riskscape.engine.PluginProjectTest;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.bind.ParamProblems;
import nz.org.riskscape.engine.coverage.TypedCoverage;
import nz.org.riskscape.engine.data.Bookmark;
import nz.org.riskscape.engine.data.ResolvedBookmark;
import nz.org.riskscape.engine.data.coverage.Dataset2D;
import nz.org.riskscape.engine.problem.GeneralProblems;
import nz.org.riskscape.engine.rl.DefaultOperators;
import nz.org.riskscape.engine.rl.MathsFunctions;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.hdf5.H5Dataset;
import nz.org.riskscape.hdf5.H5DatasetPath;
import nz.org.riskscape.hdf5.H5File;
import nz.org.riskscape.hdf5.H5TestMarker;
import nz.org.riskscape.problem.ResultOrProblems;

@Category(H5TestMarker.class)
public class UsgsShakeMapTests extends PluginProjectTest {

  @BeforeClass
  public static void setupH5Library() {
    if (System.getProperty(H5.H5PATH_PROPERTY_KEY) == null) {
      System.setProperty(H5.H5PATH_PROPERTY_KEY, "/opt/hdf5-1.12.2/lib/libhdf5_java.so");
    }
  }

  TypedCoverage coverage;
  UsgsShakeMapResolver resolver;

  Path shakemapPath = Paths.get("src", "test", "resources", "shake_result_milford2019.hdf");
  H5DatasetPath pgaMeanPath = H5DatasetPath.parse("/arrays/imts/GREATER_OF_TWO_HORIZONTAL/PGA/mean");
  H5DatasetPath pgaStdDevPath = H5DatasetPath.parse("/arrays/imts/GREATER_OF_TWO_HORIZONTAL/PGA/std");

  H5File file = new H5File(shakemapPath);
  H5Dataset meanDataset = file.openDataset(pgaMeanPath);
  H5Dataset stdDevDataset = file.openDataset(pgaStdDevPath);

  // these are bounds of the HDF5 file's dataset
  double xMax = 169.117;
  double xMin = 166.6;
  double yMax = -43.6333;
  double yMin = -45.4333;
  // each index in the grid is 0.0083 x 0.0083 degrees square
  double gridElementSize = 0.00833333;

  // some index to coverage point mappings that are easy to spot-check
  Map<Integer,Point> samplePoints = ImmutableMap.of(
      // grid should start in upper left, so check index zero maps to xMin, yMax
      0, point(xMin, yMax),
      // check last index zero maps to xMax, yMin
      (217 * 303) - 1, point(xMax, yMin),
      // bottom-left corner: x=0, y=216
      216 * 303, point(xMin, yMin),
      // top-right corner: x=302, y=0
      302, point(xMax, yMax),
      // off-centre, lower-right: x=250, y=200
      200 * 303 + 250, point(xMin + (gridElementSize * 250), yMax - (gridElementSize * 200))
  );

  @Before
  public void setup() {
    project.getFunctionSet().insertFirst(DefaultOperators.INSTANCE);
    project.getFunctionSet().addAll(MathsFunctions.FUNCTIONS);
    resolver = new UsgsShakeMapResolver(engine);
  }

  private Point point(double x, double y) {
    return project.getSridSet().getGeometryFactory(project.getDefaultCrs()).createPoint(new Coordinate(x, y));
  }

  private double evaluate(Point p) {
    return (Double) coverage.evaluate(p);
  }

  /*
   * checks the DatasetCoverage will correctly report whether a
   * lat/long co-ordinate is in bounds
   */
  @Test
  public void canResolvePointsWithinBounds() {
    Bookmark bookmark = new Bookmark("id", "desc", "usgs-shakemap", shakemapPath.toUri(),
        ImmutableMap.of("percentile", Arrays.asList("75.0")));
    Optional<ResolvedBookmark> resolvedOr = resolver.resolve(bookmark, project.newBindingContext());

    // will throw if it didn't resolve
    resolvedOr.get();
    ResultOrProblems<TypedCoverage> dataOr = resolvedOr.get().getData(TypedCoverage.class);
    coverage = dataOr.get();
    CoordinateReferenceSystem crs = project.getDefaultCrs();
    GeometryFactory factory = project.getSridSet().getGeometryFactory(crs);

    Point[] pointsInBounds = new Point[] {
        // check all the 4 corners are within bounds
        factory.createPoint(new Coordinate(xMin, yMin)),
        factory.createPoint(new Coordinate(xMax, yMax)),
        factory.createPoint(new Coordinate(xMin, yMax)),
        factory.createPoint(new Coordinate(xMax, yMin)),
        // check some random points within bounds
        factory.createPoint(new Coordinate(167.1, -44.8)),
        factory.createPoint(new Coordinate(168.0, -43.9)),
        factory.createPoint(new Coordinate(168.5, -45.1)),
        factory.createPoint(new Coordinate(169.0, -44.2))
    };

    for (int i = 0; i < pointsInBounds.length; i++) {
      Point point = pointsInBounds[i];
      Double evaluated = (Double) coverage.evaluate(point);
      String errorStr = String.format("pointsInBounds[%d] evaluated null", i);
      assertNotNull(errorStr, evaluated);
    }

    Point[] pointsOutOfBounds = new Point[] {
        // check just outside of the four corners
        factory.createPoint(new Coordinate(xMax + 0.01, yMax)),
        factory.createPoint(new Coordinate(xMin - 0.01, yMax)),
        factory.createPoint(new Coordinate(xMax, yMax + 0.01)),
        factory.createPoint(new Coordinate(xMax, yMin - 0.01)),
        // try flipping signedness of valid co-ordinates
        factory.createPoint(new Coordinate(168.0, 43.9)),
        factory.createPoint(new Coordinate(-168.0, -43.9)),
        // or valid co-ordinates that are specified the wrong way round
        factory.createPoint(new Coordinate(yMin, xMin)),
        factory.createPoint(new Coordinate(yMax, xMax)),
        // check min/max lat/long values just for kicks
        factory.createPoint(new Coordinate(0.0, 0.0)),
        factory.createPoint(new Coordinate(180.0, 90.0)),
        factory.createPoint(new Coordinate(-180.0, -90.0))
    };

    for (int i = 0; i < pointsOutOfBounds.length; i++) {
      Point point = pointsOutOfBounds[i];
      Double evaluated = (Double) coverage.evaluate(point);
      String errorStr = String.format("pointsOutOfBounds[%d] evaluated non-null", i);
      assertNull(errorStr, evaluated);
    }
  }

  @Test
  public void canReadUsgsDataset() {
    // sanity-checks we can read through the data in the expected order, i.e. we're
    // not doing anything dumb with the x,y ordering
    try (
         H5File file = new H5File(shakemapPath);
         H5Dataset dataset = file.openDataset(pgaMeanPath)
     ) {
      // get the raw mean values from the HDF5 file
      double[] meanPgaValues = dataset.getElementsAsDouble();
      UsgsShakeMapReader reader = new UsgsShakeMapReader(file, pgaMeanPath.getPath());
      Dataset2D results = reader.getData().get();
      assertEquals(results.getHeight(), 217);
      assertEquals(results.getWidth(), 303);
      assertThat(results.getType(), is(Types.FLOATING));

      // traverse the dataset in the order the data appears in the HDF5 file
      int index = 0;
      for (long y = 0; y < results.getHeight(); y++) {
        for (long x = 0; x < results.getWidth(); x++) {
          assertThat((Double) results.getValue(x, y), is(meanPgaValues[index++]));
        }
      }
    }
  }

  @Test
  public void canReadMultipleUsgsDatasets() {
    // sanity-check multiple datasets get rolled up into a struct
    try (
         H5File file = new H5File(shakemapPath);
        H5Dataset meanDataset = file.openDataset(pgaMeanPath);
        H5Dataset stddevDataset = file.openDataset(pgaStdDevPath)
     ) {
      // get the raw mean values from the HDF5 file
      double[] meanPgaValues = meanDataset.getElementsAsDouble();
      double[] stddevPgaValues = stddevDataset.getElementsAsDouble();
      UsgsShakeMapReader reader = new UsgsShakeMapMultiDatasetReader(file,
          Arrays.asList(pgaMeanPath.getPath(), pgaStdDevPath.getPath()));
      Dataset2D results = reader.getData().get();
      assertEquals(results.getHeight(), 217);
      assertEquals(results.getWidth(), 303);
      Struct expectedType = Struct.of("mean", Types.FLOATING, "std", Types.FLOATING);
      assertThat(results.getType(), is(expectedType));

      // traverse the dataset in the order the data appears in the HDF5 file
      int index = 0;
      for (long y = 0; y < results.getHeight(); y++) {
        for (long x = 0; x < results.getWidth(); x++) {
          assertThat(results.getValue(x, y), is(
              Tuple.ofValues(expectedType, meanPgaValues[index], stddevPgaValues[index])
          ));
          index++;
        }
      }
    }
  }

  /*
   * sanity checks the underlying calculation - 50th percentile should match mean
   */
  @Test
  public void canCalculate50thPercentile() {
    // no percentile, so should default to 50th, which is also the mean
    Bookmark bookmark = new Bookmark("id", "desc", "usgs-shakemap", shakemapPath.toUri(), Collections.emptyMap());
    Optional<ResolvedBookmark> resolvedOr = resolver.resolve(bookmark, project.newBindingContext());
    coverage = resolvedOr.get().getData(TypedCoverage.class).get();

    try (
        H5File file = new H5File(shakemapPath);
        H5Dataset meanDataset = file.openDataset(pgaMeanPath);
    ) {
      // get the raw values from the HDF5 file
      double[] mean = meanDataset.getElementsAsDouble();

      for (Map.Entry<Integer, Point> entry : samplePoints.entrySet()) {
        int index = entry.getKey();
        Point point = entry.getValue();
        assertThat(evaluate(point), isPga(mean[index]));
      }
    }
  }

  /*
   * sanity checks the underlying percentile calculation for something other than 50
   */
  @Test
  public void canCalculateDifferentPercentile() {
    // here we use percentile corresponding to mean + 1 stdDev
    Bookmark bookmark = new Bookmark("id", "desc", "usgs-shakemap", shakemapPath.toUri(),
        ImmutableMap.of("percentile", Arrays.asList("84.13447")));
    Optional<ResolvedBookmark> resolvedOr = resolver.resolve(bookmark, project.newBindingContext());
    coverage = resolvedOr.get().getData(TypedCoverage.class).get();

    try (
        H5File file = new H5File(shakemapPath);
        H5Dataset meanDataset = file.openDataset(pgaMeanPath);
        H5Dataset stdDevDataset = file.openDataset(pgaStdDevPath);
    ) {
      // get the raw values from the HDF5 file
      double[] mean = meanDataset.getElementsAsDouble();
      double[] stddev = stdDevDataset.getElementsAsDouble();

      for (Map.Entry<Integer,Point> entry : samplePoints.entrySet()) {
        int index = entry.getKey();
        Point point = entry.getValue();
        assertThat(evaluate(point), isPga(mean[index] + stddev[index]));
      }
    }
  }

  private Matcher<Double> isPga(double expectedValueLogUnits) {
    return closeTo(Math.exp(expectedValueLogUnits), 0.00000001);
  }

  private void assertExpectedCoverage(TypedCoverage shakemapCoverage,
      double expectedValueLogUnits, Point point) {

    double evaluated = (Double) shakemapCoverage.evaluate(point);
    double expected = Math.exp(expectedValueLogUnits);
    assertEquals(evaluated, expected, 0.00000001);
  }

  /*
   * checks the DatasetCoverage maps a lat/long co-ordinate to the correct
   * position in the 2-D grid
   */
  @Test
  public void canResolvePointsToCorrectArrayIndex() {
    Bookmark bookmark = new Bookmark("id", "desc", "usgs-shakemap", shakemapPath.toUri(),
        ImmutableMap.of("percentile", Arrays.asList("50.0")));
    Optional<ResolvedBookmark> resolvedOr = resolver.resolve(bookmark, project.newBindingContext());
    resolvedOr.get();
    ResultOrProblems<TypedCoverage> dataOr = resolvedOr.get().getData(TypedCoverage.class);
    coverage = dataOr.get();

    try (
        H5File file = new H5File(shakemapPath);
        H5Dataset meanDataset = file.openDataset(pgaMeanPath);
    ) {
      // get the raw mean values from the HDF5 file
      double[] meanPgaValues = meanDataset.getElementsAsDouble();

      CoordinateReferenceSystem crs = project.getDefaultCrs();
      GeometryFactory factory = project.getSridSet().getGeometryFactory(crs);

      double slush = gridElementSize / 3;

      // the assumption here is that the grid starts in the upper left.
      // So check index zero maps to xMin, yMax
      assertExpectedCoverage(coverage, meanPgaValues[0],
          factory.createPoint(new Coordinate(xMin, yMax)));
      assertExpectedCoverage(coverage, meanPgaValues[0],
          factory.createPoint(new Coordinate(xMin + slush, yMax)));
      assertExpectedCoverage(coverage, meanPgaValues[0],
          factory.createPoint(new Coordinate(xMin, yMax - slush)));
      assertExpectedCoverage(coverage, meanPgaValues[0],
          factory.createPoint(new Coordinate(xMin + slush, yMax - slush)));

      // check last index zero maps to xMax, yMin
      int lastIndex = meanPgaValues.length - 1;
      assertExpectedCoverage(coverage, meanPgaValues[lastIndex],
          factory.createPoint(new Coordinate(xMax, yMin)));
      assertExpectedCoverage(coverage, meanPgaValues[lastIndex],
          factory.createPoint(new Coordinate(xMax - slush, yMin)));
      assertExpectedCoverage(coverage, meanPgaValues[lastIndex],
          factory.createPoint(new Coordinate(xMax, yMin + slush)));
      assertExpectedCoverage(coverage, meanPgaValues[lastIndex],
          factory.createPoint(new Coordinate(xMax - slush, yMin + slush)));

      // check index somewhere near the middle: x=200, y=100
      double xCoord = xMin + (gridElementSize * 200);
      double yCoord = yMax - (gridElementSize * 100);
      // round to 3 decimal places as too much precision can throw the
      // conversion off slightly
      xCoord = Math.round(xCoord * 1000d) / 1000d;
      yCoord = Math.round(yCoord * 1000d) / 1000d;
      // convert to a 1-D index: y-index * grid-width + x-index
      int specificIndex = 100 * 303 + 200;  // 30500

      assertExpectedCoverage(coverage, meanPgaValues[specificIndex],
          factory.createPoint(new Coordinate(xCoord, yCoord)));
      assertExpectedCoverage(coverage, meanPgaValues[specificIndex],
          factory.createPoint(new Coordinate(xCoord + slush, yCoord)));
      assertExpectedCoverage(coverage, meanPgaValues[specificIndex],
          factory.createPoint(new Coordinate(xCoord, yCoord - slush)));
      assertExpectedCoverage(coverage, meanPgaValues[specificIndex],
          factory.createPoint(new Coordinate(xCoord + slush, yCoord - slush)));

      // bottom-left corner: x=0, y=216
      specificIndex = 216 * 303;
      assertExpectedCoverage(coverage, meanPgaValues[specificIndex],
          factory.createPoint(new Coordinate(xMin, yMin)));

      // top-right corner: x=302, y=0
      assertExpectedCoverage(coverage, meanPgaValues[302],
          factory.createPoint(new Coordinate(xMax, yMax)));

      // check index right in the centre: x=151, y=108
      xCoord = xMin + (gridElementSize * 151.4);
      yCoord = yMax - (gridElementSize * 108.4);
      // convert to a 1-D index: y-index * grid-width + x-index
      specificIndex = 108 * 303 + 151;

      assertExpectedCoverage(coverage, meanPgaValues[specificIndex],
          factory.createPoint(new Coordinate(xCoord, yCoord)));

      // off-centre, lower-right: x=250, y=200
      xCoord = xMin + (gridElementSize * 250);
      yCoord = yMax - (gridElementSize * 200);
      specificIndex = 200 * 303 + 250;
      assertExpectedCoverage(coverage, meanPgaValues[specificIndex],
          factory.createPoint(new Coordinate(xCoord, yCoord)));

      // off-centre, upper-left: x=100, y=75
      xCoord = xMin + (gridElementSize * 100.25);
      yCoord = yMax - (gridElementSize * 75.25);
      specificIndex = 75 * 303 + 100;
      assertExpectedCoverage(coverage, meanPgaValues[specificIndex],
          factory.createPoint(new Coordinate(xCoord, yCoord)));

      // off-centre, lower-left: x=80, y=135
      xCoord = xMin + (gridElementSize * 80.3);
      yCoord = yMax - (gridElementSize * 135.3);
      specificIndex = 135 * 303 + 80;
      assertExpectedCoverage(coverage, meanPgaValues[specificIndex],
          factory.createPoint(new Coordinate(xCoord, yCoord)));

      // off-centre, upper-right: x=250, y=50
      xCoord = xMin + (gridElementSize * 250.1);
      yCoord = yMax - (gridElementSize * 50.3);
      specificIndex = 50 * 303 + 250;
      assertExpectedCoverage(coverage, meanPgaValues[specificIndex],
          factory.createPoint(new Coordinate(xCoord, yCoord)));

    }
  }

  @Test
  public void canReadShakemapDatasetDirectly() {
    // check we can read the stdddev values on their own
    Bookmark bookmark = new Bookmark("id", "desc", "usgs-shakemap", shakemapPath.toUri(),
        ImmutableMap.of("dataset", Arrays.asList(pgaStdDevPath.getPath())));
    Optional<ResolvedBookmark> resolvedOr = resolver.resolve(bookmark, project.newBindingContext());
    coverage = resolvedOr.get().getData(TypedCoverage.class).get();

    try (
        H5File file = new H5File(shakemapPath);
        H5Dataset dataset = file.openDataset(pgaStdDevPath);
    ) {
      // get the raw mean values from the HDF5 file
      double[] datasetValues = dataset.getElementsAsDouble();

      for (Map.Entry<Integer, Point> entry : samplePoints.entrySet()) {
        int index = entry.getKey();
        Point point = entry.getValue();
        assertThat(evaluate(point), is(datasetValues[index]));
      }
    }
  }

  @Test
  public void cannotUseBothDatasetAndPercentileParams() {
    // check we can read the stdddev values on their own
    Bookmark bookmark = new Bookmark("id", "desc", "usgs-shakemap", shakemapPath.toUri(),
        ImmutableMap.of(
            "dataset", Arrays.asList(pgaStdDevPath.getPath()),
            "percentile", Arrays.asList("75.0")
        ));
    Optional<ResolvedBookmark> resolvedOr = resolver.resolve(bookmark, project.newBindingContext());
    assertThat(resolvedOr.get().validate(), contains(
        equalTo(ParamProblems.get().mutuallyExclusive("percentile", "dataset"))
    ));
    assertTrue(resolvedOr.get().hasValidationErrors());
  }

  @Test
  public void canReadDatasetsWithSimilarNames() {
    // it should be possible for user to read both PGA mean and MMI mean
    Bookmark bookmark = new Bookmark("id", "desc", "usgs-shakemap", shakemapPath.toUri(),
        ImmutableMap.of(
            "dataset", Arrays.asList(
                "/arrays/imts/GREATER_OF_TWO_HORIZONTAL/PGA/mean",
                "/arrays/imts/GREATER_OF_TWO_HORIZONTAL/MMI/mean")
        ));
    Optional<ResolvedBookmark> resolvedOr = resolver.resolve(bookmark, project.newBindingContext());
    coverage = resolvedOr.get().getData(TypedCoverage.class).get();
    Struct expected = Struct.of("mean", Types.FLOATING, "MMI_mean", Types.FLOATING);
    assertThat(coverage.getType(), is(expected));
  }

  @Test
  public void canGetSensibleErrorForBadDataset() {
    String badDataset = "/nope/bad/dataset/name/sorry";
    Bookmark bookmark = new Bookmark("id", "desc", "usgs-shakemap", shakemapPath.toUri(),
        ImmutableMap.of("dataset", Arrays.asList(badDataset)));
    Optional<ResolvedBookmark> resolvedOr = resolver.resolve(bookmark, project.newBindingContext());
    assertThat(resolvedOr.get().getData(TypedCoverage.class), failedResult(hasAncestorProblem(
        is(GeneralProblems.get().noSuchObjectExists(badDataset, H5Dataset.class))
    )));
  }
}
