/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.wizard.model2.smp;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.locationtech.jts.geom.Point;

import nz.org.riskscape.engine.pipeline.RealizedStep;
import nz.org.riskscape.engine.query.TupleUtils;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.pipeline.ast.PipelineDeclaration;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.wizard.Question.I18nLookupFunction;
import nz.org.riskscape.wizard.QuestionSet;
import nz.org.riskscape.wizard.Survey;
import nz.org.riskscape.wizard.bld.IncrementalBuildState;
import nz.org.riskscape.wizard.bld.PipelineChange;
import nz.org.riskscape.wizard.bld.change.AppendJoinChange;
import nz.org.riskscape.wizard.bld.dsl.IncompletePipelineChange;
import nz.org.riskscape.wizard.bld.loc.EndOfBranchWith;
import nz.org.riskscape.wizard.model2.input.InputDataPhase;
import nz.org.riskscape.wizard.survey2.BasePhase;
import nz.org.riskscape.wizard.survey2.Choices;
import nz.org.riskscape.wizard.survey2.DefaultQuestionSet2;

public class SamplePhase extends BasePhase {

  private static final String QS_SAMPLING = "sample";

  private static final String SAMPLE_TYPE_QID = "-by";

  /**
   * The all-intersections sample result from the hazard-layer. This includes the
   * intersecting geometry, which we remove from the 'hazard' attribute.
   */
  public static final String SAMPLED_HAZARD = InputDataPhase.HAZARD_ATTRIBUTE + "_sampled";

  /**
   * The ratio of sampled hazard geometry to original exposure geometry that can be used to scale losses.
   */
  public static final String EXPOSED_RATIO = "exposed_ratio";

  /**
   * A step with this name is appended to the pipeline when this phase is finished
   */
  public static final String WELL_KNOWN_STEP = "sampled";

  private static final I18nLookupFunction I18N_LOOKUP = (q, suffix, locale) -> {
    String datasetName = q.getName().split("-")[0];
    // drop the dataset name from the i18n key lookup
    String genericName = q.getName().replaceAll("^" + datasetName + "-", "");
    String i18nKey = "question.sample." + genericName + "." + suffix;
    String i18nMsg = q.getMessageSource().getMessage(i18nKey, new Object[] {datasetName}, locale);
    return Optional.ofNullable(i18nMsg);
  };

  private static final List<SampleType> ONLY_CENTROID = Arrays.asList(SampleType.CENTROID);

  public SamplePhase(Survey survey) {
    super(survey);
  }

  @Override
  protected QuestionSet buildQuestionSet(IncrementalBuildState buildState) {
    DefaultQuestionSet2 questions = new DefaultQuestionSet2(QS_SAMPLING, this);

    questions.setDefaultLocation(InputDataPhase.MAIN_BRANCH);

    List<RealizedStep> startSteps = applicableStepStream(buildState).collect(Collectors.toList());

    sortSamplingInputs(buildState, startSteps);

    Struct exposureType = buildState.getInputStruct(InputDataPhase.MAIN_BRANCH)
      .getMember(InputDataPhase.EXPOSURE_ATTRIBUTE)
      .flatMap(sm -> sm.getType().findAllowNull(Struct.class)).orElse(Struct.EMPTY_STRUCT);

    Class<?> exposureGeometryType = Optional
        .ofNullable(TupleUtils.findGeometryMember(exposureType, TupleUtils.FindOption.OPTIONAL))
        .map(member -> member.getType().internalType())
        .orElse(null);

    for (RealizedStep step : startSteps) {
      // get the original dataset name
      String datasetName = InputDataPhase.getDatasetName(step);
      String coverage = InputDataPhase.getAttributeName(datasetName);

      // sample-type question is mandatory for the hazard-layer, defaults to closest for everything else
      if (InputDataPhase.isHazardDataset(datasetName)) {
        List<SampleType> sampleTypeOptions = getSampleTypeOptions(exposureGeometryType, datasetName);

        questions.addQuestion(datasetName + SAMPLE_TYPE_QID, SampleType.class)
        // skip it when there aren't any options
        .customizeQuestion(q -> q.askWhen(ibs -> sampleTypeOptions != ONLY_CENTROID))
        .customizeQuestion(q -> q.withI18nLookup(I18N_LOOKUP))
        // customize which enum options are allowed
        .customizeQuestion(q -> q.withChoices(ibs -> Choices.forEnums(sampleTypeOptions)))
        .thenNoChange();
      }

      // optionally add a buffer distance when sampling (this is pointless with centroid)
      questions.addQuestion(datasetName + "-buffer", Expression.class)
        .customizeQuestion(q -> q.withI18nLookup(I18N_LOOKUP).optionalOne())
        // we might have skipped the last question, so we need to check for the SampleType using getSampleType
        // instead of using the simpler QuestionSet.askWhenDependencyIs
        .customizeQuestion(q -> q.askWhen(ibs -> getSampleType(ibs, datasetName) != SampleType.CENTROID))
        .then((input, distance) -> {
            SampleType type = getSampleType(input.getBuildState(), datasetName);
            return joinBranchWithSample(datasetName, getSampleStep(coverage, type, distance));
        });

      // if no buffer distance was specified, add in the sample step
      questions.addHiddenQuestion(datasetName + "-finish")
        .ifNotAnswered(datasetName + "-buffer")
        .then(input -> {
            SampleType type = getSampleType(input.getBuildState(), datasetName);
            return joinBranchWithSample(datasetName, getSampleStep(coverage, type));
        });
    }

    questions.addHiddenQuestion("append-name").then(input ->
      PipelineChange.chainStep("select({*}) as %s", WELL_KNOWN_STEP)
    );

    return questions;
  }

  private List<SampleType> getSampleTypeOptions(Class<?> geometryType, String datasetName) {
    if (geometryType != null) {
      // for point exposures and non-hazard layers, we offer choices that can emit at most a single value
      //  - For points, they don't have area, so screw up the exposed ratio if we do all intersections
      //  - for non hazard layers, it doesn't make sense to have many values - each exposure needs to correspond to
      //  *a* resource or *an* area
      // NB we should probably just always go with closest for area layer, but it makes the UX a bit weird when we
      // suddenly throw up a Question to the user about a buffer distance that's related to the closest operation.  By
      // letting them choose that, it is more consistent with the hazard layer experience and clearer why they are
      // being asked the question.
      if (geometryType.equals(Point.class) || !InputDataPhase.isHazardDataset(datasetName)) {
        return Arrays.asList(SampleType.CENTROID, SampleType.CLOSEST);
      } else {
        return Arrays.asList(SampleType.CENTROID, SampleType.CLOSEST, SampleType.ALL_INTERSECTIONS);
      }
    } else {
      // don't show any option if there's no geometry - This can fail later with a
      // realization error
      return ONLY_CENTROID;
    }
  }

  /**
   * @return the answer to the SampleType question for the given dataset
   */
  private SampleType getSampleType(IncrementalBuildState buildState, String datasetName) {
    return buildState.getResponse(QS_SAMPLING, datasetName + SAMPLE_TYPE_QID, SampleType.class)
        // default to 'closest' sampling if the user didn't pick a sampling type
        .orElse(SampleType.CLOSEST);
  }

  /**
   * Joins another input branch for a given dataset to the main exposures branch,
   * and adds in the specified sampling step.
   * @param datasetName the name of the branch to join, e.g. 'hazards', 'resources', etc
   * @param sampleStep the pipeline step to append to do the actual sampling
   */
  private IncompletePipelineChange joinBranchWithSample(String datasetName, String sampleStep) {
    String joinStepName = String.format("exposures_join_%s", datasetName);
    return answer -> AppendJoinChange.builder(answer)
        .append("join(on: true).lhs as %s -> %s", joinStepName, sampleStep)
        .joins(EndOfBranchWith.stepNamed(datasetName), joinStepName + ".rhs")
        .build();
  }

  /**
   * @return A pipeline step to do the sampling for the given coverage and sample type.
   */
  private String getSampleStep(String coverageAttribute, SampleType sampleType) {
    String sampleFunctionCall = sampleType.getFunctionCall(InputDataPhase.EXPOSURE_ATTRIBUTE, coverageAttribute);

    if (sampleType == SampleType.ALL_INTERSECTIONS) {
      boolean isHazardLayer = InputDataPhase.HAZARD_ATTRIBUTE.equals(coverageAttribute);
      // save the sampled hazard geom in a separate variable so we can show our workings
      String sampledGeom = isHazardLayer ? SAMPLED_HAZARD : coverageAttribute;

      // include an exposed ratio when we are sampling the hazard (for scaling the losses later)
      String exposedRatioStep = "";
      if (isHazardLayer) {
        exposedRatioStep = String.format("-> select({*, map(%s, geom -> measure(geom) / measure(exposure)) as %s})",
            SAMPLED_HAZARD, EXPOSED_RATIO);
      }

      return String.format(""
          // sample, then split apart the actual hazard intensity value from the sampled geom.
          // 'hazards' should always match the hazard layer's type, even if it's a list.
          + "select({*, %s as %s}) as \"sample_%s_layer\" -> "
          + "select({*, map(%s, h -> h.sampled) as %s}) %s",
          sampleFunctionCall, sampledGeom, coverageAttribute, sampledGeom, coverageAttribute, exposedRatioStep);
    } else {
      // replace the coverage attribute with the sampled value
      return String.format("select({*, %s as %s}) as \"sample_%s_layer\"",
          sampleFunctionCall, coverageAttribute, coverageAttribute);
    }
  }

  /**
   * Same as {@link #getSampleStep(String, SampleType)} but adds in a buffer
   * distance as a margin of error when sampling.
   */
  private String getSampleStep(String coverageAttribute, SampleType sampleType, Expression bufferDistance) {
    if (sampleType == SampleType.ALL_INTERSECTIONS) {
      // buffer the geom first, before we pass it to the sample function. This caters for
      // cases where we want to extend 1m out from a building to look for the max inundation depth
      String bufferedGeom = String.format("buffer(%s, %s)",
          InputDataPhase.EXPOSURE_ATTRIBUTE,
          bufferDistance.toSource());

      boolean isHazardLayer = InputDataPhase.HAZARD_ATTRIBUTE.equals(coverageAttribute);
      // save the sampled hazard geom in a separate variable so we can show our workings
      String sampledGeom = isHazardLayer ? SAMPLED_HAZARD : coverageAttribute;

      return String.format(""
          // sample, then split apart the actual hazard intensity value from the sampled geom.
          // Note we don't include exposed_ratio for scaling here because we've buffered.
          + "select({*, %s as %s}) -> "
          + "select({*, map(%s, h -> h.sampled) as %s})",
          sampleType.getFunctionCall(bufferedGeom, coverageAttribute),
          sampledGeom,
          sampledGeom,
          coverageAttribute);
    } else {
      return String.format("select({*, %s as %s})",
          sampleType.getFunctionCall(InputDataPhase.EXPOSURE_ATTRIBUTE, coverageAttribute, bufferDistance.toSource()),
          coverageAttribute);
    }
  }

  /**
   * Sort the given steps by the order in which they were declared in the ast (which will match the order in which
   * the question sets were answered)
   */
  private void sortSamplingInputs(IncrementalBuildState buildState, List<RealizedStep> startSteps) {
    PipelineDeclaration decl = buildState.getAst();
    startSteps.sort((lhs, rhs) -> {
      // sort hazard layers first, as they're generally more important
      boolean lhsIsHazard = InputDataPhase.isHazardDataset(InputDataPhase.getDatasetName(lhs));
      boolean rhsIsHazard = InputDataPhase.isHazardDataset(InputDataPhase.getDatasetName(rhs));

      if (lhsIsHazard != rhsIsHazard) {
        return Boolean.compare(rhsIsHazard, lhsIsHazard);
      }

      return Integer.compare(
        decl.findDefinition(lhs.getStepName()).get().getChainIndex(),
        decl.findDefinition(rhs.getStepName()).get().getChainIndex()
      );
    });
  }

  private Stream<RealizedStep> applicableStepStream(IncrementalBuildState buildState) {
    List<String> toSample = InputDataPhase.getDatasetsToSample(buildState);

    // the end of each input chain should match the layer's name
    return buildState.getRealizedPipeline().getEndSteps().stream()
        .filter(step -> toSample.contains(step.getStepName()));
  }

}
