/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.raster;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Transparency;
import java.awt.color.ColorSpace;
import java.awt.image.ComponentColorModel;
import java.awt.image.DataBuffer;
import java.awt.image.SampleModel;
import java.awt.image.WritableRaster;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;

import javax.media.jai.PlanarImage;
import javax.media.jai.RasterFactory;
import javax.media.jai.TiledImage;

import nz.org.riskscape.problem.ResultOrProblems;
import org.geotools.coverage.Category;
import org.geotools.coverage.GridSampleDimension;
import org.geotools.coverage.grid.GridCoverage2D;
import org.geotools.coverage.grid.GridCoverageFactory;
import org.geotools.coverage.grid.GridEnvelope2D;
import org.geotools.coverage.grid.GridGeometry2D;
import org.geotools.geometry.jts.Geometries;
import org.geotools.geometry.jts.ReferencedEnvelope;
import org.geotools.metadata.i18n.Vocabulary;
import org.geotools.metadata.i18n.VocabularyKeys;
import org.geotools.util.NumberRange;
import org.geotools.api.referencing.operation.TransformException;
import org.geotools.coverage.grid.GridCoordinates2D;
import org.geotools.referencing.CRS;
import org.geotools.referencing.CRS.AxisOrder;

import com.google.common.collect.ImmutableMap;

import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.GeometryFactory;

import it.geosolutions.jaiext.range.NoDataContainer;
import lombok.extern.slf4j.Slf4j;
import nz.org.riskscape.engine.RiskscapeException;

import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.grid.FeatureGrid;
import nz.org.riskscape.engine.grid.FeatureGridCell;
import nz.org.riskscape.engine.problem.ProblemFactory;
import nz.org.riskscape.engine.relation.Relation;
import nz.org.riskscape.engine.relation.TupleIterator;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.types.Struct.StructMember;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;

/**
 * Utility for producing a {@link GridCoverage2D} from a {@link Relation}, adapted from some geotools code in an
 * unsupported module
 *
 * Support for integer encoding has been removed in favour of a floating point only.  Note that only the 4-byte wide
 * float type is supported here, although it accepts the 8 bit double value - this is because a raster encoded in sRGB
 * will only be 4 bytes, so some loss of precision of value is possible when rasterizing.
 *
 * TODO This class should probably be split in to a couple of different classes, particularly one for drawing jts
 * shapes, and maybe another for iterating and producing the coverage
 */
@Slf4j
public class VectorToRaster {

  public enum DrawFeatureResult {
    SKIPPED_NO_VALUE_OR_GEOMETRY,
    OUT_OF_BOUNDS,
    DRAWN;
  }

  public interface LocalProblems extends ProblemFactory {

    /**
     * When some exception has bubbled up from {@link #constructCoverage(java.lang.CharSequence) }.
     *
     * It's unlikely that we will know the exact cause but most likely it will be due to there not being
     * enough memory available to build it.
     */
    Problem couldNotConstructCoverage(Problem... children);

    /**
     * The grid resolution (x * y) is larger that {@link Integer#MAX_VALUE}.
     */
    Problem gridDimensionsTooBig(long width, long height, int max);
  }
  public static final LocalProblems PROBLEMS = Problems.get(LocalProblems.class);

  /**
   * @return the height of the raster bounds in CRS units
   */
  public static double getHeightCrsUnits(ReferencedEnvelope bounds) {
    // we need to take account of CRS axis ordering there.
    // We expect width/height to be x/y but this is not how all CRSs roll.
    AxisOrder axisOrder = CRS.getAxisOrder(bounds.getCoordinateReferenceSystem());
    return axisOrder == AxisOrder.EAST_NORTH ? bounds.getHeight() : bounds.getWidth();
  }

  /**
   * @return the width of the raster bounds in CRS units
   */
  public static double getWidthCrsUnits(ReferencedEnvelope bounds) {
    AxisOrder axisOrder = CRS.getAxisOrder(bounds.getCoordinateReferenceSystem());
    return axisOrder == AxisOrder.EAST_NORTH ? bounds.getWidth() : bounds.getHeight();
  }

  /**
   * @return the dimensions that the raster will have, or a Problem if dimensions are too big
   */
  public static ResultOrProblems<Dimension> getDimensions(ReferencedEnvelope bounds, double scale) {
    // apply a scale factor to the height/width to get the number of pixels in the grid resolution we want
    Dimension gridDim = new Dimension(
            (int) Math.ceil(scale * getWidthCrsUnits(bounds)),
            (int) Math.ceil(scale * getHeightCrsUnits(bounds))
    );

    if ((1L * gridDim.width * gridDim.height) > Integer.MAX_VALUE) {
      return ResultOrProblems.failed(PROBLEMS.gridDimensionsTooBig(gridDim.width, gridDim.height, Integer.MAX_VALUE));
    }
    return ResultOrProblems.of(gridDim);
  }

  @FunctionalInterface
  private interface PixelSetter {
    void setPixel(TiledImage image, GridCoordinates2D gridPosition, float value);
  }

  private static PixelSetter pixelSetter(BiFunction<Float, Float, Float> mergeFunc) {
    return pixelSetter(0, mergeFunc);
  }

  /**
   * @param band the image band to apply pixel setter to
   * @param mergeFunc function that accepts (value, existing) and returns the new value
   * @return a pixel setter that applies merge function to the band
   */
  private static PixelSetter pixelSetter(int band, BiFunction<Float, Float, Float> mergeFunc) {
    return (theImage, grid, value) -> {
      float existing = theImage.getSampleFloat(grid.x, grid.y, band);
      if (! Float.isNaN(existing)) {
        value = mergeFunc.apply(value, existing);
      }
      theImage.setSample(grid.x, grid.y, band, value);
    };
  }

  /**
   * @return a pixel setter that will apply all of the children setters which are expected to operate
   * on different raster bands.
   */
  private static PixelSetter compositeSetter(PixelSetter... children) {
    return (theImage, grid, value) -> {
      for (PixelSetter child: children) {
        child.setPixel(theImage, grid, value);
      }
    };
  }

  /**
   * Strategy that determines how new and existing pixel values are merged.
   *
   * Also determines how many bands are created in the image, but not that only the first band will be
   * included in the produced raster.
   */
  public enum PixelStrategy {
    OVERWRITE((theImage, grid, value) -> theImage.setSample(grid.x, grid.y, 0, value)),
    MIN(pixelSetter((newValue, existing) -> Math.min(newValue, existing))),
    MAX(pixelSetter((newValue, existing) -> Math.max(newValue, existing))),
    SUM(pixelSetter((newValue, existing) -> newValue + existing)),
    MEAN(
        new float[] {Float.NaN, 0},
        compositeSetter(
            // sum values in band 0
            pixelSetter(0, (newValue, existing) -> newValue + existing),
            // count values in band 1
            pixelSetter(1, (newValue, existing) -> 1 + existing)
        ),
        (samples) -> samples.get(0) / samples.get(1)
    );

    /**
     * Determines how pixel values are merged.
     */
    final PixelSetter pixelSetter;

    /**
     * A array of the the default value that should be set to each pixel per band.
     *
     * The lenght of this array implies how many bands are created in the raster.
     *
     * Note, the first band should always be defaulted to NaN (the NoData value).
     */
    final float[] bandDefaults;

    /**
     * An optional post processor function to merge the values from each pixel band into the first
     * band (which is written out).
     *
     * This function will only be called for pixels that do not contain NaN in the first band.
     *
     * The function will be passed a list of floats, one for each band (implied from {@link #bandDefaults}).
     */
    final Function<List<Float>, Float> postProcessor;

    PixelStrategy(PixelSetter pixelSetter) {
      this.pixelSetter = pixelSetter;
      this.bandDefaults = new float[] {Float.NaN};
      this.postProcessor = null;
    }

    PixelStrategy(final float[] bandDefaults, PixelSetter pixelSetter,
        Function<List<Float>, Float> postProcessor) {
      this.pixelSetter = pixelSetter;
      this.postProcessor = postProcessor;
      this.bandDefaults = bandDefaults;
    }
  }

  private float minAttValue = Float.NaN;
  private float maxAttValue = Float.NaN;

  private ReferencedEnvelope extent;
  private AxisOrder axisOrder;
  private Geometry extentGeometry;
  private GridGeometry2D gridGeom;
  private PixelStrategy pixelStrategy;

  TiledImage image;
  private final Color noDataColor = new Color(Float.floatToIntBits(Float.NaN), true);

  public GridCoverage2D convert(
          Relation relation,
          RealizedExpression expression,
          double scale,
          ReferencedEnvelope bounds,
          String covName) throws ProblemException {

    initialize(bounds, scale, PixelStrategy.OVERWRITE);
    drawAll(relation, expression);
    return constructCoverage(covName);
  }

  private DrawFeatureResult drawFeature(Tuple feature, RealizedExpression valueExpression, StructMember geomMember) {
    Geometry geometry = feature.fetch(geomMember);
    return drawFeature((Number)valueExpression.evaluate(feature), geometry);
  }

  /**
   * Draw the value to the raster for all pixels that intersect the geometry, but only if the feature intersects with
   * the raster bounds.
   *
   * @param value expression to get the feature vaule from the tuple
   * @param geometry expression to get the feature geometry from the tuple
   * @return true if feature was drawn, false if feature geometry or value is null or is outside of initialized bounds.
   */
  public DrawFeatureResult drawFeature(Number value, Geometry geometry) {
    if (value == null || geometry == null || Float.isNaN(value.floatValue())) {
      return DrawFeatureResult.SKIPPED_NO_VALUE_OR_GEOMETRY;
    }
    if (geometry.intersects(extentGeometry)) {
      float featureValue = value.floatValue();

      if (Float.isNaN(minAttValue)) {
        minAttValue = featureValue;
        maxAttValue = featureValue;
      } else if (Float.compare(featureValue, minAttValue) < 0) {
        minAttValue = featureValue;
      } else if (Float.compare(featureValue, maxAttValue) > 0) {
        maxAttValue = featureValue;
      }

      Geometries geomType = Geometries.get(geometry);

      switch (geomType) {
        case MULTIPOLYGON:
        case MULTILINESTRING:
        case MULTIPOINT:
          final int numGeom = geometry.getNumGeometries();
          for (int i = 0; i < numGeom; i++) {
            Geometry geomN = geometry.getGeometryN(i);
            drawGeometry(geomN, featureValue);
          }
          break;

        case POLYGON:
        case LINESTRING:
        case POINT:
          drawGeometry(geometry, featureValue);
          break;

        default:
          throw new UnsupportedOperationException(
                  "Unsupported geometry type: " + geomType.getName());
      }
      return DrawFeatureResult.DRAWN;
    } else {
      return DrawFeatureResult.OUT_OF_BOUNDS;
    }
  }

  /**
   * Construct a {@link GridCoverage2D} containing all the features that have been drawn to it.
   *
   * @param covName name given to the produced coverage
   * @return grid coverage
   */
  public GridCoverage2D constructCoverage(CharSequence covName) {
    try {
      GridCoverageFactory gcf = new GridCoverageFactory();

      // set kludgy no-data
      float noDataValue = Float.intBitsToFloat(noDataColor.getRGB());
      Map<?, ?> properties = ImmutableMap.of(NoDataContainer.GC_NODATA, new NoDataContainer(noDataValue));

      Category category = new Category(
          Vocabulary.formatInternational(VocabularyKeys.NODATA),
          new Color[] {
              new Color(0, 0, 0, 0)
          },
          NumberRange.create(noDataValue, noDataValue)
      );

      GridSampleDimension[] bands = new GridSampleDimension[] {
          new GridSampleDimension("Rasterized", new Category[] {category}, null)
      };

      TiledImage imageToWrite = postProcessImageIfNecessary();

      return gcf.create(covName, imageToWrite, gridGeom, bands, null, properties);
    } catch (Throwable t) {
      throw new RiskscapeException(PROBLEMS.couldNotConstructCoverage(Problems.caught(t)));
    }
  }

  /**
   * Post process the image if necessary
   *
   * Only required for pixel strategies that accumulate data into multiple raster bands. This then
   * requires a post process step to merge the data from multiple bands into the first band which gets
   * saved.
   */
  private TiledImage postProcessImageIfNecessary() {
    if (pixelStrategy.postProcessor == null) {
      // nothing to do
      return image;
    }
    for (int yt = 0; yt < image.getNumYTiles(); yt++) {
      for (int xt = 0; xt < image.getNumXTiles(); xt++) {
        WritableRaster r = image.getWritableTile(xt, yt);
        for (int x = r.getMinX(); x < (r.getMinX() + r.getWidth()); x++) {
          for (int y = r.getMinY(); y < (r.getMinY() + r.getHeight()); y++) {
            float f0 = image.getSampleFloat(x, y, 0);
            if (! Float.isNaN(f0)) {
              List<Float> samples = new ArrayList<>(pixelStrategy.bandDefaults.length);
              samples.add(f0);
              for (int i = 1; i < pixelStrategy.bandDefaults.length; i++) {
                samples.add(image.getSampleFloat(x, y, i));
              }
              image.setSample(x, y, 0, pixelStrategy.postProcessor.apply(samples));
            }
          }
        }

        image.releaseWritableTile(xt, yt);
      }
    }
    return image.getSubImage(
            new int[] {0},
            new ComponentColorModel(
                ColorSpace.getInstance(ColorSpace.CS_GRAY),
                false,
                false,
                Transparency.OPAQUE,
                DataBuffer.TYPE_FLOAT
            )
        );
  }

  private void drawAll(Relation relation, RealizedExpression expression) {
    StructMember geometryStructMember = relation.getSpatialMetadata().get().getGeometryStructMember();
    try (TupleIterator fi = relation.iterator()) {
      log.info("Rasterizing {}...", relation);
      while (fi.hasNext()) {
        try {
          drawFeature(fi.next(), expression, geometryStructMember);
        } catch (Exception e) {
          throw new RuntimeException(e);
        }
      }
    }
  }

  /**
   * Initialize the raster to the given bounds and scale.
   *
   * @param bounds the spatial extent of the raster to produce. Must be in the CRS of the features that will be
   *               drawn to the raster
   * @param scale  Converts the bounds width/height (in CRS units) into the number of pixels, i.e. 1 / grid-resolution.
   *               e.g. a 50m grid uses a 1/50 scale, so for bounds 10km wide this is 10000 * 0.02 = 200 pixels wide.
   * @param newPixelStrategy determines how to set vales to pixels that have already been set
   */
  public void initialize(ReferencedEnvelope bounds, double scale, PixelStrategy newPixelStrategy)
      throws ProblemException {
    this.pixelStrategy = newPixelStrategy;
    Dimension gridDim = getDimensions(bounds, scale).getOrThrow();

    setBounds(bounds);
    this.axisOrder = CRS.getAxisOrder(bounds.getCoordinateReferenceSystem());

    // TODO catch out of memory exception, return something more meaningful
    this.image = createImage(gridDim, newPixelStrategy.bandDefaults.length);
    log.debug("Allocated image for rasterization {}", this.image);
    // we need to fill in the no data value
    for (int yt = 0; yt < image.getNumYTiles(); yt++) {
      for (int xt = 0; xt < image.getNumXTiles(); xt++) {
        WritableRaster r = image.getWritableTile(xt, yt);
        for (int x = r.getMinX(); x < (r.getMinX() + r.getWidth()); x++) {
          for (int y = r.getMinY(); y < (r.getMinY() + r.getHeight()); y++) {
            for (int band = 0; band < newPixelStrategy.bandDefaults.length; band++) {
              r.setSample(x, y, band, newPixelStrategy.bandDefaults[band]);
            }
          }
        }

        image.releaseWritableTile(xt, yt);
      }
    }
    gridGeom = new GridGeometry2D(new GridEnvelope2D(0, 0, gridDim.width, gridDim.height), extent);
  }

  private void setBounds(ReferencedEnvelope bounds) {
    extent = bounds;
    extentGeometry = (new GeometryFactory()).toGeometry(extent);
  }

  private TiledImage createImage(Dimension gridDim, int bands) {
      SampleModel sm = RasterFactory.createPixelInterleavedSampleModel(
          DataBuffer.TYPE_FLOAT,
          gridDim.width,
          gridDim.height,
          bands
    );

    return new TiledImage(0,0, gridDim.width, gridDim.height, 0, 0, sm, PlanarImage.createColorModel(sm));
  }

  @SuppressWarnings("unchecked")
  private void drawGeometry(Geometry geometry, float featureValue) {
    if (Float.isNaN(featureValue)) {
      // NaN is the no data value so we've got nothing to do.
      return;
    }
    FeatureGrid featureGrid;
    try {
      featureGrid = new FeatureGrid(geometry, axisOrder, gridGeom);
    } catch (TransformException e) {
      throw new RiskscapeException(Problems.caught(e));
    }

    // Using the feature grid to carve up the feature in to grid cells trades off speed for accuracy.
    // This is only really a problem if the feature is very detailed as that slows down the cell intersection
    // test.
    Iterator<FeatureGridCell> cells = featureGrid.cellIterator();
    while (cells.hasNext()) {
      FeatureGridCell cell = cells.next();
      if (geometry.intersects(cell.getCellPolygon())) {
        GridCoordinates2D gridPosition = cell.getGridPosition();
        pixelStrategy.pixelSetter.setPixel(image, gridPosition, featureValue);
      }
    }

  }

}
