/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.pipeline.ast;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.function.BiPredicate;
import java.util.function.Function;
import java.util.function.Predicate;

import com.google.common.collect.Lists;

import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import nz.org.riskscape.dsl.SourceLocation;
import nz.org.riskscape.dsl.Token;
import nz.org.riskscape.engine.OsUtils;
import nz.org.riskscape.engine.bind.Parameter;
import nz.org.riskscape.engine.pipeline.PipelineProblems;
import nz.org.riskscape.engine.problem.GeneralProblems;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.pipeline.PipelineMetadata;
import nz.org.riskscape.pipeline.PipelineParser;
import nz.org.riskscape.pipeline.StepNamingPolicy;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.MinimalVisitor;
import nz.org.riskscape.rl.ast.ParameterToken;
import nz.org.riskscape.util.ListUtils;


/**
 * AST for a pipeline declaration - the root of the AST.
 */
@RequiredArgsConstructor @EqualsAndHashCode(callSuper = false)
public final class PipelineDeclaration extends BaseExpr {

  /**
   * Simple value holder for the results of finding a step via the {@link PipelineDeclaration#find(BiPredicate)} method.
   * Was originally a {@link Pair}, but the type signature was pretty hideous and this makes reading the use of a find
   * a bit easier to comprehend.
   */
  @Data
  public static class Found {

    public static Found last(PipelineDeclaration decl, StepChain chain) {
      return new Found(decl, chain, chain.getLast());
    }

    public static Found last(PipelineDeclaration decl) {
      return new Found(decl, decl.getLast(), decl.getLast().getLast());
    }

    private final PipelineDeclaration pipeline;
    private final StepChain chain;
    private final StepDeclaration step;

    /**
     * Append a {@link StepChain} to the found {@link StepChain}, returning a new {@link PipelineDeclaration} with the
     * change.
     */
    public PipelineDeclaration append(StepChain newChain) {
      return pipeline.replace(chain, chain.append(newChain));
    }

    /**
     * @return the list index of the found chain within the pipeline declaration
     */
    public int getChainIndex() {
      return pipeline.chains.indexOf(chain);
    }

    /**
     * Replace the found {@link StepDeclaration} with the given one, returning a new {@link PipelineDeclaration} and
     * {@link StepChain} with the change.
     */
    public PipelineDeclaration replace(StepDeclaration replacement) {
      StepChain newChain = chain.replace(step, replacement);
      return pipeline.replace(chain, newChain);
    }
  }

  /**
   * An empty {@link PipelineDeclaration}.  Useful in various situations where a null might otherwise be used, or might
   * require a different API method to handle a base case.
   */
  public static final PipelineDeclaration EMPTY = new PipelineDeclaration(Collections.emptyList());

  @Getter
  private final List<StepChain> chains;

  @Getter
  private final PipelineMetadata metadata;

  /**
   * Construct a new {@link PipelineDeclaration} without any metadata.
   */
  public PipelineDeclaration(List<StepChain> chains) {
    this.chains = chains;
    this.metadata = PipelineMetadata.ANONYMOUS;
  }

  @Override
  public Optional<Pair<Token, Token>> getBoundary() {
    if (chains.isEmpty()) {
      return Optional.empty();
    } else {
      StepChain first = chains.get(0);
      StepChain last = chains.get(chains.size() - 1);

      // this folds first and last's boundaries, which might be absent, in to our own boundary
      return first
          .getBoundary()
          .flatMap(firstPair
              -> last.getBoundary().map(lastPair -> Pair.of(firstPair.getLeft(), lastPair.getRight()))
          );
    }
  }


  /**
   * Search through for the definition of a step with the given name, using either the step id or alias
   * @param stepName name (id or alias) of a step to look for
   * @return the first found declaration of the found step, along with the step chain that it was found in, or empty
   * if nothing was found
   */
  public Optional<Found> findDefinition(@NonNull String stepName) {
    return find((chain, decl) -> decl
        .isA(StepDefinition.class)
        .flatMap(StepDefinition::getName)
        .map(name -> name.equals(stepName))
        .orElse(false)
    );
  }

  /**
   * Convenience version of {@link #find(BiPredicate)} that drops the chain from the predicate
   */
  public Optional<Found> find(Predicate<StepDeclaration> predicate) {
    return find((ignored, step) -> predicate.test(step));
  }

  /**
   * General purpose step-finding method.  Find a step using a predicate, returning the first step that matches (in ast
   * order).
   * @param predicate test each step against this predicate
   * @return the first step (and chain) that matched, or empty if nothing found.
   */
  public Optional<Found> find(BiPredicate<StepChain, StepDeclaration> predicate) {
    for (StepChain stepChain : chains) {
      for (StepDeclaration declaration : stepChain.getSteps()) {
        if (predicate.test(stepChain, declaration)) {
          return Optional.of(new Found(this, stepChain, declaration));
        }
      }
    }

    return Optional.empty();
  }

  /**
   * General purpose step-finding method.  Find a list of steps using a predicate, returning all steps that match (in
   * ast
   * order).
   * @param predicate test each step against this predicate
   * @return a list of {@link Found} objects for the steps that matched the predicate
   */
  public List<Found> findAll(BiPredicate<StepChain, StepDeclaration> predicate) {
    List<Found> collected = new LinkedList<>();
    for (StepChain stepChain : chains) {
      for (StepDeclaration declaration : stepChain.getSteps()) {
        if (predicate.test(stepChain, declaration)) {
          collected.add(new Found(this, stepChain, declaration));
        }
      }
    }

    return collected;
  }

  /**
   * Return a new {@link PipelineDeclaration}, with a replacement step chain
   * @param found the step chain to replace
   * @param newChain the replacement chain
   * @return a new {@link PipelineDeclaration} with the replacement made
   * @throws IllegalStateException if found is not part of this pipeline declaration
   */
  public PipelineDeclaration replace(@NonNull StepChain found, @NonNull StepChain newChain)
  throws IllegalStateException {
    List<StepChain> clone = new ArrayList<>(this.chains);
    int index = clone.indexOf(found);

    if (index < 0) {
      throw new IllegalStateException("chain not a member of this declaration");
    }

    clone.set(index, newChain);

    return new PipelineDeclaration(clone);
  }

  public Map<ParameterToken, List<StepDefinition>> findParameters() {
    Map<ParameterToken, List<StepDefinition>> tokens = new LinkedHashMap<>();

    MinimalVisitor<StepDefinition> tokenVisitor = new MinimalVisitor<StepDefinition>() {
      @Override
      public StepDefinition visit(ParameterToken parameterToken, StepDefinition data) {
        tokens.compute(parameterToken, (k, v) -> {
          List<StepDefinition> list = (v == null) ? new LinkedList<>() : v;
          if (! list.contains(data)) {
            // we don't want to add the same step to the list more than once.
            list.add(data);
          }
          return list;
        });
        return null;
      }
    };

    Iterator<StepDefinition> iterator = stepDefinitionIterator();
    while (iterator.hasNext()) {
      StepDefinition defn = iterator.next();
      defn.getStep().accept(tokenVisitor, defn);
    }

    return tokens;
  }

  /**
   * Builds a new PipelineDeclaration by replacing tokens with those specified, while trying to preserve whitespace
   *
   * @param replacements parameter names mapped to replacement expressions.
   * @return a new {@link PipelineDeclaration}, or problems if any of the parameter tokens found in the pipeline
   * declaration that were not replaced.
   */
  public ResultOrProblems<PipelineDeclaration> replaceParameters(
      Map<String, Expression> replacements
  ) {
    Optional<Pair<Token, Token>> boundary = getBoundary();

    // extract the original pipeline source text, we call it pipelineWithParamValues because we will be
    // updating it later
    String pipelineWithParamValues;
    if (! boundary.isPresent()) {
      // this isn't going to preserve whitespace, but it will at least work
      pipelineWithParamValues = toSource();
    } else {
      pipelineWithParamValues = boundary.get().getLeft().source;
    }

    // now we find all the parameter tokens that have been used.
    List<ParameterToken> parameterTokens = new ArrayList<ParameterToken>();
    MinimalVisitor<List<ParameterToken>> tokenCollector = new MinimalVisitor<List<ParameterToken>>() {
      @Override
      public List<ParameterToken> visit(ParameterToken parameterToken, List<ParameterToken> tokens) {
        tokens.add(parameterToken);
        return tokens;
      }
    };
    Iterator<StepDefinition> stepIterator = stepDefinitionIterator();
    while (stepIterator.hasNext()) {
      StepDefinition defn = stepIterator.next();
      defn.getStep().accept(tokenCollector, parameterTokens);
    }

    List<ParameterToken> unresolved = new LinkedList<>();

    // We iterate over the tokens in reverse order so we are replacing from last -> first.
    // We do this because substituting the $param token for the value is going to mess up the source boundary
    // of later tokens (hence doing it in reverse.
    // This approach will only work when the parameterized pipeline has come from a single source. But this
    // is that case for this model as the whole pipeline will come from the location/pipeline.
    for (int i = parameterTokens.size() - 1; i >= 0; i--) {
      ParameterToken token = parameterTokens.get(i);

      Expression paramExpr = replacements.get(token.getValue());
      if (paramExpr == null) {
        unresolved.add(token);
        continue;
      }

      String beforeParam = pipelineWithParamValues.substring(0, token.getBoundary().get().getLeft().begin);
      String afterParam = pipelineWithParamValues.substring(
          token.getBoundary().get().getRight().end, pipelineWithParamValues.length());

      // We get the paraSource by extracting it from the token source if possible. This is to keep any
      // white space that may exist. Otherwise we need to fall back to expression.toSource()
      String paramSource = paramExpr.getBoundary()
          .map(paramBoundary -> {
            // TODO we should really be ensuring all parts of the parameter have a common source.
            // but we know they will have for parameter in this model
            return paramBoundary.getLeft().source.substring(
                paramBoundary.getLeft().begin, paramBoundary.getRight().end);
          }).orElse(paramExpr.toSource());
      pipelineWithParamValues = beforeParam + paramSource + afterParam;
    }

    if (unresolved.size() > 0) {
      return ResultOrProblems.failed(unresolved.stream()
          // As long as the original pipeline has been parsed (rather than built
          // programmatically) including the token here should report the location
          .map(token -> Problems.foundWith(token, GeneralProblems.required(token.getValue(), Parameter.class)))
          .toList());
    }

    return ResultOrProblems.of(
        PipelineParser.INSTANCE.parsePipelineAllowParameters(pipelineWithParamValues)
          .withMetadata(this.getMetadata())
    );
  }

  /**
   * Return a new {@link PipelineDeclaration} with the given step chain added to the end of this
   * {@link PipelineDeclaration}'s list of chains
   * @param toAdd the chain to add
   * @return a new {@link PipelineDeclaration} containing the given chain
   */
  public PipelineDeclaration add(@NonNull StepChain toAdd) {
    return new PipelineDeclaration(ListUtils.concat(this.chains, Collections.singletonList(toAdd)));
  }

  /**
   * Return a new {@link PipelineDeclaration} that is the combination of this pipeline and the other - no merge or
   * squishing or anything clever is going on
   */
  public PipelineDeclaration add(PipelineDeclaration toAdd) {
    return new PipelineDeclaration(ListUtils.concat(this.chains, toAdd.chains));
  }

  /**
   * @return the 0th chain in this declaration
   * @throws IndexOutOfBoundsException if this declaration is empty
   */
  public StepChain getFirst() {
    return chains.get(0);
  }

  /**
   * @return the last chain in this declaration
   * @throws IndexOutOfBoundsException if this declaration is empty
   */
  public StepChain getLast() {
    return chains.get(chains.size() - 1);
  }

  /**
   * @return true if this declaration has no chains
   */
  public boolean isEmpty() {
    return chains.size() == 0;
  }

  @Override
  protected void appendString(StringBuilder appendTo) {
    boolean first = true;
    for (StepChain stepChain : chains) {
      if (!first) {
        appendTo.append(", ");
      }

      stepChain.appendString(appendTo);
      first = false;
    }
  }

  @Override
  protected void appendSource(StringBuilder appendTo) {
    boolean first = true;
    for (StepChain stepChain : chains) {
      if (!first) {
        // TODO support an optional separator - maybe a comma?
        appendTo.append(OsUtils.LINE_SEPARATOR);
      }
      stepChain.appendSource(appendTo);
      first = false;
    }
  }

  private class StepDefnIterator implements Iterator<StepDefinition> {

    private int chainIndex = 0;
    private int stepIndex = 0;
    private StepDefinition peeked = null;

    @Override
    public boolean hasNext() {
      return peek(false) != null;
    }
    @Override
    public StepDefinition next() {
      return peek(true);
    }

    private StepDefinition peek(boolean consume) {
      peekloop:
      while (peeked == null && chainIndex < chains.size()) {
        StepChain currentChain = chains.get(chainIndex);

        if (stepIndex < currentChain.size()) {
          StepDeclaration decl = currentChain.getSteps().get(stepIndex++);

          if (decl instanceof StepDefinition) {
            peeked = (StepDefinition) decl;
            break peekloop;
          }
        } else {
          stepIndex = 0;
          chainIndex++;
        }
      }

      StepDefinition toReturn = peeked;

      if (consume) {
        if (peeked == null) {
          throw new NoSuchElementException();
        } else {
          peeked = null;
        }
      }

      return toReturn;
    }
  }

  /**
   * @return an {@link Iterator} that yields all of the {@link StepDefinition}s in this pipeline declaration in the
   * order they appear in the source.  Skips all the {@link StepReference}s.
   */
  public Iterator<StepDefinition> stepDefinitionIterator() {
    return new StepDefnIterator();
  }

  /**
   * @return the number of chains (not steps) in this pipeline declaration
   */
  public int size() {
    return chains.size();
  }

  // used in cycle detection to track edges along with the depth they were seen
  @RequiredArgsConstructor
  private static final class StackEl {
    final StepLink edge;
    final int depth;
  }

  /**
   * @param policy the step naming policy that will be applied
   * @return a function that gives the name for a step definition, returning either a unique generated name,
   * if none was assigned, or the explicitly assigned name
   */
  public Function<StepDeclaration, String> getStepNameFunction(StepNamingPolicy policy) {
    return StepNamingPolicy.broaden(policy.getStepNameFunction(this));
  }

  /**
   * Check that this {@link PipelineDeclaration} is valid. This check includes
   * checking for step redefinition and pipeline cycles.
   *
   * @param nameFunction maps a StepDefinition to a name, see {@link #getStepNameFunction(StepNamingPolicy)}
   * @return function that provides step names for this pipeline or errors found
   */
  public ResultOrProblems<PipelineDeclaration> checkValid(Function<StepDeclaration, String> nameFunction) {
    // we first check for step redefintion as this can throw off the cycle detection
    return checkStepRefinition()
        // and now check for cycles if there is no step redefinition
        .flatMap(pipeline -> detectCycles(nameFunction));
  }

  private ResultOrProblems<PipelineDeclaration> checkStepRefinition() {
    List<Problem> problems = Lists.newArrayList();

    Map<String, StepDefinition> userNamedSteps = new HashMap<>();
    Iterator<StepDefinition> definitions = stepDefinitionIterator();
    while (definitions.hasNext()) {
      StepDefinition toAdd = definitions.next();
      if (toAdd.getName().isPresent()) {
        String name = toAdd.getName().get();

        StepDefinition existing = userNamedSteps.get(name);
        if (existing == null) {
          userNamedSteps.put(name, toAdd);
        } else {
          // use either the the explicit name or the step id (should pretty only be a name, but maybe it's possible
          // to have a collision with implicit names?
          SourceLocation firstDefinition = existing.getNameToken().orElse(existing.getIdentToken()).getLocation();
          // the pipeline author has assigned the same name to multiple steps.
          problems.add(PipelineProblems.get().stepRedefinition(name,
              firstDefinition, toAdd.getNameToken().get().getLocation()));
        }
      }
    }
    if (! problems.isEmpty()) {
      return ResultOrProblems.failed(problems);
    }
    return ResultOrProblems.of(this);
  }

  private ResultOrProblems<PipelineDeclaration> detectCycles(Function<StepDeclaration, String> nameFunc) {

    // we need to keep a track of which edges we've followed so that disconnected parts of the graph are also visited
    LinkedList<StepLink> unseen = new LinkedList<>();
    for (StepChain stepChain : chains) {
      for (int i = 0; i < stepChain.getLinkCount(); i++) {
        StepLink link = stepChain.getLink(i);
        unseen.add(link);
      }
    }

    // this algorithm does a DFS search over the edges of the pipeline's graph, maintaining a stack of the visit
    // to spot cycles.  One important detail here is that we record a depth on the stack alongside the node,  we
    // use this to pop 'dead ends' off the stack when we reach the terminal part of the tree to avoid mis-reporting
    // cycles
    while (!unseen.isEmpty()) {
      StepLink node = unseen.removeFirst();
      // the visit list guides our DFS
      LinkedList<StackEl> visitList = new LinkedList<>();
      // the stack is used to detect cycles
      LinkedList<StackEl> stack = new LinkedList<>();

      // seed the stack and the visit list with the first 'starting' edge
      visitList.add(new StackEl(node, 0));

      // the stack has to have a 'pretend' edge popped on so that the starting point is recorded properly
      stack.add(new StackEl(new StepLink(node.getLhs(), null, node.getLhs()), 0));

      // start traversal
      while (!visitList.isEmpty()) {
        StackEl visiting = visitList.removeLast();
        // remove nodes from unseen once they've been visited - this is used to find 'islands' - a completely
        // connected pipeline will have empty unseen after the first search while loop
        unseen.remove(visiting.edge);

        // pop 'dead ends' off the stack - if the current node has a smaller depth than the one at the top of the stack
        // then we back track up the stack to the place in the tree where the current element was added
        while (stack.size() > 0 && stack.getLast().depth > visiting.depth) {
          stack.removeLast();
        }

        // now check if we've seen an edge to the rhs already
        StackEl alreadySeen = stack
            .stream()
            .filter(pair -> {
              String stackName = nameFunc.apply(pair.edge.getRhs());
              String visitingName = nameFunc.apply(visiting.edge.getRhs());
              return stackName.equals(visitingName);
            })
            .findFirst()
            .orElse(null);

        if (alreadySeen != null) {
          StepLink offendingEdge = visiting.edge;

          // we have a cycle
          return ResultOrProblems.failed(PipelineProblems.get().cycleDetected(
              offendingEdge.toSource(),
              offendingEdge.getChain().getLocation(),
              offendingEdge.getRhs().toSource(),
              alreadySeen.edge.getRhs().getIdentToken().getLocation()
            ));
        }

        // add this edge to the stack for the next cycle check
        stack.addLast(visiting);

        // find next lot of edges for visiting
        ListIterator<StepLink> unseenIter = unseen.listIterator();
        while (unseenIter.hasNext()) {
          StepLink candidate = unseenIter.next();
          String fromName = nameFunc.apply(candidate.getLhs());
          String toName = nameFunc.apply(visiting.edge.getRhs());

          if (fromName.equals(toName)) {
            // add the next lot of nodes to the visit list - we might want to add these in reverse order so that
            // we visit 'earlier' branches before later ones
            visitList.addLast(new StackEl(candidate, visiting.depth + 1));
          }
        }
      }

    }

    return ResultOrProblems.of(this);
  }

  /**
   * Clones this PipelineDeclaration but with different metadata.
   */
  public PipelineDeclaration withMetadata(PipelineMetadata newMetadata) {
    return new PipelineDeclaration(chains, newMetadata);
  }
}
