/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.steps;

import java.net.URI;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.stream.Collectors;

import com.google.common.collect.Range;
import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;

import lombok.RequiredArgsConstructor;
import nz.org.riskscape.dsl.SourceLocation;
import nz.org.riskscape.dsl.Token;
import nz.org.riskscape.engine.Engine;
import nz.org.riskscape.engine.bind.ParamProblems;
import nz.org.riskscape.engine.bind.ParameterField;
import nz.org.riskscape.engine.pipeline.RealizationInput;
import nz.org.riskscape.engine.pipeline.Realized;
import nz.org.riskscape.engine.pipeline.RealizedPipeline;
import nz.org.riskscape.engine.pipeline.RealizedStep;
import nz.org.riskscape.engine.problem.ProblemFactory;
import nz.org.riskscape.engine.projection.Projector;
import nz.org.riskscape.engine.relation.EmptyRelation;
import nz.org.riskscape.engine.resource.Resource;
import nz.org.riskscape.engine.resource.ResourceLoadingException;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.pipeline.PipelineParser;
import nz.org.riskscape.pipeline.ast.PipelineDeclaration;
import nz.org.riskscape.pipeline.ast.PipelineDeclaration.Found;
import nz.org.riskscape.pipeline.ast.StepLink;
import nz.org.riskscape.pipeline.ast.StepReference;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.ParameterToken;
import nz.org.riskscape.rl.ast.StructDeclaration;

public class SubpipelineStep extends BaseStep<SubpipelineStep.Parameters> {

  interface LocalProblems extends ProblemFactory {
    Problem ambiguousOutput(List<Token> candidates, Problem... hintGoesHere);
    Problem ambiguousOutputHint();

    /**
     * The sub-pipeline contains no steps.
     */
    Problem empty();

    Problem missingParameters(URI pipelineLocation, Set<String> parameterNames);
    Problem missingParameterHint(String parameterName, SourceLocation occurrence);

    Problem inStepRefMissing(String inputStepName);
    Problem inStepRefButNoInput();

    Problem chainingToIn(Token at);

    Problem childFailed(URI location, Problem... failures);

    Problem parametersNotNamed(Token at);

    Problem parametersNotNamedHint();

    Problem recursion(URI location);
  }

  public static final LocalProblems PROBLEMS = Problems.get(LocalProblems.class);

  public static final String STEP_REF_IN = "in";
  public static final String STEP_REF_OUT = "out";

  /**
   * A thread local stack to allow recursion to be detected whilst realizing sub-pipelines.
   */
  private final ThreadLocal<Queue<URI>> recursionStack = ThreadLocal.withInitial(() -> new ArrayDeque<>());

  public SubpipelineStep(Engine engine) {
    super(engine, Range.closed(0, 1), List.of());
  }

  public static class Parameters {
    @ParameterField
    public URI location;

    @ParameterField
    public Optional<StructDeclaration> parameters;

    public RealizationInput input;
  }

  // this step overrides the top-level realize method on BaseStep because it wants to return a whole pipeline, not just
  // a realized.  If we end up with more than one of these 'creates a whole pipeline' steps, we might want to split
  // BaseStep up in to BaseStep (which just has the support methods) and SimpleStep (which finalizes this method)
  @Override
  public ResultOrProblems<? extends Realized> realize(Parameters parameters) {
    throw new UnsupportedOperationException();
  }

  @RequiredArgsConstructor
  private static class ChildPipeline {

    // the result of realizing the child pipeline
    final RealizedPipeline realized;

    // The step from the parent pipeline that we are feeding in to this child
    final RealizedStep originalInput;

    // The renamed version of the originalInput step - as it appears in realized
    final RealizedStep renamedInput;

    // the step in the child pipeline that produces output that can be fed back in to the parent pipeline
    final RealizedStep outputStep;

    // true if the child was fed input from the parent
    boolean hasInput() {
      return originalInput != null;
    }
  }

  /**
   * Validate that the given parameters are valid enough to be passed to the sub-pipeline.
   *
   * Any problems found here should be due to the parameters themselves, not related to how the sub-pipeline
   * is using them.
   */
  private ResultOrProblems<Parameters> validateParameters(Parameters params) {
    if (params.parameters.isPresent()) {
      StructDeclaration sd = params.parameters.get();
      for (StructDeclaration.Member member : sd.getMembers()) {
        if (member.getName().isEmpty()) {
          // replacements must be named
          Token location = member.getExpression().getBoundary()
              .map(Pair::getLeft)
              .orElse(Token.UNKNOWN_LOCATION);
          return ResultOrProblems.failed(
              PROBLEMS.parametersNotNamed(location)
                  .withChildren(PROBLEMS.parametersNotNamedHint())
          );
        }
      }
    }
    return ResultOrProblems.of(params);
  }

  private ResultOrProblems<PipelineDeclaration> replaceParameters(
      PipelineDeclaration ast, StructDeclaration replacements
  ) {
    // turn replacements in to a map
    Map<String, Expression> replacementsMap = new HashMap<>();
    for (StructDeclaration.Member member : replacements.getMembers()) {
      replacementsMap.put(member.getName().get(), member.getExpression());
    }

    // TODO need a test that asserts that the keys are always the first found
    Map<String, ParameterToken> requiredNamesToTokens = ast.findParameters().keySet().stream()
        .collect(Collectors.toMap(ParameterToken::getValue, token -> token));

    // create custom error message for missing parameters
    SetView<String> missing = Sets.difference(requiredNamesToTokens.keySet(), replacementsMap.keySet());
    if (!missing.isEmpty()) {
      // build a problem for each missing parameter that lists where it's first encountered in the child pipeline
      List<Problem> hints = missing.stream().map(name ->  {
        return PROBLEMS.missingParameterHint(name, requiredNamesToTokens.get(name).getToken().getLocation());
      }).toList();

      return ResultOrProblems.failed(
          PROBLEMS.missingParameters(ast.getMetadata().getLocation(), missing).withChildren(hints)
      );
    }

    // find any extra params, create some warnings for them
    SetView<String> surplus = Sets.difference(replacementsMap.keySet(), requiredNamesToTokens.keySet());
    List<Problem> warnings = surplus.stream().map(parameterName -> ParamProblems.get().ignored(parameterName)).toList();

    // NB we are not mapping the result here as we are not expecting replace parameters to fail - we've pre-validated
    // the parameters so we can have custom error messages.
    return ResultOrProblems.of(ast.replaceParameters(replacementsMap).get(), warnings);
  }

  private ResultOrProblems<Resource> loadSubpipeline(Parameters params) {
    try {
      return ResultOrProblems.of(getEngine().getResourceFactory().load(params.location));
    } catch (ResourceLoadingException ex) {
      return ResultOrProblems.failed(Problems.foundWith(getParameterSet().get("location"), Problems.caught(ex)));
    }
  }

  /**
   * Parses the child pipeline from the location parameter, substituting given parameters
   */
  private ResultOrProblems<PipelineDeclaration> parseChildPipeline(Resource pipelineResource, Parameters params) {
    return PipelineParser.parseParameterizedPipeline(pipelineResource).flatMap(ast -> {
      if(ast.getChains().isEmpty()) {
        return ResultOrProblems.failed(PROBLEMS.empty());
      }
      return replaceParameters(ast, params.parameters.orElse(new StructDeclaration(List.of(), Optional.empty())));
    });
  }

  /**
   * Realizes the child pipeline, possibly feeding the input step in to the child's magic `start` step reference
   */
  private ResultOrProblems<ChildPipeline> realizeChildPipeline(RealizationInput input, PipelineDeclaration pipeline) {
    if (recursionStack.get().contains(pipeline.getMetadata().getLocation())) {
      return ResultOrProblems.failed(PROBLEMS.recursion(pipeline.getMetadata().getLocation()));
    }
    try {
      // push the sub-pipeline's location onto the recursion stack so we can detect recursion.
      recursionStack.get().add(pipeline.getMetadata().getLocation());
      final int numDependencies = input.getDependencies().size();

      if (numDependencies > 1) {
        // subpipeline only supports one input, the pipeline realizer enforces this arity check
        throw new AssertionError("Realization should not allow this");
      }

      // is there an `in` step reference in the child pipeline
      List<Found> inStepRefs = pipeline.findAll((ignored, sd) -> sd
              .isA(StepReference.class)
              .map(sr -> sr.getIdent().equals(STEP_REF_IN))
              .orElse(false));
      boolean hasStepRefIn = !inStepRefs.isEmpty();

      if (hasStepRefIn) {
        // if there is an in step ref then we do another sweep to check there are no steps in the
        // subpipeline that chain to in (cause that would be weird)
        StepLink toIn = null;
        outer:
        for (Found f : inStepRefs) {
          // we expect the in step to be used as the source (lhs) of links. if we find any on the
          // target (rhs) then that's a problem.
          for (int i = 0; i < f.getChain().getLinkCount(); i++) {
            StepLink link = f.getChain().getLink(i);
            if (STEP_REF_IN.equals(link.getRhs().getIdent())) {
              toIn = link;
              break outer;
            }
          }
        }
        if (toIn != null) {
          return ResultOrProblems.failed(PROBLEMS.chainingToIn(
              toIn.getBoundary()
                  .map(Pair::getRight)
                  .orElse(Token.UNKNOWN_LOCATION)
          ));
        }

      }

      RealizedPipeline startFromPipeline = RealizedPipeline.empty(input.getExecutionContext(), pipeline);
      RealizedStep originalInputStep = null;
      RealizedStep stubInputStep = null;

      // if the child is being fed input, we need to 'inject' it in to the sub-pipeline
      if (numDependencies == 1) {
        // get our dependency - this is the input step
        originalInputStep = input.getDependencies().get(0);

        // Created a stub-step to use to realize the child pipeline (it won't be executed)
        stubInputStep = RealizedStep.named(STEP_REF_IN)
            .withResult(new EmptyRelation(originalInputStep.getProduces()));

        // Seed realization with this `in` step
        startFromPipeline = startFromPipeline.add(stubInputStep);

        // make sure the child pipeline has an `in` step

        if (!hasStepRefIn) {
          return ResultOrProblems.failed(PROBLEMS.inStepRefMissing(originalInputStep.getStepName()));
        }
      } else {
        if (hasStepRefIn) {
          return ResultOrProblems.failed(PROBLEMS.inStepRefButNoInput());
        }
      }

      RealizedPipeline newPipeline =
          input.getExecutionContext().getPipelineRealizer().realize(startFromPipeline, pipeline);

      // if there's only one end step that produces a non-empty struct, this is our end step.
      List<RealizedStep> endSteps = newPipeline.getEndSteps().stream()
          .filter(rs -> rs.getProduces().size() > 0)
          .toList();

      RealizedStep outStep;
      if (endSteps.size() == 0) {
        outStep = RealizedStep.named(STEP_REF_OUT).withResult(new EmptyRelation(Struct.EMPTY_STRUCT));
        // make it a null step
      } else if (endSteps.size() == 1) {
        outStep = endSteps.get(0);
      } else {
        outStep = endSteps.stream().filter(rs -> rs.getName().equals(STEP_REF_OUT)).findFirst().orElse(null);

        // there wasn't a single step named out, fail
        if (outStep == null) {
          // sort by source code position as end steps is otherwise unsorted which if nothing else makes test assertions
          // hard
          List<Token> candidates = endSteps.stream()
              .map(step -> step.getAst().getNameToken().orElse(step.getAst().getIdentToken()))
              .sorted((lhs, rhs) -> lhs.getLocation().compareTo(rhs.getLocation()))
              .toList();

          return ResultOrProblems.failed(Problems.foundWith("pipeline",
              PROBLEMS.ambiguousOutput(candidates, PROBLEMS.ambiguousOutputHint())
          ));
        }
      }

      return ResultOrProblems.of(new ChildPipeline(
          newPipeline,
          originalInputStep,
          stubInputStep,
          outStep
      ));
    } finally {
      // ensure we pop this sub-pipelines location off the stack.
      recursionStack.get().remove();
    }
  }

  /**
   * Returns an updated version of the parent pipeline with all the child steps inserted in to it (but with a prefix)
   */
  private RealizedPipeline insertChildIntoParent(RealizationInput input, ChildPipeline child, List<Problem> warnings) {

    // start adding the newly realized pipeline's steps in to the main pipeline, but in traversal order so we can
    // rebuild each step with a new name and new dependencies.
    RealizedPipeline addTo = input.getRealizedPipeline();

    // if the child failed to realize, we start by adding a step in to the pipeline with the input name that contains
    // all of the child pipeline's errors
    RealizedStep failed;
    if (child.realized.hasFailures()) {
      failed = input.newPrototypeStep().withProblems(PROBLEMS.childFailed(
          child.realized.getAst().getMetadata().getLocation(),
          child.realized.getFailures().toArray(Problem[]::new))
      );
      addTo = addTo.add(failed);
    } else {
      failed = null;
    }

    LinkedList<RealizedStep> visitStack = new LinkedList<>(child.realized.getRealizedSteps());

    // the steps we spit out get renamed to avoid collisions with the parent pipeline
    // TODO we might want to spit out warnings if users put dots in their step names?
    final String childStepPrefix = input.newPrototypeStep().getName() + ".";

    // Maps from steps that were in the child pipeline to the ones that we end up putting in the pipeline we return.
    // This lets us rebuild the dependencies for each step as it gets added to the pipeline we return
    Map<RealizedStep, RealizedStep> rebuilt = new HashMap<>();

    if (child.hasInput()) {
      // this seems a little backwards at first, but it makes sense:  The renamed input is the one we put in the child
      // pipeline, but ultimately we want the steps we add to the resulting pipeline to refer to the original input step
      rebuilt.put(child.renamedInput, child.originalInput);

      // we don't need to visit this one, we did it 'manually' above
      visitStack.remove(child.renamedInput);
    }

    // for assertions
    int cycleDetect = 0;

    visitLoop:
    while (!visitStack.isEmpty()) {
      RealizedStep childStep = visitStack.removeFirst();

      // rename it
      RealizedStep renamed = childStep.withName(childStepPrefix + childStep.getName());

      // rebuild dependencies
      List<RealizedStep> rebuiltDependencies = new ArrayList<>(childStep.getDependencies().size());
      for (RealizedStep originalDependency : childStep.getDependencies()) {
        RealizedStep rebuiltDependency = rebuilt.get(originalDependency);

        // TODO see if this is necessary - I suspect it won't ever happen because the steps should already be
        // in traversal order (but then we might need to assert that somewhere else in a test)
        if (rebuiltDependency == null) {
          visitStack.add(childStep);

          if (cycleDetect == visitStack.size()) {
            throw new AssertionError("cycle detected " + visitStack);
          }

          cycleDetect++;
          continue visitLoop;
        }

        rebuiltDependencies.add(rebuiltDependency);
      }

      if (failed != null) {
        // if this step depends on our magic failed step, it'll prevent the error coming out twice, but will preserve
        // the entire pipeline for --print (maybe this is a bad reason, let's see how this pans out)
        rebuiltDependencies.add(failed);
      }

      // reset assertion counter
      cycleDetect = 0;

      // update the pipeline
      RealizedStep rebuiltStep = renamed.withDependencies(rebuiltDependencies);
      addTo = addTo.add(rebuiltStep);
      rebuilt.put(childStep, rebuiltStep);
    }

    // there's no need to add the noop step (below) if failed, we've already added it first (with all the errors)
    if (failed != null) {
      return addTo;
    }

    if (! child.outputStep.getProduces().equals(Struct.EMPTY_STRUCT)) {
      // rather than renaming the output step, we add in a no-op step with the name (so it's less confusing)
      Projector noop = Projector.identity(child.outputStep.getProduces());

      addTo = addTo.add(input.newPrototypeStep()
                  .withResult(noop)
                  .withProblems(warnings)
                  .withDependencies(rebuilt.get(child.outputStep)));

    }
    return addTo;
  }

  @Override
  public RealizedPipeline realize(RealizationInput input) {
    return buildParametersObject(input)
        .flatMap(params -> validateParameters(params))
        .flatMap(params -> loadSubpipeline(params)
            .flatMap(pipelineResource -> {
              return parseChildPipeline(pipelineResource, params)
                  .flatMap((pipeline) -> realizeChildPipeline(input, pipeline))
                  .composeProblems((severity, problems) -> {
                    // compose the problem with the subpipeline's location so the user knows which
                    // subpipeline file they need to look into for problem diagnosis.
                    return Problems.foundWith(params.location, problems);
                  });
            })
        )
        .flatMap((child, warnings) -> {
          return ResultOrProblems.of(insertChildIntoParent(input, child, warnings));
        }).orElseGet(problems -> {
      return input.getRealizedPipeline().add(input.newPrototypeStep().withProblems(problems));
    });
  }
}
