/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.output;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

import lombok.AllArgsConstructor;
import nz.org.riskscape.engine.bind.BindingContext;
import nz.org.riskscape.engine.bind.Parameter;
import nz.org.riskscape.engine.data.coverage.GridTypedCoverage;
import org.geotools.api.referencing.crs.CoordinateReferenceSystem;
import org.geotools.coverage.grid.GridCoverage2D;
import org.geotools.coverage.grid.GridEnvelope2D;
import org.geotools.gce.geotiff.GeoTiffWriter;
import org.geotools.geometry.jts.ReferencedEnvelope;
import org.locationtech.jts.geom.Geometry;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import nz.org.riskscape.engine.GeometryProblems;
import nz.org.riskscape.engine.SRIDSet;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.bind.ParameterField;
import nz.org.riskscape.engine.geo.GeometryUtils;
import nz.org.riskscape.engine.problem.GeneralProblems;
import nz.org.riskscape.engine.problem.ProblemFactory;
import nz.org.riskscape.engine.problem.ProblemPlaceholder;
import nz.org.riskscape.engine.problem.SeverityLevel;
import nz.org.riskscape.engine.raster.VectorToRaster;
import nz.org.riskscape.engine.resource.CreateHandle;
import nz.org.riskscape.engine.rl.ExpressionRealizer;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.types.Geom;
import nz.org.riskscape.engine.types.Referenced;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Struct.StructMember;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.ProblemSink;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.PropertyAccess;

public class GeoTiffFormat extends BaseFormat {

  public interface LocalProblems extends ProblemFactory {

    /**
     * Tip for when the grid dimensions are invalid
     */
    Problem dimensionsTip();

    /**
     * When features have been ignore because they are outside of the geotiff bounds
     */
    @SeverityLevel(Problem.Severity.WARNING)
    Problem skippedFeaturesOutOfBounds(int count, URI tiffLocation);

    /**
     * When features have not be written to the GeoTIFF because their valus is either null or NaN.
     */
    @SeverityLevel(Problem.Severity.WARNING)
    Problem skippedFeaturesNullOrNan(int count, URI tiffLocation);
  }

  public static final LocalProblems PROBLEMS = Problems.get(LocalProblems.class);

  @RequiredArgsConstructor
  private static class Writer extends RiskscapeWriter {

    private final Function<Tuple, Number> valueExtractor;
    private final Function<Tuple, Geometry> geometryExtractor;
    private final VectorToRaster v2r;
    private final CreateHandle handle;
    private final ProblemSink problemSink;
    private int skippedOutOfBounds = 0;
    private int skippedNullOrNaN = 0;

    @Getter
    private URI storedAt = null;

    @Override
    public void write(Tuple value) {
      VectorToRaster.DrawFeatureResult drawResult = v2r.drawFeature(
          valueExtractor.apply(value),
          geometryExtractor.apply(value)
      );
      if (drawResult == VectorToRaster.DrawFeatureResult.OUT_OF_BOUNDS) {
        ++skippedOutOfBounds;
      } else if (drawResult == VectorToRaster.DrawFeatureResult.SKIPPED_NO_VALUE_OR_GEOMETRY) {
        ++skippedNullOrNaN;
      }
    }

    @Override
    public void close() throws IOException {
      GridCoverage2D coverage = v2r.constructCoverage("layer");

      GeoTiffWriter writer = new GeoTiffWriter(handle.getOutputStream());
      writer.write(coverage, null);
      writer.dispose();
      storedAt = handle.store();

      // warn the user if features were not drawn
      if (skippedOutOfBounds > 0) {
        problemSink.accept(PROBLEMS.skippedFeaturesOutOfBounds(skippedOutOfBounds, storedAt));
      }
      if (skippedNullOrNaN > 0) {
        problemSink.accept(PROBLEMS.skippedFeaturesNullOrNan(skippedNullOrNaN, storedAt));
      }
    }

  }

  public static class Options extends FormatOptions {

    @ParameterField
    Optional<Double> gridResolution = Optional.empty();
    @ParameterField
    Optional<Expression> bounds =  Optional.empty();
    @ParameterField
    Optional<Expression> value = Optional.empty();
    @ParameterField
    Optional<Expression> geometry = Optional.empty();
    @ParameterField
    VectorToRaster.PixelStrategy pixelStatistic = VectorToRaster.PixelStrategy.MEAN;
    @ParameterField
    Optional<GridTypedCoverage> template = Optional.empty();
  }

  public GeoTiffFormat() {
    super("geotiff", "tif", "image/tiff");
  }

  @Override
  public Class<? extends FormatOptions> getWriterOptionsClass() {
    return Options.class;
  }

  @Override
  public Optional<WriterConstructor> getWriterConstructor() {
    return Optional.of((context, type, handle, options)
                    -> ProblemException.catching(() -> {
              Options opts = options.flatMap(o -> {
                if (o instanceof Options opt) {
                  return Optional.of(opt);
                }
                return Optional.empty();
              }).orElse(null);
              if (opts == null) {
                // error, options are required
                throw new ProblemException(GeneralProblems.get().required("options"));
              }

              Builder builder =
                new Builder(opts, type, context.getExpressionRealizer(), context.getProject().getSridSet());
              return builder.build(handle, context.getProject().getProblemSink());
            })
    );
  }

  @Override
  public ResultOrProblems<? extends FormatOptions> buildOptions(Map<String, List<?>> paramMap,
                                                                BindingContext context,
                                                                Struct input) {
    ResultOrProblems<? extends FormatOptions> built = super.buildOptions(paramMap, context, input);
    if (built.hasProblems()) {
      return built;
    }
    Options options = (Options) built.get();
    // use the Builder to validate the required parameters are present and are valid
    Builder builder = new Builder(
            options,
            input,
            context.getRealizationContext().getExpressionRealizer(),
            context.getProject().getSridSet());
    return ProblemException.catching(() -> {
      builder.validate();
      return options;
    }).composeProblems(Problems.foundWith(ProblemPlaceholder.of(Format.FormatOptions.class, getId())));
  }

  @Override
  public List<String> getRequiredOptions(BindingContext context) {
    return Arrays.asList("grid-resolution", "bounds");
  }

  /**
   * Helper to make sense of the various GeoTIFF options, check they are valid, and turn them into a Writer
   */
  @AllArgsConstructor
  private static class Builder {
    Options opts;
    Struct input;
    ExpressionRealizer realizer;
    SRIDSet sridSet;

    public RealizedExpression getValueExpression() throws ProblemException {
      RealizedExpression valueExpr = opts.value
              .map(value -> realizer.realize(input, value))
              .orElseGet(() -> expressionToFirstOfType(false))
              .getOrThrow();
        if (!valueExpr.getResultType().isNumeric()) {
          throw new ProblemException(TypeProblems.get()
                  .requiresOneOf("value", List.of(Types.FLOATING, Types.INTEGER), valueExpr.getResultType())
          );
        }
        return valueExpr;
    }

    public RealizedExpression getGeometryExpression() throws ProblemException {
      RealizedExpression geomExpr = opts.geometry
              .map(value -> realizer.realize(input, value))
              .orElseGet(() -> expressionToFirstOfType(true))
              .getOrThrow();
      if (geomExpr.getResultType().findAllowNull(Geom.class).isEmpty()) {
        throw new ProblemException(
                TypeProblems.get().mismatch("geometry", Types.GEOMETRY, geomExpr.getResultType())
        );
      }
      return geomExpr;
    }

    public CoordinateReferenceSystem getTargetCrs() throws ProblemException {
      if (opts.template.isPresent()) {
        return opts.template.get().getCoordinateReferenceSystem();
      }
      Type geomType = getGeometryExpression().getResultType();
      Referenced referenced = geomType.findAllowNull(Referenced.class)
              .orElseThrow(() -> new ProblemException(GeometryProblems.get().notReferenced(geomType)));
      return referenced.getCrs();
    }

    public ReferencedEnvelope getBounds() throws ProblemException {
      if (opts.bounds.isPresent()) {
        CoordinateReferenceSystem targetCrs = getTargetCrs();
        Geometry boundsGeom = realizer.realizeConstant(opts.bounds.get())
                .flatMap(re -> {
                  if (re.getResultType().findAllowNull(Geom.class).isEmpty()) {
                    return ResultOrProblems.failed(
                            TypeProblems.get().mismatch("bounds", Types.GEOMETRY, re.getResultType())
                    );
                  }
                  return ResultOrProblems.of((Geometry) re.evaluate(Tuple.EMPTY_TUPLE));
                }).map(bounds ->
                        // reproject the bounds to match the referenced CRS
                        sridSet.reproject(bounds, sridSet.get(targetCrs)
                        )).getOrThrow();

        return new ReferencedEnvelope(boundsGeom.getEnvelopeInternal(), targetCrs);
      } else if (opts.template.isPresent()) {
        return opts.template.get().getEnvelope().get();
      } else {
        throw new ProblemException(GeneralProblems.required("bounds", Parameter.class));
      }
    }

    public double getScale() throws ProblemException {
      if (opts.gridResolution.isPresent()) {
        return 1D / GeometryUtils.toCrsUnits(opts.gridResolution.get(), getTargetCrs());
      } else if (opts.template.isPresent()) {
        // we infer the grid-resolution from the template GeoTIFF's bounds / pixels height.
        // This will only give us the same grid resolution if the template has square pixels
        // (it is possible to make rectangular pixels though)
        GridEnvelope2D range = opts.template.get().getCoverage().getGridGeometry().getGridRange2D();
        ReferencedEnvelope templateBounds = opts.template.get().getEnvelope().get();
        double gridResolution = VectorToRaster.getHeightCrsUnits(templateBounds) / range.height;
        // note that this value is already in the correct CRS units
        return 1D / gridResolution;
      } else {
        throw new ProblemException(GeneralProblems.required("grid-resolution", Parameter.class));
      }
    }

    public void validate() throws ProblemException {
      // sanity-check expressions realize OK
      getValueExpression();
      getGeometryExpression();

      // sanity-check dimensions are OK
      VectorToRaster.getDimensions(getBounds(), getScale())
              .map(d -> d, p -> p.withChildren(PROBLEMS.dimensionsTip()))
              .getOrThrow();
    }

    public Writer build(CreateHandle handle, ProblemSink problemSink) throws ProblemException {
      validate();

      RealizedExpression valueExpr = getValueExpression();
      RealizedExpression geomExpr = getGeometryExpression();

      VectorToRaster v2r = new VectorToRaster();
      v2r.initialize(getBounds(), getScale(), opts.pixelStatistic);
      // make sure we reproject to match the target CRS for the output
      int srid = sridSet.get(getTargetCrs());
      return new Writer(
              (t) -> (Number)valueExpr.evaluate(t),
              (t) -> {
                // NB: reproject() isn't null-safe
                Geometry geom = (Geometry) geomExpr.evaluate(t);
                if (geom == null) {
                  return geom;
                }
                return sridSet.reproject(geom, srid);
              },
              v2r,
              handle,
              problemSink
      );
    }

    /**
     * Builds an expression to access the first found attribute of struct of the desired type.
     * <p>
     * Desired type is determined by isGeometry (Types.GEOMETRY when set, any numeric type otherwise).
     */
    private ResultOrProblems<RealizedExpression> expressionToFirstOfType(boolean isGeometry) {
      List<String> members = findSegmentsToFirstMemberOfType(input, isGeometry);
      if (members.isEmpty()) {
        return ResultOrProblems.failed(TypeProblems.get()
                .structMustHaveMemberType(isGeometry ? Types.GEOMETRY : Types.FLOATING, input));
      }
      return realizer.realize(input, PropertyAccess.of(members));
    }

    private List<String> findSegmentsToFirstMemberOfType(Struct struct, boolean isGeometry) {
      for (StructMember member : struct.getMembers()) {
        boolean hasType = isGeometry
                ? member.getType().findAllowNull(Geom.class).isPresent()
                : member.getType().isNumeric();
        if (hasType) {
          List<String> found = new ArrayList<>();
          found.add(member.getKey());
          return found;
        }
        Optional<Struct> structMember = member.getType().findAllowNull(Struct.class);
        if (structMember.isPresent()) {
          List<String> nested = findSegmentsToFirstMemberOfType(structMember.get(), isGeometry);
          if (!nested.isEmpty()) {
            nested.add(0, member.getKey());
            return nested;
          }
        }
      }
      return Collections.emptyList();
    }
  }
}
