/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.tiff;

import java.awt.image.BufferedImage;
import java.awt.image.DataBuffer;
import java.awt.image.RenderedImage;
import java.io.IOException;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.BitSet;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

import javax.imageio.ImageReadParam;
import javax.imageio.ImageReader;
import javax.imageio.spi.ImageReaderSpi;

import org.geotools.gce.geotiff.GeoTiffReader;

import it.geosolutions.imageio.plugins.tiff.BaselineTIFFTagSet;
import it.geosolutions.imageio.plugins.tiff.TIFFField;
import it.geosolutions.imageioimpl.plugins.tiff.TIFFImageReader;
import lombok.extern.slf4j.Slf4j;

/**
 * Extends the default {@link TIFFImageReader} to add support for querying empty tiles, so that sampling operations
 * can avoid initializing and indexing tiles (float[]s) full of zeroes.
 *
 * Some hazard rasters, particularly water based, will be very fine resolution but very sparse.  The raw amount of RAM
 * required to represent these images is huge, but small on disk.  Being able to skip generating image tiles for these
 * empty tiles makes a big difference to performance.
 *
 * Can only be used with a single image from a TIFF, which *should* be fine as I can't see any code in the
 * {@link GeoTiffReader} that would do anything but decide on a single image from the container to use.  See
 * {@link #read(int, ImageReadParam)} for information on why it is this way
 * {@link https://gdal.org/en/stable/drivers/raster/gtiff.html#sparse-files}
 */
@Slf4j
public class SparseTIFFImageReader extends TIFFImageReader {

  private static final int MAX_FIND_NODATA_TILE_ATTEMPTS = 3;

  /**
   * Only consider tiles less than this size when scanning for no-data
   */
  private static final int MAX_MIN_TILE_SIZE = 1024 * 10;

  /**
   * Initialize and return the {@link SparseTIFFImageReader} that is being used by the given coverage so that it
   * is correctly setup with the right image index from the TIFF.
   */
  public static SparseTIFFImageReader initialize(SparseTiffCoverage coverage) {
    RenderedImage renderedImage = coverage.getRenderedImage();

    ImageReader reader = (ImageReader) renderedImage.getProperty("JAI.ImageReader");

    if (reader instanceof SparseTIFFImageReader tiffReader) {
      tiffReader.initNoDataInfo(() -> {
        coverage.sampleAPixelForInit();
      });

      // make sure this actually worked.  If it didn't, it might be a sign that the call chain for
      // GridCoverage2D#evaluate through to this class hasn't worked as expected
      if (tiffReader.firstImageIndex == -1) {
        throw new AssertionError("Image index was not set during initialization!");
      }

      return tiffReader;
    } else {
      throw new IllegalArgumentException("Wrong ImageReader class, can not initialize " + reader);
    }
  }

  // records which tiles are all no-data
  private final BitSet noDataTiles = new BitSet();

  // remembers the index of the image we are optimized for.  Attempts to use any image other than this are going to
  // fail.
  private int firstImageIndex = -1;

  // track whether the initNoDataInfo is being called - used to alter the behaviour of the read method.
  private boolean runningInit = false;

  public SparseTIFFImageReader(ImageReaderSpi originatingProvider) {
    super(originatingProvider);
  }

  /**
   * Unit-Test friendly method for setting the imageIndex when we are not being used by a coverage
   */
  void initialize(int imageIndex) {
    initNoDataInfo(() -> {
      firstImageIndex = imageIndex;
    });
  }

  private void initNoDataInfo(Runnable callback) {
    this.runningInit = true;
    try {
      callback.run();
    } finally {
      runningInit = false;
    }

    // not sure how we'd support this, need to see some examples of it first.  Removing this would require
    // that we change isEmptyTile to support looking in all of the bands for empty tile information.  Possible, but
    // I haven't seen any planar geotiffs (yet) to understand and test it.  Maybe we can generate some using gdal and
    // test it?
    if (planarConfiguration == BaselineTIFFTagSet.PLANAR_CONFIGURATION_PLANAR) {
      throw new UnsupportedOperationException("SparseTiff NotSupported with planar images");
    }

    try {
      recordNoDataTiles();
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  /**
   * @return true if the given tile is empty.  Querying pixels from this tile will always result in
   * no_data being returned
   */
  public boolean isEmptyTile(int tileX, int tileY) {
    int tileIndex = tileY * tilesAcross + tileX;

    return noDataTiles.get(tileIndex);
  }

  /**
   * @return true if this image reader has any knowledge of any empty (no-data) tiles
   */
  public boolean hasEmptyTiles() {
    return !noDataTiles.isEmpty();
  }

  /**
   * Override's the default method to aid in our kludge for supporting the isEmptyTile tile method *without* an image
   * index, e.g. isEmptyTile(imageIndex, tileX, tileY) vs isEmptyTile(tileX, tileY).
   *
   * From my reckons, there's no API way to access the index of the image that a coverage is actually sampling.  The
   * JAI Image API doesn't surface it, and neither does the GridCoverage2D API.  That left me in a bind when
   * implementing this:  Do I use reflection to access the index from the JAI Image class, monkey patch the very large
   * GeoTiffReader#read method to store the index that was picked?  In the end I decided neither.  In practice, it
   * looks like it's always going to be one image in a TIFF that gets used, so this mild kludge seemed the lesser of
   * all the evils.
   */
  @Override
  public BufferedImage read(int imageIndex, ImageReadParam param) throws IOException {

    if (firstImageIndex != imageIndex) {
      if (firstImageIndex == -1) {
        if (runningInit) {
          // on first run, we record the imageIndex for use with isEmptyTile later and store our offset info
          this.firstImageIndex = imageIndex;
        }
        // otherwise this reader is being used but hasn't been initialized by the SparseTiffCoverage - no mind, the
        // extra isEmptyTile method isn't going to be used (and would do nothing if it were)
      } else {
        // safety check:  If for some reason we see a call to read with a different imageIndex to the one we've been
        // initialized with, we must bail. We can't support multiple imageIndexes being used, otherwise isEmptyTile is
        // going to be wrong
        throw new IllegalStateException("Saw wrong image index - " + imageIndex + " is not " +  firstImageIndex);
      }
    }

    return super.read(imageIndex, param);
  }

  /**
   * Attempts to record which tiles in this tiff are all no-data so that we can skip constructing tiles for these when
   * using the {@link SparseTiffCoverage}
   */
  private void recordNoDataTiles() throws IOException {
    long[] byteCounts = getByteCounts();

    // not tiled, give up
    if (byteCounts.length == 0) {
      return;
    }

    // first look and see if this is a sparse tiff - a sparse tiff records empty (0 byte) tiles.  We know these are all
    // no-data by convention.
    for (int idx = 0; idx < byteCounts.length; idx++) {
      if (byteCounts[idx] == 0) {
        noDataTiles.set(idx);
      }
    }

    if (noDataTiles.cardinality() > 0) {
      if (log.isDebugEnabled()) {
        log.info("Sparse TIFF has {}/{} empty tiles", noDataTiles.cardinality(), byteCounts.length);
      }

      return;
    }

    // not a sparse tiff, let's try this the slow(er) way.
    scanForNoDataTiles(byteCounts);
  }

  /**
   * Return the tile byte counts from metadata
   */
  private long[] getByteCounts() throws IOException {

    // KLUDGE! The metadata that is returned from getMetadata does not include tile offsets. We have to resort
    // to extracting it via protected field that is initialized any time an image-indexed method is called
    getWidth(this.firstImageIndex);

    // This code was scraped from the base implementation - it's not written in a way that can be exposed
    TIFFField f = imageMetadata.getTIFFField(BaselineTIFFTagSet.TAG_TILE_BYTE_COUNTS);
    if (f == null) {
      f = imageMetadata.getTIFFField(BaselineTIFFTagSet.TAG_STRIP_BYTE_COUNTS);
    }
    if (f == null) {
      f = imageMetadata.getTIFFField(BaselineTIFFTagSet.TAG_JPEG_INTERCHANGE_FORMAT_LENGTH);
    }

    if (f == null) {
      return new long[] {};
    }

    // condense down the byte counts to a bit set (compressed boolean[])
    return getAsLongs(f);
  }

  private long[] getAsLongs(TIFFField field) {
    return IntStream.range(0, Array.getLength(field.getData()))
            .mapToLong(field::getAsLong)
            .toArray();
  }

  /**
   * Iterates through the smallest tiles in the TIFF to see if they are all no-data, recording their indices in
   * noDataTiles to support empty tile lookups.
   *
   * Relies on the fact that a no-data tile will always be stored with the same bytes on disk (even when compressed) to
   * do a fast scan of tiles for no-data, based on the first example of a no-data tile we find.
   */
  private void scanForNoDataTiles(long[] tileByteCounts) throws IOException {

    if (noData == null) {
      log.debug("TIFF is missing no-data metadata, can not apply no-data optimizations");
      return;
    }
    // NB if it were easy to look up, it would be good to bail if the file is not compressed with a known sane format
    // like LZW.  I *think* this method is fine, as it is fairly conservative, but hard to know without having full
    // knowledge of each and every compression format.

    // find the smallest tile size.  Assuming we are using LZW compression, a tile that's all the same value is always
    // going to be the smallest tile size.  There might also be runs of other values as well, but less likely.  Either
    // way, we are going to read candidate tiles and check they are all no data.
    final long minLong = (int) LongStream.of(tileByteCounts).min().orElse(Long.MAX_VALUE);

    // there's no tiles here or the tiles are too big for consideration.  Either way we can't proceed
    if (minLong > MAX_MIN_TILE_SIZE) {
      log.debug("Skipping scanning for no-data tiles, no tiles are less than {} bytes", MAX_MIN_TILE_SIZE);
      return;
    }

    final int min = (int) minLong;

    // these fields from the super class don't get computed until a tile is read  - there's no simple way to set these,
    // so I copy-pastad their initialization from the super class implemetation
    tilesAcross = (width + tileOrStripWidth - 1)/tileOrStripWidth;
    tilesDown = (height + tileOrStripHeight - 1)/tileOrStripHeight;

    // NB we could apply another heuristic here like "if the proportion of tiles that are min sized is < 5% don't
    // bother scanning - it's likely the image isn't sparse in this case and so not worth scanning

    log.debug("Scanning {} tiles looking for size {} for no-data", tileByteCounts.length, min);

    // find the raw bytes (as it is on disk) of the first tile that is all no-data
    // NB could also remember the index we found it at and skip those tiles - probably not worth the extra code?
    final byte[] noDataTileBytes = findNoDataTile(tileByteCounts, min);
    if (noDataTileBytes == null) {
      log.debug("No-data ({}) tile not found", noData);
      return;
    }

    byte[] tileBuffer = new byte[min];

    // run through the rest of the min tiles and see if their raw bytes match as well.
    // NB comparing the raw bytes is waaaaay faster than comparing image tiles.  A compressed no-data tile is likely
    // to be a few hundred bytes, whereas a decompressed tile is a) much bigger in RAM and b) in a structured format and
    // so not as easily comparable.  Arrays.equals is heavily optimized, so this goes fast.

    for (int tileIndex = 0; tileIndex < tileByteCounts.length; tileIndex++) {
      // see if the tile is min size - if it is then it could match noDataTileBytes
      if (tileByteCounts[tileIndex] == min) {
        long offset = getTileOrStripOffset(tileIndex);
        stream.seek(offset);

        // read the bytes off disk in to our buffer and see if they match the exemplar no-data tile
        int read = stream.read(tileBuffer);

        // sanity check
        if (read != min) {
          throw new IOException(
              "Failed to read %d bytes from image stream %s - corrupt tiff?".formatted(read, stream));
        }

        if (Arrays.equals(noDataTileBytes, tileBuffer)) {
          noDataTiles.set(tileIndex);
        }
      }
    }

    log.info("Found {} no-data tiles after scanning {}", noDataTiles.cardinality(), stream);
  }

  /**
   * Iterates through the tile data looking for the first tile of min size that is all no-data.  This is used as an
   * exemplar for doing a much quicker (than this method) raw-byte comparison to find other no-data tiles.
   */
  private byte[] findNoDataTile(long[] tileByteCounts, int min) throws IOException {
    final byte[] tileBuffer = new byte[min];

    int attempts = 0;
    // start looping through tile sizes, look for the smallest tiles and decode them to see if they are no data
    tileloop:
    for (int tileIndex = 0; tileIndex < tileByteCounts.length; tileIndex++) {

      // don't look at all the tiles - this method is relatively slow, so if we don't find a no-data tile quickly using
      // our smallest-tiles-are-probably-no-data heuristic then give up
      if (attempts > MAX_FIND_NODATA_TILE_ATTEMPTS) {
        log.debug("Exceeded number of attempts ({}) to find no data tile, giving up", MAX_FIND_NODATA_TILE_ATTEMPTS);
        return null;
      }

      if (tileByteCounts[tileIndex] == min) {

        int tileX = tileIndex % tilesAcross;
        int tileY = tileIndex / tilesDown;

        // this is the expensive bit - we need to render the tile to see if it's actually no-data.  Once we find one
        // that is, we can avoid it rendering any of the other small tiles and instead compare them to this
        // byte-for-byte
        BufferedImage tile = readTile(firstImageIndex, tileX, tileY);
        DataBuffer buffer = tile.getData().getDataBuffer();

        for (int bi = 0; bi < buffer.getSize(); bi++) {
          if (buffer.getElemDouble(bi) != noData.doubleValue()) {
            attempts++;
            continue tileloop;
          }
        }

        // we didn't continue the loop, must be all no-data.  Read in the tile's raw data and return it.
        long offset = getTileOrStripOffset(tileIndex);
        stream.seek(offset);
        int read = stream.read(tileBuffer);

        // sanity check
        if (read != min) {
          throw new IOException(
              "Failed to read %d bytes from image stream %s - corrupt tiff?".formatted(read, stream));
        }

        return tileBuffer;
      }
    }

    return null;
  }
}
