/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.sched;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import nz.org.riskscape.engine.join.LopsidedJoiner;
import nz.org.riskscape.engine.output.PipelineJobContext;
import nz.org.riskscape.engine.output.SinkParameters;
import nz.org.riskscape.engine.pipeline.Collector;
import nz.org.riskscape.engine.pipeline.Realized;
import nz.org.riskscape.engine.pipeline.RealizedPipeline;
import nz.org.riskscape.engine.pipeline.RealizedStep;
import nz.org.riskscape.engine.pipeline.Sink;
import nz.org.riskscape.engine.pipeline.SinkConstructor;
import nz.org.riskscape.engine.pipeline.TupleInput;
import nz.org.riskscape.engine.projection.AsyncProjector;
import nz.org.riskscape.engine.projection.FlatProjector;
import nz.org.riskscape.engine.projection.Projector;
import nz.org.riskscape.engine.relation.Relation;
import nz.org.riskscape.engine.restriction.Restrictor;
import nz.org.riskscape.engine.task.AccumulatorProcessorTask;
import nz.org.riskscape.engine.task.ChainTask;
import nz.org.riskscape.engine.task.LinkedSteps;
import nz.org.riskscape.engine.task.PageBuffer;
import nz.org.riskscape.engine.task.ReadPageBuffer;
import nz.org.riskscape.engine.task.SinkTask;
import nz.org.riskscape.engine.task.TaskSpec;
import nz.org.riskscape.engine.task.TupleInputTask;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.util.ListUtils;

/**
 * Produces a set of {@link TaskSpec}s for execution on a {@link Scheduler}
 */
@RequiredArgsConstructor
@Slf4j
public class TaskBuilder {

  private static final boolean SINGLE_THREADED = false;
  private static final boolean MULTI_THREADED = true;

  private final SchedulerParams params;

  @RequiredArgsConstructor
  private class TaskBuildingState {
    Map<LinkedSteps, List<TaskSpec>> built = new HashMap<>();
    final int estimatedBuffers;
    final PipelineJobContext context;

    TaskBuildingState(PipelineJobContext context, List<LinkedSteps> steps) {
      this.context = context;
      // sinks don't have an output buffer, but all other steps need one
      // (technically slow joins use 2 buffers, but these share the same Tuple objects)
      int numSinks = (int) steps.stream().filter(step -> step.containsOnly(SinkConstructor.class)).count();
      this.estimatedBuffers = steps.size() - numSinks;
    }

    private PageBuffer getOutputBuffer(LinkedSteps steps) {
      // we've cached the tasks we've already constructed for these steps
      List<TaskSpec> builtTasks = built.get(steps);

      // usually the list will only have one task, but where there are several
      // it's the last task in the list that should always have the output buffer
      TaskSpec lastTask = builtTasks.get(builtTasks.size() - 1);

      // output is a WriteBuffer, but really it is some kind of PageBuffer underneath
      return (PageBuffer) lastTask.getOutput().get();
    }

    public ReadPageBuffer findInputForSteps(LinkedSteps steps, int index) {
      // we lookup our predecessor and use their output buffer as our input buffer
      LinkedSteps predecessor = steps.getOrderedPredecessor(index);
      PageBuffer buffer = getOutputBuffer(predecessor);

      // If the predecessor fans out, then its output will need to be shared amongst
      // multiple downstream tasks. The first descendant still needs to 'own' the
      // input buffer, but everyone else will need to read from a cloned buffer
      boolean isFirstDescendant = (predecessor.getDescendants().indexOf(steps) == 0);
      if (predecessor.hasFanOut() && !isFirstDescendant) {
        return buffer.newReaderClone();
      } else {
        return buffer;
      }
    }

    public ReadPageBuffer getInputFor(LinkedSteps steps) {
      return findInputForSteps(steps, 0);
    }

    public int numTuplesPerBuffer() {
      return params.getTuplesPerTask().orElse(params.getMaxTuplesQueued() / estimatedBuffers);
    }

    public int getPageSize() {
      // now that the user can set the tuples per task, we need some protection against page size
      // less than one (that ends in a nasty exception).
      int defaultPageSize = Math.max(1, numTuplesPerBuffer() / (params.getMaxThreadsPerTask() * 2));
      return params.getPageSize().orElse(defaultPageSize);
    }
  }

  // some steps should be chained together so they are processed in a single task,
  // e.g. projectors, restrictors, etc.
  private ArrayList<RealizedStep> groupSteps(RealizedPipeline pipeline, RealizedStep step) {
    ArrayList<RealizedStep> grouped = new ArrayList<>();

    while (canChain(step)) {
      // union step is a ChainTask, but it fans in. We can chain other tasks after
      // it, but it must go at the start of the chain
      if (step.getDependencies().size() != 1 && !grouped.isEmpty()) {
        break;
      }

      grouped.add(step);

      Set<RealizedStep> dependents = pipeline.getDependents(step);
      // we need to stop grouping if the step fans out, i.e. this step is the hub of a spoke
      if (dependents.size() != 1) {
        break;
      }
      step = dependents.iterator().next();
    }

    // if we couldn't chain anything, just add the original step on its own
    if (grouped.isEmpty()) {
      grouped.add(step);
    }

    return grouped;
  }

  private boolean canChain(RealizedStep step) {
    Class<? extends Realized> type = step.getStepType();
    return (type == Projector.class || type == FlatProjector.class || type == Restrictor.class);
  }

  public boolean canChain(List<RealizedStep> steps) {
    return steps.stream().allMatch(step -> canChain(step));
  }

  /**
   * Breaks a pipeline down into a series of LinkedSteps (groups of steps that can
   * be processed as the same task).
   */
  List<LinkedSteps> decompose(RealizedPipeline pipeline) {
    ArrayList<LinkedSteps> decomposed = new ArrayList<>();
    LinkedList<RealizedStep> toVisit = new LinkedList<RealizedStep>();
    Set<RealizedStep> seen = new HashSet<>();
    toVisit.addAll(pipeline.getStartSteps());

    while (!toVisit.isEmpty()) {
      RealizedStep step = toVisit.remove(0);

      // don't process a step until its dependencies are processed
      if (!seen.containsAll(step.getDependencies())) {
        toVisit.addAll(0, step.getDependencies());
        toVisit.add(step);
        continue;
      }

      if (seen.contains(step)) {
        // only process a step once
        continue;
      }

      // group the steps that can be chained together
      ArrayList<RealizedStep> grouped = groupSteps(pipeline, step);

      // find any predecessor(s) and create the LinkedStep
      List<LinkedSteps> predecessors = findPredecessors(step, decomposed);
      decomposed.add(new LinkedSteps(predecessors, grouped));
      seen.addAll(grouped);

      // make sure we visit any steps downstream from the ones we grouped
      RealizedStep lastStep = grouped.get(grouped.size() - 1);
      toVisit.addAll(pipeline.getDependents(lastStep));
    }

    return decomposed;
  }

  private List<LinkedSteps> findPredecessors(RealizedStep step, ArrayList<LinkedSteps> decomposed) {
    List<LinkedSteps> collected = new ArrayList<LinkedSteps>();
    for (LinkedSteps linkedSteps : decomposed) {
      RealizedStep lastStep = linkedSteps.getLastStep();
      if (step.getDependencies().contains(lastStep)) {
        collected.add(linkedSteps);
      }
    }
    return collected;
  }

  public List<TaskSpec> convertToTasks(PipelineJobContext context) {
    return convertToTasks(context, decompose(context.getPipeline()));
  }

  List<TaskSpec> convertToTasks(PipelineJobContext context, List<LinkedSteps> steps) {
    TaskBuildingState state = new TaskBuildingState(context, steps);
    ArrayList<TaskSpec> tasks = new ArrayList<>(steps.size());

    for (LinkedSteps step : steps) {
      List<TaskSpec> newTasks = toTasks(state, step);
      state.built.put(step, newTasks);
      tasks.addAll(newTasks);

      // this needs a sink task
      if (step.getDescendants().isEmpty() && !step.containsOnly(SinkConstructor.class)) {
        RealizedStep terminalStep = step.getLastStep();

        RealizedStep fakeRealizedStep = RealizedStep
            .named(terminalStep.getStepName() + "-capped")
            .withResult(new SinkConstructor() {

              @Override
              public ResultOrProblems<Sink> newInstance(PipelineJobContext context) {
                return context.getOutputContainer().createSinkForStep(
                  new SinkParameters(terminalStep.getStepName(), terminalStep.getProduces())
                );

              }
            }).withDependencies(terminalStep);

        LinkedSteps fakeLinkedSteps = new LinkedSteps(Arrays.asList(step), Arrays.asList(fakeRealizedStep));
        tasks.add(toSinkTask(state, fakeLinkedSteps));
      }
    }

    if (tasks.size() > 0) {
      log.info("Scheduler adding {} tasks, {} tuples per page, {} tuples per task",
          tasks.size(), state.getPageSize(), state.numTuplesPerBuffer());
    }
    return tasks;
  }

  List<TaskSpec> toTasks(TaskBuildingState state, LinkedSteps convert) {
    if (convert.containsOnly(Relation.class)) {
      return Arrays.asList(toRelationTask(state, convert));
    } else if (canChain(convert.getGrouped())) {
      return Arrays.asList(toChainTask(state, convert));
    } else if (convert.containsOnly(Collector.class)) {
      return toCollectorTasks(state, convert);
    } else if (convert.containsOnly(SinkConstructor.class)) {
      return Arrays.asList(toSinkTask(state, convert));
    } else if (convert.containsOnly(LopsidedJoiner.class)) {
      return toJoinerTasks(state, convert);
    } else if (convert.containsOnly(AsyncProjector.class)) {
      return toAsyncProjectorTask(state, convert);
    }

    throw new RuntimeException("Could not convert steps into tasks: " + convert);
  }

  private List<TaskSpec> toAsyncProjectorTask(TaskBuildingState state, LinkedSteps steps) {
    RealizedStep collectorStep = steps.getFirstStep();
    AsyncProjector collector = collectorStep.getRealized(AsyncProjector.class).get();

    SinkConstructor sinkConstructor = collector.getOutput();
    ReadPageBuffer inputBuffer = state.getInputFor(steps);
    TaskSpec sinkTask = new TaskSpec(
      SinkTask.class,
      List.of(collectorStep.withResult(sinkConstructor)),
      inputBuffer,
      null,
      SINGLE_THREADED,
      state.context
    );

    TupleInput input = collector.getInput();
    TaskSpec inputTask = new TaskSpec(
      TupleInputTask.class,
      List.of(collectorStep.withResult(input)),
      null,
      newOutputBuffer(state),
      SINGLE_THREADED,
      state.context
    );

    return List.of(
        sinkTask,
        inputTask
    );
  }

  private LinkedListBuffer newOutputBuffer(TaskBuildingState state) {
    // The buffer size can be a trade-off between utilizing CPU resources and
    // utilizing memory resources. Some considerations:
    // 1. The buffers should all fit the same number of tuples. Otherwise you could
    // get a deadlock, where not enough output is generated for the next task to
    // run (especially when a multi-threaded task fans out across workers).
    // 2. The buffer page capacity shouldn't be less than the number of threads
    // (otherwise multi-threaded tasks won't really be processed in parallel).
    // 3. The smaller the capacity, the longer steps will be waiting for a slower
    // upstream step to fill pages (i.e. a larger capacity means that downstream
    // work can be done in parallel, despite a bottleneck upstream).
    // 4. If the page size if too small, then a lot of CPU time may be wasted
    // dealing with the buffers' mutex locks.
    // The following seems to work OK when run with 8 threads (but more trial
    // and error may uncover a more optimal combination)
    PageAllocator allocator = new PageAllocator(1)
        // TODO use .withBackoff() to respond to memory pressure
        .doubleEvery(params.getMaxThreadsPerTask())
        .withMaxPageSize(state.getPageSize());

    return new LinkedListBuffer(allocator, state.numTuplesPerBuffer());
  }

  private <T> List<TaskSpec> toJoinerTasks(TaskBuildingState state, LinkedSteps steps) {
    LopsidedJoiner<?> joiner = (LopsidedJoiner<?>) steps.getFirstStep().getResult().get();
    LopsidedJoinAdapter<?> adapter = new LopsidedJoinAdapter<>(joiner);

    // the rhs sink accumulates tuples in to the index
    SinkConstructor sinkConstructor = adapter.newSinkConstructor();
    ReadPageBuffer rhsInputBuffer = state.findInputForSteps(steps, LopsidedJoiner.RHS_STEP_INDEX);
    RealizedStep sinkStep = steps.getFirstStep().withResult(sinkConstructor);

    // NB we could probably chuck this at the end of a chain if we can find it...  Group steps do this.
    TaskSpec sinkTask = new TaskSpec(SinkTask.class, List.of(sinkStep), rhsInputBuffer,
        null, SINGLE_THREADED, state.context);

    // ths lhs is a flat projector that can only proceed once the rhs is built
    FlatProjector flatProjector = adapter.newFlatProjector();
    ReadPageBuffer lhsInputBuffer = state.findInputForSteps(steps, LopsidedJoiner.LHS_STEP_INDEX);
    LinkedListBuffer outputBuffer = newOutputBuffer(state);
    RealizedStep projectorStep = steps.getFirstStep().withResult(flatProjector);

    TaskSpec chainTask = new TaskSpec(ChainTask.class, List.of(projectorStep),
        lhsInputBuffer, outputBuffer, MULTI_THREADED, state.context);

    chainTask.addDependency(sinkTask);

    return Arrays.asList(sinkTask, chainTask);
  }

  private TaskSpec toSinkTask(TaskBuildingState state, LinkedSteps steps) {
    return new TaskSpec(SinkTask.class, steps.getGrouped(), state.getInputFor(steps),
        null, SINGLE_THREADED, state.context);
  }

  private List<TaskSpec> toCollectorTasks(TaskBuildingState state, LinkedSteps steps) {
    Collector<?> collector = steps.getFirstStep().getRealized(Collector.class).get();

    if (steps.getPredecessors().size() != 1) {
      throw new IllegalStateException("A collector should never be proceeded by more than one step");
    }
    LinkedSteps predecessor = steps.getPredecessors().get(0);

    TaskSpec processorSpec = new TaskSpec(AccumulatorProcessorTask.class, steps.getGrouped(), null,
        newOutputBuffer(state), SINGLE_THREADED, state.context);

    // check whether we applied a sink to our dependency - if so, we just mark it as the dependency and don't add an
   // accumulator task
    List<TaskSpec> builtForPredecessor = state.built.get(predecessor);
    if (builtForPredecessor.size() == 1) {
      TaskSpec singleTask = builtForPredecessor.get(0);

      if (singleTask.getLastStep().getRealized(AccumulatorSink.Constructor.class).isPresent()) {
        processorSpec.addDependency(singleTask);
        return Arrays.asList(processorSpec);
      }
    }

    // if we're here, it means we need an explicit task for accumulating - we use the same AccumulatorSink as
    // the chain does, but we run it standalone in a SinkTask
    boolean parallelizable = collector.getCharacteristics().contains(Collector.Characteristic.PARALLELIZABLE)
        ? MULTI_THREADED : SINGLE_THREADED;

    SinkConstructor sinkConstructor = new AccumulatorSink.Constructor(Collections.singletonList(collector));
    RealizedStep originalStep = steps.getFirstStep();

    // replace the collector with a sink constructor
    RealizedStep fauxStep = originalStep.withResult(sinkConstructor);

    TaskSpec accumulatorSpec = new TaskSpec(
        SinkTask.class,
        Collections.singletonList(fauxStep),
        state.getInputFor(steps),
        null,
        parallelizable,
        state.context
    );

    processorSpec.addDependency(accumulatorSpec);
    return Arrays.asList(accumulatorSpec, processorSpec);
  }

  private TaskSpec toRelationTask(TaskBuildingState state, LinkedSteps steps) {
    Relation relation = steps.getFirstStep().getRealized(Relation.class).get();
    RealizedStep asTupleInputStep = steps.getFirstStep().withResult(relation.toTupleInput());

    return new TaskSpec(TupleInputTask.class, List.of(asTupleInputStep), null,
        newOutputBuffer(state), SINGLE_THREADED, state.context);
  }

  private ReadPageBuffer getChainInput(TaskBuildingState state, LinkedSteps steps) {
    // union steps are a special case where we fan in from multiple parent steps, so we
    // need to combine the multiple inputs into a single page buffer the task can read from
    int numInputs = steps.getFirstStep().getDependencies().size();
    if (numInputs > 1) {
      List<ReadPageBuffer> inputs = new ArrayList<>();
      for (int i = 0; i < numInputs; i++) {
        inputs.add(state.findInputForSteps(steps, i));
      }
      return new CombinedPageBuffer(inputs);
    } else {
      return state.getInputFor(steps);
    }
  }

  private TaskSpec toChainTask(TaskBuildingState state, LinkedSteps steps) {

    // here we look ahead to see if we are followed only by parallelizable collectors, and if we are, we create a
    // new realized step - a sink - for the chain to dump in to

    List<Collector<?>> parallelCollectorSteps = findFollowedOnlyByParallelizableCollectors(steps);
    List<RealizedStep> groupedSteps;

    LinkedListBuffer outputBuffer = null;
    // if we have a list of collectors, it means we're good to drop a sink on the end of the chain
    if (!parallelCollectorSteps.isEmpty()) {
      RealizedStep newSinkStep = createParallelizedCollectorStep(steps.getLastStep(), parallelCollectorSteps);

      groupedSteps = ListUtils.append(steps.getGrouped(), newSinkStep);
    } else {
      groupedSteps = steps.getGrouped();
      outputBuffer = newOutputBuffer(state);
    }

    return new TaskSpec(ChainTask.class, groupedSteps, getChainInput(state, steps),
      outputBuffer, MULTI_THREADED, state.context);
  }

  private RealizedStep createParallelizedCollectorStep(
      RealizedStep follows,
      List<Collector<?>> collectors
  ) {

    SinkConstructor constructor = new AccumulatorSink.Constructor(collectors);

    return RealizedStep.named(follows.getStepName() + "-sink")
        .withResult(constructor)
        .withDependencies(follows);
  }

  private List<Collector<?>> findFollowedOnlyByParallelizableCollectors(LinkedSteps steps) {

    List<LinkedSteps> descendants = steps.getDescendants();

    List<Collector<?>> collectors = descendants.stream()
      .filter(ls -> ls.getGrouped().size() == 1) // has to be a single step
      .map(ls -> ls.getGrouped().get(0).getRealized(Collector.class)) // and a collector
      .filter(Optional::isPresent)
      .map(opt -> (Collector<?>) opt.get())
      .filter(Collector::isParallelizable) // and parallelizable
      .collect(Collectors.toList());

    // if the number of qualifying collectors matches the numbers of descendants, then we're good to go
    return collectors.size() == descendants.size() ? collectors : Collections.emptyList();
  }

}
