/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.wizard.model2.input;

import static nz.org.riskscape.wizard.model2.DslHelper.*;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import nz.org.riskscape.engine.GeometryProblems;
import nz.org.riskscape.engine.coverage.TypedCoverage;
import nz.org.riskscape.engine.data.ResolvedBookmark;
import nz.org.riskscape.engine.pipeline.RealizedPipeline;
import nz.org.riskscape.engine.pipeline.RealizedStep;
import nz.org.riskscape.engine.query.TupleUtils;
import nz.org.riskscape.engine.query.TupleUtils.FindOption;
import nz.org.riskscape.engine.relation.Relation;
import nz.org.riskscape.engine.types.Referenced;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Struct.StructMember;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.ProblemSink;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.problem.StandardCodes;
import nz.org.riskscape.rl.TokenTypes;
import nz.org.riskscape.wizard.QuestionSet;
import nz.org.riskscape.wizard.Survey;
import nz.org.riskscape.wizard.bld.IncrementalBuildState;
import nz.org.riskscape.wizard.bld.InvalidAnswerException;
import nz.org.riskscape.wizard.bld.PipelineChange;
import nz.org.riskscape.wizard.bld.change.AppendJoinChange;
import nz.org.riskscape.wizard.bld.loc.AtStepNamed;
import nz.org.riskscape.wizard.bld.loc.ChangeLocation;
import nz.org.riskscape.wizard.bld.loc.EndOfBranchWith;
import nz.org.riskscape.wizard.survey2.BasePhase;
import nz.org.riskscape.wizard.survey2.DefaultQuestionSet2;

public class InputDataPhase extends BasePhase {

  public InputDataPhase(Survey survey) {
    super(survey);
  }

  private static final String EXPOSURES_DATASET = "exposures";
  private static final String AREAS_DATASET = "areas";
  private static final String RESOURCES_DATASET = "resources";
  private static final String HAZARDS_DATASET = "hazards";

  public static final String EXPOSURE_ATTRIBUTE = getAttributeName(EXPOSURES_DATASET);
  public static final String HAZARD_ATTRIBUTE = getAttributeName(HAZARDS_DATASET);
  public static final String RESOURCE_ATTRIBUTE = getAttributeName(RESOURCES_DATASET);

  /**
   * We refer to the input-data layer in the plural form, but the pipeline attribute in the singular.
   * Generally the user will only ever see the singular form, i.e. when they select attributes or
   * view the final results. The singular form will make more sense in this case, as each row of
   * output represents a single exposure. The plural form makes more sense to us devs as we build the pipeline
   */
  public static String getAttributeName(String dataset) {
    // de-pluralize the string
    return dataset.replaceAll("s$", "");
  }

  public static String getDatasetName(RealizedStep inputStep) {
    return inputStep.getStepName().replace(INPUT_LAYER_SUFFIX, "");
  }

  public static boolean isMultiHazard(IncrementalBuildState buildState) {
    return buildState.isQuestionSetAnswered("input-combine-hazards");
  }

  /**
   * The 'main' branch of the pipeline that subsequent phases should hang off.
   * This differentiates the branch containing the exposures from the various
   * other input layers we add to the pipeline.
   */
  public static final ChangeLocation MAIN_BRANCH = EndOfBranchWith.stepNamed(EXPOSURES_DATASET);

  private static final String INPUT_LAYER_SUFFIX = "_input";

  /**
   * Annotation for finding questions for picking multi hazard - there are many of the same question so they can't be
   * looked up by id
   */
  private static final String ANNO_PICK_MULTI_HAZARD = "multi-hazard-pick";

  /**
   * Annotation that goes with ANNO_PICK_MULTI_HAZARD that gives the number of the hazard in the order that the user
   * gave them, starting from 1
   */
  private static final String ANNO_MULTI_HAZARD_INDEX = "multi-hazard-index";

  private final Geoprocessing geoprocessing = new Geoprocessing();

  /**
   * @return a list of the input layer names, deduced from inspecting the pipeline start steps, which should correspond
   * to some well-known steps in the pipeline that represent the final result of adding and manipulating the data that
   * was input
   */
  private static List<String> getInputLayers(IncrementalBuildState buildState) {
    // look for any start steps that end in '_input'
    return buildState.getRealizedPipeline().getStartSteps().stream()
        .map(step -> step.getStepName())
        .filter(name -> name.endsWith(INPUT_LAYER_SUFFIX))
        .map(name -> name.replace(INPUT_LAYER_SUFFIX, ""))
        .collect(Collectors.toList());
  }

  /**
   * @return a list of the datasets that need to be sampled
   */
  public static List<String> getDatasetsToSample(IncrementalBuildState buildState) {
    List<String> inputLayers = getInputLayers(buildState);

    // remove the exposure layer - that's the main branch of execution that does the sampling
    inputLayers.remove(EXPOSURES_DATASET);
    return inputLayers;
  }

  @Override
  public List<QuestionSet> getAvailableQuestionSets(IncrementalBuildState buildState) {
    List<QuestionSet> sets = new ArrayList<>();

    RealizedPipeline pipeline = buildState.getRealizedPipeline();
    Set<RealizedStep> startSteps = pipeline.getStartSteps();

    if (!isHasInputStepForDataset(startSteps, EXPOSURES_DATASET)) {
      sets.add(createInputDataQuestionSet(EXPOSURES_DATASET));
    }

    int nextHazardLayerIndex = getNextUnusedHazardLayerIndex(buildState);

    if (!isAnsweredAnyTypeOfHazardQuestion(buildState)) {
      // if nothing is answered, they can choose to go multi hazard or single hazard
      sets.add(createInputDataQuestionSet(HAZARDS_DATASET));
      sets.add(createInputDataQuestionSet(HAZARDS_DATASET, nextHazardLayerIndex));
    } else if (isBuildingMultiHazard(buildState)) {
      // but once multi hazard has begun, they can either finish adding layers or add more
      sets.add(createInputDataQuestionSet(HAZARDS_DATASET, nextHazardLayerIndex));
      sets.add(createCombineMultiHazardQuestionSet(buildState));
    }

    if (!isHasInputStepForDataset(startSteps, AREAS_DATASET)) {
      sets.add(createInputDataQuestionSet(AREAS_DATASET));
    }

    if (!isHasInputStepForDataset(startSteps, RESOURCES_DATASET)) {
      sets.add(createInputDataQuestionSet(RESOURCES_DATASET));
    }

    return sets;
  }


  @Override
  public boolean canSkip(IncrementalBuildState buildState) {
    // can be skipped once exposures answered and either:
    // 1 - a simple hazard layer chosen
    // 2 - many hazard layers chosen and then assembled in to a single layer
    Set<RealizedStep> startSteps = buildState.getRealizedPipeline().getStartSteps();
    boolean exposuresPicked = isHasInputStepForDataset(startSteps, EXPOSURES_DATASET);
    boolean hazardsPicked = isHasInputStepForDataset(startSteps, HAZARDS_DATASET)
        || isHazardsHaveBeenCombined(buildState);

    return exposuresPicked && hazardsPicked;
  }

  /**
   * Construct a predicate for use with askWhen that tests when the response to the last question set's question with
   * name `questionName` is equal to response.
   * TODO move this to the question set api?
   */
  private static <T> Predicate<IncrementalBuildState> isResponse(String questionName, T response) {
    return buildState ->
      response.equals(buildState.getAnswerTo(buildState.getQuestionSet().getId(), questionName, response.getClass()));
  }

  /**
   * @return true if there's a dataset with the given name being added in the set of startSteps
   */
  private boolean isHasInputStepForDataset(Set<RealizedStep> startSteps, String datasetName) {
    String stepName = getInputStepNameFor(datasetName);
    return startSteps.stream().anyMatch(rs -> rs.getStepName().equals(stepName));
  }

  /**
   * @return the input step name for a dataset with the given `datasetName`
   */
  private static String getInputStepNameFor(String datasetName) {
    return datasetName + INPUT_LAYER_SUFFIX;
  }

  private QuestionSet createCombineMultiHazardQuestionSet(IncrementalBuildState buildState) {
    DefaultQuestionSet2 questionSet = new DefaultQuestionSet2("input-combine-hazards", this);

    List<RealizedStep> hazardEndSteps = buildState.getRealizedPipeline().getEndSteps()
        .stream()
        .filter(rs -> rs.getStepName()
        .startsWith(HAZARDS_DATASET))
        .collect(Collectors.toList());

    List<IncrementalBuildState> allPickStates = buildState.buildStateStream()
      .filter(ibs -> ibs.getQuestion().hasAnnotation(ANNO_PICK_MULTI_HAZARD))
      .collect(Collectors.toList());

    if (allPickStates.isEmpty()) {
      throw new AssertionError("can not do multi hazard when no multi hazard layers have been picked");
    }

    questionSet
      .addHiddenQuestion("create-hazard-branch")
      .then(input ->
        // this new chain is the one the sampling phase will pick up
        PipelineChange.newChain("input(value: {}, name: '%s') as %s", HAZARD_ATTRIBUTE,
            getInputStepNameFor(HAZARDS_DATASET))
      );

    // join each coverage back to the mainline with a specific name
    for (RealizedStep endStep : hazardEndSteps) {
      String joinStepName =
          TokenTypes.quoteIdent("input-multi-hazard-join-" + endStep.getStepName());

      ChangeLocation otherSideLocation = new AtStepNamed(endStep.getStepName());

      questionSet
        .addHiddenQuestion("join-attribute-" + endStep.getStepName())
        .then(input ->
          answer -> new AppendJoinChange.Builder(answer)
            .append("join(on: true) as %s", joinStepName)
            .joins(
                otherSideLocation,
                "%s.rhs",
                joinStepName
            )
            .build()
        );
    }

    questionSet
    .addQuestion("grid-resolution", Double.class)
    .then((input, gridResolution) ->
      // this new chain is the one the sampling phase will pick up
      PipelineChange.chainStep(
        "select({combine_coverages(%s, %f) as %s}) as %s",
        createCombineCoveragesStructExpression(input.getBuildState()),
        gridResolution,
        HAZARD_ATTRIBUTE,
        HAZARDS_DATASET
      )
    );


    return questionSet;
  }

  private String createCombineCoveragesStructExpression(IncrementalBuildState buildState) {
    return buildState.buildStateStream()
        // find all the responses where we named the hazard layer
      .filter(ibs -> ibs.getQuestion().hasAnnotation("hazard-attribute"))
      .map(ibs -> {
        // this is the name the user gave for this part of the hazard, e.g depth
        // TODO as per A/Cs we need to allow the user to apply some sort of mapping to this, in case it's come from a
        // vector layer, i.e. pick a specific attribute or attributes
        String coverageMember = ibs.getAnswer().getValueAs(String.class);

        return TokenTypes.quoteIdent(coverageMember);
      }).collect(Collectors.joining(", ", "{", "}"));
  }

  /**
   * @return true if the user has answered questions meaning that we're building a multi hazard model
   */
  private boolean isBuildingMultiHazard(IncrementalBuildState buildState) {
    boolean doingMultiHazard = buildState.getRealizedPipeline()
        .getStartSteps().stream()
        .anyMatch(rs -> rs.getStepName().startsWith(HAZARDS_DATASET + "_1"));

    boolean doneCombine = isHazardsHaveBeenCombined(buildState);

    return doingMultiHazard && !doneCombine;
  }

  /**
   * @return true if all the hazard layers have been combined in to a single hazard layer
   */
  private boolean isHazardsHaveBeenCombined(IncrementalBuildState buildState) {
    return buildState.isQuestionSetAnswered("input-combine-hazards");
  }

  /**
   * Find the next index to use for a subsequent hazard layer with multi hazard
   */
  private int getNextUnusedHazardLayerIndex(IncrementalBuildState buildState) {
    return (int) buildState
        .buildStateStream()
        .map(ibs -> ibs.getQuestion().getAnnotation(ANNO_MULTI_HAZARD_INDEX))
        .filter(Optional::isPresent)
        .count() + 1;
  }

  /**
   * @return true if there are any input steps that begin with hazard, then they've started doing *something* already
   */
  private boolean isAnsweredAnyTypeOfHazardQuestion(IncrementalBuildState buildState) {
    return buildState.getRealizedPipeline()
        .getStartSteps().stream()
        .anyMatch(rs -> rs.getStepName().startsWith(HAZARDS_DATASET));
  }

  /**
   * Lookup a UI label for a dataset with the given name
   */
  private String getDatasetLabel(String datasetName, Locale locale) {
    return survey.getMessageSource()
        .getMessage("question.input.dataset." + datasetName, null, datasetName, locale);
  }

  public QuestionSet createInputDataQuestionSet(String datasetName) {
    return createInputDataQuestionSet(datasetName, 0);
  }

  public QuestionSet createInputDataQuestionSet(String originalDatasetName, int multiHazardIndex) {
    String datasetName = multiHazardIndex == 0 ? originalDatasetName : originalDatasetName + "_" + multiHazardIndex;
    String startStepName = getInputStepNameFor(datasetName);
    DefaultQuestionSet2 questionSet = new DefaultQuestionSet2("input", datasetName, this);

    // make sure we always append to this input layer branch by default,
    // even if we end up adding new branches for geoprocessing, etc
    questionSet.setDefaultLocation(EndOfBranchWith.stepNamed(startStepName));

    //
    // pick
    //
    questionSet.addQuestion("layer", ResolvedBookmark.class)
    .customizeQuestion(q -> {
      // some extra annotations on this question for multi hazard so we can find the answers later
      if (multiHazardIndex > 0) {
        q = q.withAnnotations(
            ANNO_MULTI_HAZARD_INDEX, String.valueOf(multiHazardIndex),
            ANNO_PICK_MULTI_HAZARD, ""
        );
      }
      q = q.withI18nLookup((suffix, locale) ->
            survey.getMessageSource().getMessage(
              "question.input.layer." + suffix + "." + datasetName,
              getDatasetLabel(datasetName, locale),
              "Select which layer you want to use",
              locale)
      );

      return q;
    })
    .then((input, response) ->  {
      String dataArg;
      String id = TokenTypes.quoteText(response.getId());
      if (response.isType(Relation.class)) {
        ProblemException.catching(() -> validateRelation(response.getData(Relation.class)))
          .orElseThrow((problems) -> new InvalidAnswerException(input, problems));
        dataArg = "relation: " + id;
      } else if (datasetName.equals(EXPOSURES_DATASET)) {
        // Always treat an exposure layer as a relation, even if its native data format isn't relational - rasters
        // support being iterated over pixel-by-pixel - others might also support this in the future - let's leave it
        // to the input step to figure this out
        ProblemException.catching(() -> validateRelation(response))
          .orElseThrow((problems) -> new InvalidAnswerException(input, problems));
        dataArg = "relation: " + id;
      } else {
        dataArg = "value: bookmark(" + id + ")";
      }

      return PipelineChange.newChain(
        "input(%s, name: %s) as %s",
        dataArg,
        TokenTypes.quoteText(getAttributeName(datasetName)),
        startStepName
      );

    });

    if (multiHazardIndex > 0) {
      // user needs to name the dataset for multi hazard so it's not just hazard_1 and hazard_2
      questionSet
        .addQuestion("hazard-attribute", String.class)
        .customizeQuestion(q ->
          q.withAnnotations("hazard-attribute", datasetName)
        ).thenNoChange();
    }

    questionSet.addQuestion(Geoprocessing.GEOPROCESS_QUESTION, Boolean.class)
      .customizeQuestion(q -> q.askWhen(isDataSuitableForProcessing(datasetName)))
      .customizeQuestion(q ->
        q.withI18nLookup((suffix, locale) ->
          survey.getMessageSource().getMessage(
            "question.input.geoprocess." + suffix,
            getDatasetLabel(datasetName, locale)
          )
        )
      )
      .then(r -> PipelineChange.noChange());

    geoprocessing.addQuestions(questionSet, getAttributeName(datasetName));

    questionSet.addHiddenQuestion("finalize")
      .then(input -> {
        String attributeName = getAttributeName(datasetName);
        boolean isExposureLayer = datasetName.equals(EXPOSURES_DATASET);

        if (isExposureLayer) {
          return PipelineChange.chainStep("select({%s}) as %s", attributeName, datasetName);
        } else {
          // see if we should rename the input data for multi hazard attribute
          String newName = input.getBuildState()
              .getAnswer(questionSet.getId(), "hazard-attribute")
              .map(a -> a.getValueAs(String.class))
              .orElse(attributeName); // otherwise dataset-based name is fine

          if (isDataSuitableForProcessing(datasetName).test(input.getBuildState())) {
            // non-exposure vector layers need to be indexed
            return PipelineChange.chainStep("group({to_coverage(%s) as %s}) as %s",
                attributeName, newName, datasetName);
          } else {
            return PipelineChange.chainStep("select({%s as %s}) as %s", attributeName, newName, datasetName);
          }
        }


      });

    return questionSet;
  }

  private static Predicate<IncrementalBuildState> isDataSuitableForProcessing(String datasetName) {
    return ibs -> {
      String attributeName = getAttributeName(datasetName);

      // TODO ideally we could locate the final step in the chain starting with this step - it might be more robust
      return ibs.getRealizedPipeline().getStep(getInputStepNameFor(datasetName))
          .flatMap(step -> getAttributeType(step.getProduces(), attributeName, Struct.class))
          // and it has to have a single geometry member
          .map(struct -> TupleUtils.findGeometryMember(struct, TupleUtils.FindOption.OPTIONAL) != null)
          // otherwise it's not geometrical relation data and it's not suitable for our processing steps
          .orElse(false);
    };
  }

  public static boolean isHazardDataset(String datasetName) {
    return datasetName.equals(HAZARDS_DATASET);
  }

  /**
   * Sanity-checks the given bookmark to see if it can be used as a relation in our model
   */
  private boolean validateRelation(ResolvedBookmark bookmark) throws ProblemException {
    if (bookmark.isType(TypedCoverage.class)
        && bookmark.getData(TypedCoverage.class).map(tc -> tc.asRelation().isPresent()).orElse(false)) {
      return true;
    }

    return validateRelation(bookmark.getData(Relation.class));
  }
    /**
     * Sanity-checks the given relation has enough geometry info to be used in a model.
     */
    private boolean validateRelation(ResultOrProblems<Relation> relationOr) throws ProblemException {
    // sanity-check the relation contains geometry (we won't get far without it)
    Relation relation = relationOr
        // discard any warnings here as they would be duplicating those emitted from the bookmark binding
        .drainWarnings(ProblemSink.DEVNULL)
        .getOrThrow();
    StructMember geomMember = TupleUtils.findGeometryMember(relation.getType(), FindOption.OPTIONAL);
    if (geomMember == null) {
      throw new ProblemException(Problem.error(StandardCodes.GEOMETRY_REQUIRED, relation.getType()));
    }
    // it will also need CRS info, otherwise spatial sampling won't work
    geomMember.getType().find(Referenced.class).orElseThrow(() ->
      new ProblemException(GeometryProblems.get().notReferenced(geomMember.getType())));

    return true;
  }
}
