/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.pipeline;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import com.google.common.base.CaseFormat;
import com.google.common.collect.Range;

import lombok.RequiredArgsConstructor;
import nz.org.riskscape.dsl.Token;
import nz.org.riskscape.engine.ArgsProblems;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.bind.Parameter;
import nz.org.riskscape.engine.pipeline.PipelineBuilder.ProblemCodes;
import nz.org.riskscape.pipeline.StepNamingPolicy;
import nz.org.riskscape.pipeline.ast.PipelineDeclaration;
import nz.org.riskscape.pipeline.ast.PipelineExpression;
import nz.org.riskscape.pipeline.ast.StepChain;
import nz.org.riskscape.pipeline.ast.StepDeclaration;
import nz.org.riskscape.pipeline.ast.StepDefinition;
import nz.org.riskscape.pipeline.ast.StepReference;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.TokenTypes;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.FunctionCall.Argument;

@RequiredArgsConstructor
public class DefaultPipelineRealizer implements PipelineRealizer {

  /*
   * A lot of this code is an amalgam from the DefaultPipelineExecutor and the PipelineBuilder
   */

  /**
   * Represents a step definition in the pipeline. Helps to track the step
   * dependencies and the realization result
   */
  @RequiredArgsConstructor
  private class Node {
    final String stepName;
    final StepDefinition definition;
    final Step implementation;
    final List<Node> dependencies = new LinkedList<>();
    final List<Node> dependents = new LinkedList<>();
    final List<Problem> problems = new LinkedList<>();
    final Map<String, Node> namedInputs = new HashMap<>();


    List<Problem> addEdgeTo(Node to, Optional<String> namedInput) {
      Problem error = checkEdgeValid(to, namedInput);
      if (error != null) {
        return Arrays.asList(error);
      }

      // add the dependency between the two steps
      dependents.add(to);
      to.dependencies.add(this);

      if (namedInput.isEmpty()) {
        // use the default input name, if any
        namedInput = to.implementation.getDefaultInputName();
      }

      // keep track of what named inputs have been used
      namedInput.map(name -> to.namedInputs.put(name, this));
      return Collections.emptyList();
    }

    private Problem checkEdgeValid(Node to, Optional<String> namedInput) {
      if (dependents.contains(to)) {
        // the exact same edge has been chained twice
        return Problem.error(Pipeline.ProblemCodes.EDGE_ALREADY_EXISTS,
            stepName, to.stepName);
      }

      if (namedInput.isPresent()) {
        String name = namedInput.get();

        // check that named input are supported on the target step
        if (to.implementation.getInputNames().isEmpty()) {
          return Problem.error(Pipeline.ProblemCodes.NAMED_INPUT_NOT_ALLOWED,
              to.implementation.getId(), name, to.stepName);
        } else if (!to.implementation.hasNamedInput(name)) {
          // step supports named inputs, but not the one specified
          return PipelineProblems.get().namedInputUnknown(
              to.implementation, name, to.implementation.getInputNames());
        }
      }

      // otherwise check if there are different 'from' steps connecting to the same named input
      Optional<Node> existing = to.getAlreadyChained(namedInput);
      if (existing.isPresent()) {
        // use the full target the user specified in the error
        String targetName = to.stepName + namedInput.map(name -> "." + name).orElse("");

        // fine tune the error we give the user. If we've ended up using the default named
        // input, then this implicit behaviour might not be immediately obvious to the user
        if (!namedInput.isPresent() && to.implementation.getDefaultInputName().isPresent()) {
          // suggest what named input they should be using
          String suggestedTarget = to.stepName + "." + to.getNamedInputAvailable().get();
          return Problem.error(Pipeline.ProblemCodes.DEFAULT_INPUT_ALREADY_CHAINED,
              stepName + " -> " + targetName,
              suggestedTarget, to.stepName + "." + to.implementation.getDefaultInputName().get(),
              existing.get().stepName);
        } else {
          // simple case - it's an unnamed input
          return Problem.error(Pipeline.ProblemCodes.INPUT_ALREADY_CHAINED, stepName, targetName,
              existing.get().stepName);
        }
      }

      return null; // no problemo
    }

    /**
     * @return the node already chained to this node's input, if any
     */
    Optional<Node> getAlreadyChained(Optional<String> namedInput) {
      if (!namedInput.isPresent()) {
        namedInput = implementation.getDefaultInputName();
      }
      Optional<Node> existing = namedInput.map(name -> namedInputs.get(name));
      if (existing.isPresent()) {
        return existing;
      }

      Range<Integer> arity = implementation.getInputArity();
      if (arity.hasUpperBound() && arity.upperEndpoint() == dependencies.size()) {
        // already at maximum capacity, just return the first dependency
        // NB this is failing if a step's input arity is 0?  Found with exec step
        assert ! dependencies.isEmpty() : "dependencies are empty " + this;
        return Optional.of(dependencies.get(0));
      }
      return Optional.empty();
    }

    Optional<String> getNamedInputAvailable() {
      List<String> available = new ArrayList<>(implementation.getInputNames());
      available.removeAll(namedInputs.keySet());
      // return the first available named input, if any
      return available.size() > 0 ? Optional.of(available.get(0)) : Optional.empty();
    }

    boolean isRealized(RealizedPipeline pipeline) {
      return pipeline.getStep(stepName).isPresent();
    }

    boolean isDependenciesAllRealized(RealizedPipeline pipeline) {
      return dependencies.stream().allMatch(n -> n.isRealized(pipeline));
    }

    public Parameter getParameter(int index) {
      // humans like 1-based arrays
      final int humanIndex = index + 1;

      Argument argument = definition.getStepParameters().get(index);
      String name = argument.getNameToken()
      .map(t -> {
        if (t.type == TokenTypes.QUOTED_IDENTIFIER) {
          // quoted identifiers are returned as is to allow non expected parameter name formats to
          // work in pipeline dsl.
          return t.getValue();
        }

        // identifiers are case format mapped from pipeline DSL (expected format) to
        // format expected by pipeline step parameters.
        return dslToStep(t.getValue());
      })
      .orElse(null);

      final boolean preceedingArgsKeyworded = definition.getStepParameters().subList(0, index).stream().
          anyMatch(Argument::isKeywordArgument);

      if (name == null) {
        if (preceedingArgsKeyworded) {
          //problem, once keywords have been used you can't go back to anon parameters
          problems.add(Problem.error(ProblemCodes.KEYWORD_REQUIRED, humanIndex));
          return null;
        }
        if (index >= implementation.getParameterSet().size()) {
          //try to set more arguments than exist.
          problems.add(ArgsProblems.get().wrongNumber(implementation.getParameterSet().size(), humanIndex));
          return null;
        } else {
          // use the parameter based on its order
          name = implementation.getParameterSet().toList().get(index).getName();
        }
      }

      try {
        return implementation.getParameterSet().get(name);
      } catch (IllegalArgumentException ex) {
          List<String> available = implementation.getParameterSet().getDeclared().stream()
              .map(Parameter::getName)
              .map(DefaultPipelineRealizer::stepToDsl)
              .collect(Collectors.toList());

          problems.add(
              Problem.error(ProblemCodes.STEP_PARAMETER_UNKNOWN, name, available, implementation.getId()));

          return null;
        }
      }

    public boolean hasProblems() {
      return !problems.isEmpty();
    }

    public List<RealizedStep> getRealizedDependencies(RealizedPipeline pipeline) {
      if (implementation.getInputNames().isEmpty()) {
        // input order does not concern this step
        return dependencies.stream().map(node -> pipeline.getStep(node.stepName).get()).toList();
      }
      // step uses named inputs - return the dependencies in the same order
      // that the inputs are declared
      return implementation.getInputNames().stream()
          .map(name -> namedInputs.get(name))
          // note that realization will fail if not all named inputs are present
          .filter(node -> node != null)
          .map(node -> pipeline.getStep(node.stepName).get())
          .toList();
    }
  }

  /**
   * Constructs the {@link RealizedPipeline} from a given {@link PipelineDeclaration}
   */
  @RequiredArgsConstructor
  private class Instance {

    private final ExecutionContext context;
    private final PipelineDeclaration ast;
    private final PipelineSteps availableSteps;
    private final Function<StepDeclaration, String> stepNamer;

    private final List<Problem> problems = new LinkedList<>();
    private final Map<String, Node> nodes = new HashMap<>();

    /**
     * Fills up the node map with pre-populated values for the given pipeline, so that incremental realization works
     */
    void populateNodeMapFromRealized(RealizedPipeline addTo) {
      // add all nodes in
      for (RealizedStep alreadyRealized : addTo.getRealizedSteps()) {
        Node node =
            new Node(alreadyRealized.getName(), alreadyRealized.getAst(), alreadyRealized.getImplementation());

        Node replaced = nodes.put(node.stepName, node);

        if (replaced != null) {
          throw new AssertionError("Existing pipeline is invalid, has multiple steps with the same name");
        }
      }

      // add dependencies
      for (RealizedStep alreadyRealized : addTo.getRealizedSteps()) {
        Node node = nodes.get(alreadyRealized.getName());

        alreadyRealized.getDependencies().stream().map(depStep -> node.dependencies.add(nodes.get(depStep.getName())));
      }

      // add dependents
      for (RealizedStep alreadyRealized : addTo.getRealizedSteps()) {
        Node node = nodes.get(alreadyRealized.getName());

        for (Node dependencyNode : node.dependencies) {
          dependencyNode.dependents.add(node);
        }
      }
    }

    /**
     * Creates a {@link Node} for each step definition in the pipeline AST, plus pre-built stuff
     */
    void populateNodeMapFromAst() {
      Iterator<StepDefinition> definitions = ast.stepDefinitionIterator();
      while (definitions.hasNext()) {
        StepDefinition toAdd = definitions.next();
        Node newNode = addDefinition(toAdd);
        nodes.put(newNode.stepName, newNode);
      }
    }

    /**
     * Connects the pipeline Nodes together so we can realize steps in the correct dependency order
     */
    boolean addEdges() {
      boolean valid = true;
      for (StepChain chain : ast.getChains()) {
        List<StepDeclaration> steps  = chain.getSteps();
        StepDeclaration lastStep = steps.get(0);

        validateFirstStepInChain(steps);
        // iterator through steps pair-wise, adding edges from step[i - 1] to step[i]
        for (int i = 1; i < steps.size(); i++) {
          StepDeclaration nextStep = steps.get(i);
          Node from = findStep(lastStep);
          Node to = findStep(nextStep);

          if (from == null || to == null) {
            lastStep = nextStep;
            continue;
          }
          // we can hit errors adding the edges if the user has joined things up incorrectly
          List<Problem> edgeErrors = from.addEdgeTo(to, nextStep.getNamedInput());
          if (edgeErrors.size() > 0) {
            problems.add(Problems.foundWith(nextStep, edgeErrors));
            valid = false;
          }
          lastStep = nextStep;
        }
      }
      return valid;
    }

    /**
     * Find the Node in the pipeline that corresponds to the given step
     */
    private Node findStep(StepDeclaration stepDecl) {
      Node lookingFor = nodes.get(stepNamer.apply(stepDecl));

      if (lookingFor == null) {
        problems.add(sourceError(stepDecl, Problem.error(Pipeline.ProblemCodes.STEP_NAME_UNKNOWN, stepDecl.getIdent(),
            nodes.keySet())));
      }

      return lookingFor;
    }

    private void validateFirstStepInChain(List<StepDeclaration> steps) {
      final StepDeclaration firstStep = steps.get(0);

      if (steps.size() == 1) {
        // // a reference by itself does nothing
        if (findStep(firstStep) != null && firstStep instanceof StepReference) {
          problems.add(sourceError(firstStep, Problem.error(ProblemCodes.UNUSED_STEP_REFERENCE, firstStep.getIdent())));
        }
      }

      // foo.bar -> baz - the bar is pointless as it receives no input
      if (firstStep.getNamedInput().isPresent()) {
        problems.add(
            sourceError(firstStep,
              Problem.error(ProblemCodes.UNUSED_NAMED_INPUT, firstStep.getIdent(), firstStep.getNamedInput().get())));
      }
    }

    /**
     * @return a Node for a given step definition. Any problems encountered are added to the Node
     */
    private Node addDefinition(StepDefinition toAdd){
      String name = stepNamer.apply(toAdd);
      ResultOrProblems<Step> stepOr = availableSteps.getOr(toAdd.getStepId());

      Step step = stepOr.orElse(NullStep.INSTANCE);
      Node node = new Node(name, toAdd, step);

      if (stepOr.hasProblems()) {
        node.problems.add(Problems.foundWith(toAdd, stepOr.getProblems()));
      }

      Node existing = nodes.get(name);
      if (existing != null) {
        // Incremental realization means that we can get this problem from adding a second pipeline in, although steps
        // would do well to alter the step names to avoid this
        this.problems.add(
          sourceError(toAdd,
            Problem.error(ProblemCodes.STEP_REDEFINITION, name, existing.definition, node.definition.getIdentToken())
          )
        );

        // return the existing one so we don't have to handle a null case (it'll just add the existing node back in)
        return existing;
      }

      return node;
    }

    private Token getLocation(PipelineExpression expr) {
      return expr.getBoundary().map(pair -> pair.getLeft()).orElse(Token.UNKNOWN_LOCATION);
    }

    private Problem sourceError(PipelineExpression expr, Problem... children) {
      return Problems.foundWith(getLocation(expr), children);
    }

    public RealizedPipeline traverseAndRealize(final RealizedPipeline addTo) {
      LinkedList<Node> toVisit = new LinkedList<>();
      // don't include pre-realized steps
      toVisit.addAll(nodes.values().stream().filter(node -> !node.isRealized(addTo)).toList());

      int skipCount = 0;
      RealizedPipeline ptr = addTo;
      while (!toVisit.isEmpty()) {
        // safety to make sure we've not got cycles - should already have been checked during parsing?
        assert skipCount++ < toVisit.size();

        Node node = toVisit.removeFirst();
        // dependencies not satisfied, skip it
        if (!node.isDependenciesAllRealized(ptr)) {
          toVisit.addLast(node);
          continue;
        }

        skipCount = 0;        // reset failsafe
        ptr = realize(ptr, node);
        if (ptr.hasFailures()) {
          // pipeline has failed, bail fast other wise we'll trigger the skip count assertion
          return ptr;
        }
      }

      return ptr;
    }

    private RealizedPipeline realize(RealizedPipeline realized, Node node) {
      // don't bother trying to realize the parameters if the node is already broken - going to just be noise
      Map<String, List<?>> parameterMap;
      if (!node.hasProblems()) {
        parameterMap = realizeParameters(node);
      } else {
        parameterMap = Map.of();
      }

      RealizationInput input = new RealizationInputImpl(
          realized,
          node.definition,
          node.stepName,
          node.getRealizedDependencies(realized),
          parameterMap
      );

      if (node.hasProblems()) {
        // there's no point trying to realize a node with problems, instead add in a failed step with the node's
        // problems so it can be debugged by the user
        return realized.add(input.newPrototypeStep().withProblems(node.problems));
      } else {
        RealizedPipeline pipeline = node.implementation.realize(input);

        // if this node has dependencies then it needs to have added a step with its name.  Strictly speaking,
        // we don't need to throw an error here if we can work out what step was just added, but that doesn't
        // seem worth the extra loc - can't see a use case for that just yet.
        if (!node.dependents.isEmpty() && pipeline.getStep(input.getName()).isEmpty()) {
          // a step that produces no output should not have any dependents
          pipeline = pipeline.addProblems(PipelineProblems.get().chainingFromStepWithNoOutput(
              input.getName(),
              node.dependents.stream().map(n -> n.stepName).toList()
          ));
        }

        return pipeline;
      }
    }

    private Map<String, List<?>> realizeParameters(Node node) {

      Map<String, List<?>> parameters = new HashMap<>();

      for (int i = 0; i < node.definition.getStepParameters().size(); i++) {
        Argument arg = node.definition.getStepParameters().get(i);

        Parameter parameter = node.getParameter(i);

        if (parameter == null) {
          assert(node.hasProblems());
          continue;
        }

        // TODO want to move away from this and make it the step's job to go from expressions -> objects?
        Expression argExpression = arg.getExpression();
        if (Expression.class.isAssignableFrom(parameter.getType())) {
          parameters.put(parameter.getName(), Collections.singletonList(argExpression));
        } else {

          // TODO clean this up later - the type is only used for the error message
          Object constant = context.getRealizationContext().getExpressionRealizer()
              .realizeConstant(arg.getExpression())
              .map(re -> re.evaluate(Tuple.EMPTY_TUPLE)).orElse(node.problems::addAll, null);

          if (constant != null) {
            if (constant instanceof List<?> list) {
              parameters.put(parameter.getName(), list);
            } else {
              parameters.put(parameter.getName(), Arrays.asList(constant));
            }
          }
        }
      }

      return parameters;
    }
  }

  @Override
  public RealizedPipeline realize(ExecutionContext context, PipelineDeclaration pipeline) {
    return realize(RealizedPipeline.empty(context, pipeline),
        pipeline);
  }

  @Override
  public RealizedPipeline realize(RealizedPipeline addTo, PipelineDeclaration pipeline) {
    return realize(addTo,
        pipeline,
        addTo.getContext().getEngine().getPipelineSteps());
  }

  protected RealizedPipeline realize(RealizedPipeline addTo, PipelineDeclaration pipeline,
      PipelineSteps engineSteps) {

    StepNamingPolicy stepNamingPolicy = new DefaultStepNamingPolicy(
        addTo.getRealizedSteps().stream().map(RealizedStep::getName).collect(Collectors.toSet())
    );

    Instance instance = new Instance(
        addTo.getContext(),
        pipeline,
        engineSteps,
        pipeline.getStepNameFunction(stepNamingPolicy)
    );

    // make sure steps in the ast can be linked to the existing pipeline
    instance.populateNodeMapFromRealized(addTo);

    // turn the ast in to a dag
    instance.populateNodeMapFromAst();

    // look for cycles once we've built nodes but before we've added edges
    instance.ast.checkValid(instance.stepNamer).addProblemsTo(instance.problems);

    // only try and add edges if checkValid worked
    if (!Problem.hasErrors(instance.problems)) {
      instance.addEdges();
    }

    if (Problem.hasErrors(instance.problems)) {
      List<RealizedStep> dummySteps =
          instance.nodes.values().stream().map(node -> RealizedStep.named(node.stepName)).toList();
      // TODO - try and add to passed in one?
      return new RealizedPipeline(addTo.getContext(), pipeline, dummySteps, instance.problems);
    } else {
      return instance.traverseAndRealize(addTo);
    }
  }


  /**
   * Converts parameterName from  pipeline DSL format to step format.
   *
   * @param parameterName expected to be in pipeline dsl parameter format
   * @return value converted to be in pipeline step format
   */
  static String dslToStep(String parameterName) {
    return CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_HYPHEN, parameterName);
  }

  /**
   * Converts parameterName from step format to pipeline DSL format.
   *
   * @param parameterName expected to be in pipeline step parameter format
   * @return value converted to be in DSL format
   */
  static String stepToDsl(String parameterName) {
    return CaseFormat.LOWER_HYPHEN.to(CaseFormat.LOWER_UNDERSCORE, parameterName);
  }

}
