/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.function.geometry;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;

import org.locationtech.jts.geom.Geometry;

import com.google.common.collect.Maps;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import nz.org.riskscape.engine.ArgsProblems;
import nz.org.riskscape.engine.RiskscapeException;
import nz.org.riskscape.engine.SRIDSet;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.function.ArgumentList;
import nz.org.riskscape.engine.function.ExpensiveResource;
import nz.org.riskscape.engine.function.FunctionArgument;
import nz.org.riskscape.engine.function.RiskscapeFunction;
import nz.org.riskscape.engine.geo.GeometryUtils;
import nz.org.riskscape.engine.geo.IntersectionIndex;
import nz.org.riskscape.engine.relation.Relation;
import nz.org.riskscape.engine.rl.RealizableFunction;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.types.Geom;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.RSList;
import nz.org.riskscape.engine.types.RelationType;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Struct.StructBuilder;
import nz.org.riskscape.engine.types.Struct.StructMember;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.TokenTypes;
import nz.org.riskscape.rl.ast.Constant;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.FunctionCall;
import nz.org.riskscape.rl.ast.FunctionCall.Argument;
import nz.org.riskscape.rl.ast.StructDeclaration;

/**
 * Calculates the intersection between a feature and all the other features in a vector layer
 */
public class LayerIntersections implements RiskscapeFunction, RealizableFunction {

  /**
   * Stores memoized information about how to set a value in the result
   */
  @RequiredArgsConstructor
  private class IndexInfo {
    // if true, the value comes from the indexed rhs tuples, not the lhs
    final boolean fromRhs;
    // index in either lhs or rhs's unsafeValues of our source value
    final int sourceIndex;
    // index in the result's unsafeValues for the value tracked by this IndexInfo object
    final int targetIndex;
  }

  protected static final int NO_SRID_SET = -1;

  @Getter
  private final ArgumentList arguments = ArgumentList.fromArray(
      new FunctionArgument("feature", Types.ANYTHING),
      // This should be relation type, but coercion is annoying me
      new FunctionArgument("layer", Types.ANYTHING),
      new FunctionArgument("merge_attributes", Nullable.ANYTHING),
      new FunctionArgument("return_difference", Nullable.BOOLEAN)
  );

  @Getter
  private final List<Type> argumentTypes = arguments.getArgumentTypes();

  @Getter
  private final Type returnType = RSList.create(Struct.EMPTY_STRUCT);

  @Override
  public Object call(List<Object> args) {
    throw new UnsupportedOperationException();
  }

  /**
   * Prepares an array of IndexInfo objects that encodes where to get each member's values from.  Each IndexInfo
   * object has a target index which corresponds to a member in the result , i.e. `index_info[j].targetIndex`
   * corresponds to `resultTuple.unsafeValues[i]`
   *
   * This is used to simplify function execution to avoid doing lots of looping and searching, which only needs to be
   * done once
   * @param mergeAttributes a map version of the mergeAttributes parameter, which maps result name to
   * source name
   * @param builder a struct builder to append members to as the index is built.
   * @param lhs the type of the function's first argument
   * @param rhs the relation type of the function's second argument
   * @param nullable whether rhs members need to be nullable, which is true when the function returns the difference
   * @param lhsGeomMember the lhs geometry member - this is skipped, as we treat the lhs's geometry specially
   * @return an array of IndexInfo objects that can be used when the function is evaluated
   */
  private IndexInfo[] setupLoop(Map<String, String> mergeAttributes, StructBuilder builder,
      Struct lhs, Struct rhs, boolean nullable, StructMember lhsGeomMember
  ) {

    // clone to pull things from the map as we consume them, but retaining order
    mergeAttributes = Maps.newLinkedHashMap(mergeAttributes);
    List<IndexInfo> replacements = new LinkedList<>();

    // work through the lhs members first, skipping those that end up being replaced by rhs values
    int targetIndex = 0;
    for (StructMember originalMember : lhs.getMembers()) {
      String fromRhs = mergeAttributes.remove(originalMember.getKey());

      StructMember newMember;
      if (fromRhs != null) {
        // this is a replacement from the rhs
        StructMember rhsMember = rhs.getEntry(fromRhs);
        newMember = rhsMember;
        replacements.add(new IndexInfo(true, rhsMember.getIndex(), targetIndex));
      } else {
        // this is original, but skip it if it's the geometry member
        if (originalMember != lhsGeomMember) {
          replacements.add(new IndexInfo(false, originalMember.getIndex(), targetIndex));
        }
        // but the geom type gets added to the result type
        newMember = originalMember;
      }

      builder.add(originalMember.getKey(), newMember.getType());
      targetIndex++;
    }

    // add the remainder - this will be anything that wasn't replacing something from the lhs
    for (Entry<String, String> appended : mergeAttributes.entrySet()) {
      String rhsName = appended.getValue();
      StructMember rhsMember = rhs.getEntry(rhsName);
      replacements.add(new IndexInfo(true, rhsMember.getIndex(), targetIndex++));
      builder.add(appended.getKey(), Nullable.ifTrue(nullable, rhsMember.getType()));
    }

    return replacements.toArray(new IndexInfo[replacements.size()]);
  }

  @Override
  public ResultOrProblems<RiskscapeFunction> realize(RealizationContext context, FunctionCall functionCall,
      List<Type> givenTypes) {

    if (givenTypes.size() < 2) {
      return ResultOrProblems.failed(ArgsProblems.get().wrongNumber(2, givenTypes.size()));
    }

    return ProblemException.catching(() -> {
      Relation cutBy = arguments.evaluateConstant(context, functionCall, "layer", Relation.class,
          RelationType.WILD).getOrThrow();

      // the 3rd (optional) args is a struct of attributes in the cutBy layer to merge into the return value,
      // where each member key is the new (destination) name, and member value is the cutBy layer attribute
      // to copy (as text). E.g. { area_name: 'name' }
      StructDeclaration mergeAttributesDecl = getMergeAttributesDeclaration(functionCall);
      Map<String, String> mergeAttributesFrom = createAttributeMap(mergeAttributesDecl);

      // check whether to return the geometry difference (i.e. pieces that don't intersect), as well
      // as the geometry intersection
      Boolean returnDifference;
      if (givenTypes.size() == 4) {
        returnDifference = arguments
            .evaluateConstant(context, functionCall, "return_difference", Boolean.class, Types.BOOLEAN)
            .getOrThrow();

      } else {
        returnDifference = false;
      }

      Struct lhs = givenTypes.get(0).find(Struct.class).orElseThrow(()
          -> new ProblemException(TypeProblems.get().mismatch(
              functionCall.getArguments().get(0).getExpression(),
              Struct.EMPTY_STRUCT,
              givenTypes.get(0)
            )
      ));

      Struct rhs = cutBy.getType();
      StructMember lhsGeomMember = findGeomMember(lhs, functionCall, 0);
      StructMember rhsGeomMember = findGeomMember(rhs, functionCall, 1);

      StructBuilder builder = new StructBuilder();
      IndexInfo[] indexInfos = setupLoop(mergeAttributesFrom, builder, lhs, rhs, returnDifference, lhsGeomMember);

      final Struct resultTypeFinal = builder.build();
      StructMember resultGeomMember = findGeomMember(resultTypeFinal, null, -1);
      RSList realizedReturnType = RSList.create(resultTypeFinal);

      SRIDSet sridSet = context.getProject().getSridSet();

      // for something like this, we'd register the expensive resource with the cache, I think, rather than building the
      // thing in the cache and dropping expensive resource
      ExpensiveResource<IntersectionIndex> indexResource = new ExpensiveResource<>(
        context.getProblemSink(),
        "build index from " + functionCall.getArguments().get(1).toSource(),
        () -> {
          try {
            return IntersectionIndex.populateFromRelation(cutBy, lhsGeomMember.getType(), sridSet,
                IntersectionIndex.defaultOptions());
          } catch (ProblemException e) {
            // We will never get here because we know that rhs has geometry member rhsGeomMember
            throw new RiskscapeException(e.getProblems().get(0));
          }
        }
      );

      return new RiskscapeFunction() {

        @Override
        public Type getReturnType() {
          return realizedReturnType;
        }

        @Override
        public List<Type> getArgumentTypes() {
          return givenTypes;
        }

        @Override
        public Object call(List<Object> args) {
          Tuple lhs = (Tuple) args.get(0);
          Geometry lhsGeometry = lhs.fetch(lhsGeomMember);

          // lhs geometry is null, can't intersect
          if (lhsGeometry == null) {
            if (returnDifference) {
              return Collections.singletonList(createTuple(null, lhs, null));
            } else {
              return Collections.emptyList();
            }
          }
          // If the lhs is not geometry collection we want to unpack any multi-geoms that are returned.
          // But not if the lhs is multi-geom itself. But we check num geometries as shapefiles will
          // always emit a multi geom, even when there is only one part to it.
          boolean unpackMultiGeoms = lhsGeometry.getNumGeometries() == 1;

          IntersectionIndex index = indexResource.get();

          List<Pair<Geometry, Tuple>> results;
          Optional<Geometry> difference = null;
          if (returnDifference) {
            Pair<Optional<Geometry>, List<Pair<Geometry, Tuple>>> differenceAndIntersections =
                index.findDifferenceAndIntersections(lhsGeometry);
            results = differenceAndIntersections.getRight();
            difference = differenceAndIntersections.getLeft();
          } else {
            results = index.findIntersections(lhsGeometry);
          }
          List<Tuple> cuts = new ArrayList<>(results.size());

          for (Pair<Geometry, Tuple> result : results) {
            Tuple rhsTuple = result.getRight();
            Geometry rhsPiece = result.getLeft();

            // If we started off with a multigeom being cut up, then we don't need to unpack the 'cut' multigeom
            // out into all its constituent pieces (it makes it look like we excessively cut up the geometry).
            // However, input of single lines/polygons can result in a multigeom intersection if the geometry
            // crosses the cut-by-layer multiple times (e.g. a road winding in and out of a region). In this
            // case, it makes sense to deal with each intersecting piece individually (and it keeps the geometry
            // types in the resulting shapefile consistent)
            if (unpackMultiGeoms) {
              GeometryUtils.processPerPart(rhsPiece, piece -> cuts.add(createTuple(piece, lhs, rhsTuple)));
            } else {
              cuts.add(createTuple(rhsPiece, lhs, rhsTuple));
            }
          }

          // add a result in for the difference (bits of geometry that fell outside the cut layer)
          if (returnDifference) {

            // We treat multi-geom differences in the same way as the rhsPiece above.
            if (difference.isPresent()) {
              if (unpackMultiGeoms) {
                GeometryUtils.processPerPart(difference.get(), piece -> cuts.add(createTuple(piece, lhs, null)));
              } else {
                cuts.add(createTuple(difference.get(), lhs, null));
              }
            }
          }

          return cuts;
        }

        /**
         * Forms a merged result tuple with the given geometry and lhs/rhs - rhsValues can be null
         */
        private Tuple createTuple(Geometry geometry, Tuple lhs, Tuple rhs) {
          Tuple tuple = new Tuple(resultTypeFinal);

          for (IndexInfo indexInfo : indexInfos) {
            Object toSet;
            if (indexInfo.fromRhs) {
              // rhs values can be null when we're building the difference
              toSet = rhs == null ? null : rhs.fetch(indexInfo.sourceIndex);
            } else {
              toSet = lhs.fetch(indexInfo.sourceIndex);
            }

            tuple.set(indexInfo.targetIndex, toSet);
          }

          tuple.set(resultGeomMember, geometry);

          return tuple;
        }
      };
    });
  }


  /**
   * Helper for pulling out the first geom member from a struct
   */
  private StructMember findGeomMember(Struct struct, FunctionCall functionCall, int index) throws ProblemException {
    return struct.getMembers().stream()
      .filter(member -> member.getType().findAllowNull(Geom.class).isPresent())
      .findFirst().orElseThrow(() -> new ProblemException(
        Problems.foundWith(
            functionCall.getArguments().get(index).getExpression(),
            TypeProblems.get().structMustHaveMemberType(Types.GEOMETRY, struct)
            )
        )
      );
  }

  /**
   * Helper for returning a {@link StructDeclaration} from function call's arguments, throwing if it's not as expected,
   * or returns an empty one if nothing was given
   */
  private StructDeclaration getMergeAttributesDeclaration(FunctionCall functionCall) throws ProblemException {
    Optional<Argument> arg = arguments.get("merge_attributes").getFunctionCallArgument(functionCall);
    if (arg.isPresent()) {
        Expression structDeclaration = arg.get().getExpression();
        return structDeclaration.isA(StructDeclaration.class).orElseThrow(() -> new ProblemException(
          TypeProblems.get().mismatch(structDeclaration, StructDeclaration.class, structDeclaration.getClass())
      ));
    }
    return new StructDeclaration(Collections.emptyList(), Optional.empty());
  }

  /**
   * Traverses the StructDeclaration, returning a map of strings that represent the list of `ident -> text` members from
   * the original declaration, throwing a ProblemException if something doesn't match what's expected
   */
  private Map<String, String> createAttributeMap(StructDeclaration decl) throws ProblemException {
    Map<String, String> map = new HashMap<>(decl.getMembers().size());

    for (StructDeclaration.Member member : decl.getMembers()) {

      Constant constant = member.getExpression().isA(Constant.class).orElseThrow(()
          -> new ProblemException(TypeProblems.get().mismatch(
              member.getExpression(),
              Constant.class,
              member.getExpression().getClass()
      )));

      if (constant.getToken().type != TokenTypes.STRING) {
        throw new ProblemException(TypeProblems.get().mismatch(constant, TokenTypes.STRING, constant.getToken().type));
      }

      map.put(member.getIdentifier().getValue(), constant.getToken().getValue());
    }

    return map;
  }

}
