/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.sched;

import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.junit.Before;
import org.junit.Test;

import nz.org.riskscape.engine.Assert;
import nz.org.riskscape.engine.ProjectTest;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.output.PipelineJobContext;
import nz.org.riskscape.engine.pipeline.ExecutionContext;
import nz.org.riskscape.engine.pipeline.RealizedPipeline;
import nz.org.riskscape.engine.pipeline.RealizedStep;
import nz.org.riskscape.engine.pipeline.TestPipelineJobContext;
import nz.org.riskscape.engine.relation.Relation;
import nz.org.riskscape.engine.task.ReadPageBuffer;
import nz.org.riskscape.engine.task.ReturnState;
import nz.org.riskscape.engine.task.TaskSpec;
import nz.org.riskscape.engine.task.WorkerTask;
import nz.org.riskscape.engine.task.WritePageBuffer;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.ProblemSink;
import nz.org.riskscape.problem.Problems;

public class SchedulerTest extends ProjectTest {

  Struct struct = Types.INTEGER.asStruct();
  int inputCapacity = 3; // pages
  int inputPageSize = 5; // tuples
  int outputCapacity = 3;
  int outputPageSize = 5;
  LinkedListBuffer input;
  LinkedListBuffer output;

  int numWorkers = 2;
  List<Problem> problems = new ArrayList<>();
  ProblemSink problemSink = p -> problems.add(p);
  Scheduler scheduler = new Scheduler(new SchedulerParams(numWorkers), problemSink);
  SchedulerBasedExecutor executor = new SchedulerBasedExecutor(engine);
  int numTaskSpecs = 1;
  boolean parallelizeTask = false;
  List<TestTask> workerTasks;
  ExecutionContext executionContext = executor.newExecutionContext(project);
  RealizedPipeline pipeline = mock(RealizedPipeline.class);
  PipelineJobContext context = new TestPipelineJobContext(pipeline);
  int stepId = 1;

  @Before
  public void setup() {
    problems.clear();

    when(pipeline.getContext()).thenReturn(executionContext);
  }

  private List<RealizedStep> linkedSteps() {
    return Arrays.asList(dummyStep());
  }

  private RealizedStep dummyStep() {
    return RealizedStep.named(String.format("dummy-%d", stepId++))
        // use a mock for verify close gets called
        .withResult(mock(Relation.class), Struct.EMPTY_STRUCT);
  }

  private void addPageToBuffer(WritePageBuffer buffer, long value) {
    Page page = buffer.newPage();
    int i;
    for (i = 0; i < inputPageSize; i++) {
      page.add(Tuple.ofValues(struct, value));
    }
    buffer.add(page);
  }

  public void addPages(WritePageBuffer buffer, int numPages) {
    for (int value = 1; value <= numPages; value++) {
      addPageToBuffer(input, value);
    }
  }

  private void sleep(int millisecs) {
    try {
      Thread.sleep(millisecs);
    } catch (Exception e) {
      // don't care
    }
  }

  private void waitForTasksToRun() {
    for (int i = 0; i < 100; i++) {
      if (scheduler.numWorkersRunning() == 0) {
        break;
      }
      sleep(10);
    }
    assertEquals(0, scheduler.numWorkersRunning());
  }

  private TaskSpec setupSchedulerForSingleTask() {
    input = new LinkedListBuffer(inputPageSize, inputCapacity);
    output = new LinkedListBuffer(outputPageSize, outputCapacity);
    TaskSpec spec = new TaskSpec(TestTask.class, linkedSteps(), input, output, parallelizeTask, context);

    setupScheduler(Arrays.asList(spec));

    return spec;
  }

  private void setupScheduler(List<TaskSpec> specs) {
    workerTasks = new ArrayList<>();
    List<WorkerTask> newTasks;
    try {
      newTasks = scheduler.addTasks(specs);
    } catch (ProblemException e) {
      throw new RuntimeException(e);
    }
    newTasks.forEach(t -> workerTasks.add((TestTask) t));

    scheduler.startWorkers();
    scheduler.detectDeadlocks = false;

    assertTrue(scheduler.getWaitingTasks().size() > 0);
    assertEquals(0, scheduler.getRunningTasks().size());
    assertEquals(0, scheduler.getCompletedTasks().size());
  }

  // create a chain of tasks where the output from one is input for the next
  private List<TaskSpec> makeChainOfTasks(int numSpecs) {
    List<TaskSpec> taskSpecs = new ArrayList<>();
    input = new LinkedListBuffer(inputPageSize, inputCapacity);
    ReadPageBuffer nextInput = input;

    for (int i = 0; i < numSpecs; i++) {
      output = new LinkedListBuffer(outputPageSize, outputCapacity);
      TaskSpec spec = new TaskSpec(TestTask.class, linkedSteps(), nextInput, output, parallelizeTask, context);
      taskSpecs.add(spec);

      // use the output as input into the next task
      nextInput = output;
    }
    return taskSpecs;
  }

  private long numTasksWithTuplesRead() {
    return workerTasks.stream().filter(t -> t.tuplesRead.size() > 0).count();
  }

  private void runTwoTasksInParallel() {
    runTwoTasksInParallel(inputPageSize);
  }

  private void runTwoTasksInParallel(int initialTuplesToRead) {
    long initialNumTasksRan = numTasksWithTuplesRead();
    // With multi-threading, these gets a bit tricky to guarantee that both run.
    // Beforehand, we tell the tasks to block after they've read one page
    workerTasks.forEach(t -> t.blockAfter(initialTuplesToRead));
    // kick off two workers - they should both get to the point where they block
    assertTrue(scheduler.runOnce());
    assertTrue(scheduler.runOnce());

    // make sure that both tasks have started reading input
    // (otherwise it's possible we could unblock them before the task's run()
    // has properly started)
    int loopCount = 0;
    while (numTasksWithTuplesRead() < initialNumTasksRan + 2 && loopCount++ < 100) {
      sleep(10);
    }

    // unblock the tasks so they can finish running
    workerTasks.forEach(t -> t.unblock());
    waitForTasksToRun();
  }

  private void assertPagesQueued(LinkedListBuffer buffer, int expectedPages) {
    int expectedTuples = inputPageSize * expectedPages;
    assertEquals(expectedTuples, buffer.size());
  }

  @Test
  public void canRunUnblockedTaskToCompletion() {
    TaskSpec taskSpec = setupSchedulerForSingleTask();
    TestTask task = workerTasks.get(0);

    // scheduler shouldn't do anything because we haven't given the task any input yet
    assertFalse(scheduler.runOnce());

    addPages(input, 3);
    input.markComplete();
    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    assertEquals(ReturnState.COMPLETE, task.returnState);

    // there shouldn't be anything more to do, but this will cleanup the completed task
    assertFalse(scheduler.runOnce());
    assertEquals(0, scheduler.getWaitingTasks().size());
    assertEquals(0, scheduler.getRunningTasks().size());
    assertEquals(1, scheduler.getCompletedTasks().size());
    assertTrue(scheduler.allTasksComplete());

    assertPagesQueued(output, 3);

    assertStepsClosed(taskSpec);
  }

  @Test
  public void canRunUntilNoMoreInput() {
    TaskSpec taskSpec = setupSchedulerForSingleTask();
    TestTask task = workerTasks.get(0);

    // give it 2 pages of (incomplete) input
    addPages(input, 2);
    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    assertEquals(ReturnState.INPUT_EMPTY, task.returnState);
    // should be nothing to do until we give it more input
    assertFalse(scheduler.runOnce());
    assertPagesQueued(output, 2);

    // give it more input and it'll run again
    addPageToBuffer(input, 3);
    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    assertEquals(ReturnState.INPUT_EMPTY, task.returnState);
    assertPagesQueued(output, 3);

    input.markComplete();
    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    assertEquals(ReturnState.COMPLETE, task.returnState);
    assertPagesQueued(output, 3);

    assertStepsNotClosed(taskSpec);
  }

  @Test
  public void canRunUntilNoMoreOutput() {
    TaskSpec taskSpec = setupSchedulerForSingleTask();
    TestTask task = workerTasks.get(0);

    // give it 4 pages of input
    addPages(input, 4);
    input.markComplete();

    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    assertEquals(ReturnState.OUTPUT_FULL, task.returnState);
    assertTrue(output.isFull());

    // should be nothing to do until we free up output buffers
    assertFalse(scheduler.runOnce());
    assertPagesQueued(output, 3);

    // read a page of output to free up space
    output.read();
    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    assertEquals(ReturnState.COMPLETE, task.returnState);
    assertPagesQueued(output, 3);

    assertStepsNotClosed(taskSpec);
  }

  @Test
  public void canRunMultiThreadedTaskToCompletion() {
    outputCapacity = inputCapacity + 1;
    parallelizeTask = true;
    TaskSpec taskSpec = setupSchedulerForSingleTask();

    addPages(input, 3);
    input.markComplete();
    runTwoTasksInParallel();

    // check the tasks ran to completion
    assertAnyWorkerMatches(ReturnState.COMPLETE);
    assertPagesQueued(output, 3);

    // there should be nothing more to do (but we still need to cleanup
    // after the tasks that ran)
    assertFalse(scheduler.runOnce());
    assertTrue(scheduler.allTasksComplete());

    // check both tasks saw some input
    for (TestTask task : workerTasks) {
      assertTrue(task.tuplesRead.size() >= inputPageSize);
    }

    assertStepsClosed(taskSpec);
  }

  private void assertAnyWorkerMatches(ReturnState expected) {
    // Do to the multi-threaded nature of these tests, we can't really assert
    // that a specific worker should have a given return state. So here we assert at
    // least one worker (don't care which one) matches the state we expect.
    boolean match = workerTasks.stream().anyMatch(t -> t.returnState == expected);
    if (!match) {
      System.out.println(workerTasks);
      fail("No worker in state " + expected);
    }
  }

  @Test
  public void canNotCompleteWithPartiallyReadPage() {
    // make the output buffer not quite big enough for all the input
    outputPageSize = inputPageSize - 1;
    outputCapacity = 1;
    parallelizeTask = true;
    TaskSpec taskSpec = setupSchedulerForSingleTask();

    // add some input and try to run both tasks
    addPageToBuffer(input, 1);
    input.markComplete();
    runTwoTasksInParallel(1);

    // one should've blocked because the write buffer is full.
    // the other should have returned 'complete' because there's no other input
    assertAnyWorkerMatches(ReturnState.OUTPUT_FULL);
    assertAnyWorkerMatches(ReturnState.COMPLETE);

    // one task should be left with an incomplete input page
    assertEquals(1, workerTasks.stream().filter(t -> t.hasPageInProgress()).count());
    assertEquals(1, workerTasks.stream().filter(t -> !t.hasPageInProgress()).count());
    assertFalse(scheduler.allTasksComplete());

    assertStepsNotClosed(taskSpec);

    // they won't run again until there's more room in the write buffer
    assertFalse(scheduler.runOnce());
    assertFalse(scheduler.allTasksComplete());

    assertStepsNotClosed(taskSpec);

    // clear the output and check the task can run to completion
    output.read();
    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    assertEquals(0, workerTasks.stream().filter(t -> t.hasPageInProgress()).count());
    assertEquals(inputPageSize, output.numTuplesWritten());

    // there shouldn't be anything more to do
    assertFalse(scheduler.runOnce());
    assertTrue(scheduler.allTasksComplete());

    assertStepsClosed(taskSpec);
  }

  @Test
  public void canNotCompleteWithPartiallyWrittenPage() {
    // make the output buffer slightly bigger than the input
    outputPageSize = inputPageSize + 1;
    outputCapacity = 1;
    parallelizeTask = true;
    TaskSpec spec = setupSchedulerForSingleTask();

    // give the first task limited input and run it
    addPageToBuffer(input, 1);
    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    TestTask task = workerTasks.get(0);
    assertEquals(ReturnState.INPUT_EMPTY, task.returnState);
    // it should still have a partially written, unflushed output page
    assertTrue(task.hasPageInProgress());
    assertPagesQueued(output, 0);
    assertEquals(0, output.numTuplesWritten());

    // it won't run again until there's more input
    assertFalse(scheduler.runOnce());
    assertFalse(scheduler.allTasksComplete());

    assertStepsNotClosed(spec);

    // once the input is complete the task should run to completion
    // and all output gets flushed
    input.markComplete();
    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    assertEquals(ReturnState.COMPLETE, task.returnState);
    assertFalse(task.hasPageInProgress());
    assertPagesQueued(output, 1);
    assertEquals(inputPageSize, output.numTuplesWritten());

    // there shouldn't be anything more to do
    assertFalse(scheduler.runOnce());
    assertTrue(scheduler.allTasksComplete());

    assertStepsClosed(spec);
  }

  @Test
  public void canWaitUntilDependencyIsMet() {
    // create 2 tasks, the second one dependent on the first
    List<TaskSpec> specs = makeChainOfTasks(2);
    specs.get(1).addDependency(specs.get(0));
    setupScheduler(specs);

    // add some input so the first task can run
    addPages(input, inputCapacity);
    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    assertFalse(scheduler.runOnce());
    assertFalse(scheduler.allTasksComplete());

    // the downstream task should not be ready even though it has input
    TaskSpec downstreamSpec = specs.get(1);
    TestTask downstreamTask = (TestTask) downstreamSpec.getWorkerTasks().get(0);
    assertFalse(downstreamTask.isReadyToRun());
    assertFalse(downstreamSpec.getInput().get().isEmpty());
    assertPagesQueued(output, 0);
    assertEquals(0, downstreamTask.tuplesRead.size());

    // run the first (upstream) task to completion
    input.markComplete();
    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    scheduler.processWorkerResults();
    assertFalse(scheduler.allTasksComplete());

    assertStepsClosed(specs.get(0));
    assertStepsNotClosed(specs.get(1));

    // check the downstream task can now run to completion
    assertTrue(downstreamTask.isReadyToRun());
    assertTrue(scheduler.runOnce());
    waitForTasksToRun();
    scheduler.processWorkerResults();
    assertTrue(scheduler.allTasksComplete());
    assertPagesQueued(output, inputCapacity);

    assertStepsClosed(specs);
  }

  @Test
  public void canRunChainOfTasks() {
    parallelizeTask = true;
    int numInChain = 5;
    List<TaskSpec> specs = makeChainOfTasks(numInChain);
    setupScheduler(specs);

    scheduler.detectDeadlocks = true;

    // add some input so the first task can run, but will get blocked on output
    addPages(input, inputCapacity * 2);

    // move through the chain running each task, and then checking that the next one
    // is then ready to run
    for (int i = 0; i < numInChain - 1; i++) {
      // check the task is ready to run
      TestTask task = (TestTask) specs.get(i).getWorkerTasks().get(0);
      assertTrue(task.isReadyToRun());

      // but the next task isn't (no input yet)
      TestTask nextTaskInChain = (TestTask) specs.get(i + 1).getWorkerTasks().get(0);
      assertFalse(nextTaskInChain.isReadyToRun());

      // kick off 2 Workers/WorkerTasks for the same TaskSpec
      runTwoTasksInParallel();

      assertTrue(nextTaskInChain.isReadyToRun());
      assertTrue(task.tuplesRead.size() > 0);
      assertEquals(0, nextTaskInChain.tuplesRead.size());
      assertFalse(scheduler.allTasksComplete());
    }

    // now just run it as a free-for-all until everything completes
    addPages(input, inputCapacity * 2);
    input.markComplete();

    int loopCount = 0;
    while (!scheduler.allTasksComplete() || loopCount++ < 100) {
      // clear any output out of the last buffer (as there's no task
      // setup to read from it)
      output.read();

      // try to do something
      if (!scheduler.runOnce()) {
        // nothing to do, try again later
        sleep(10);
      }
    }

    // sanity-check all the input made it to the end of the pipeline
    assertTrue(scheduler.allTasksComplete());
    long expectedTuples = inputCapacity * 4 * inputPageSize;
    assertEquals(expectedTuples, output.numTuplesWritten());

    assertStepsClosed(specs);
  }

  @Test
  public void canRunDifferentTaskSpecsInParallel() {
    // create the task specs we'll use manually
    input = new LinkedListBuffer(inputPageSize, inputCapacity);
    output = new LinkedListBuffer(outputPageSize, outputCapacity);
    TaskSpec spec1 = new TaskSpec(TestTask.class, linkedSteps(), input, output, parallelizeTask, context);

    LinkedListBuffer output2 = new LinkedListBuffer(outputPageSize, outputCapacity);
    TaskSpec spec2 =
        new TaskSpec(TestTask.class, linkedSteps(), input.newReaderClone(), output2, parallelizeTask, context);

    setupScheduler(Arrays.asList(spec1, spec2));

    // add some input and check both tasks can run to completion
    addPages(input, inputCapacity);
    input.markComplete();
    runTwoTasksInParallel();
    scheduler.processWorkerResults();

    assertTrue(scheduler.allTasksComplete());
    long expectedTuples = inputCapacity * inputPageSize;
    assertEquals(expectedTuples, output.numTuplesWritten());
    assertEquals(expectedTuples, output2.numTuplesWritten());

    for (TestTask task : workerTasks) {
      assertEquals(expectedTuples, task.tuplesRead.size());
    }

    assertStepsClosed(spec1, spec2);
  }

  public static class ThrowingTask extends WorkerTask {

    static Problem pickledProblem;

    public ThrowingTask(TaskSpec spec) throws ProblemException {
      super(spec);
      throw new ProblemException(pickledProblem);
    }

    @Override
    public ReturnState run() {
      return null;
    }
    @Override
    public boolean producesResult() {
      return false;
    }

  }

  @Test
  public void failsIfAnyTasksFailToConstructWithProblemExceptionsWrappedInExecutionException() throws Exception {
    ThrowingTask.pickledProblem = Problems.foundWith("your face");
    TaskSpec throwingSpec =
        new TaskSpec(ThrowingTask.class, linkedSteps(), input, output, parallelizeTask, context);
    TaskSpec okSpec =
        new TaskSpec(TestTask.class, linkedSteps(), input, output, parallelizeTask, context);
    TaskSpec anotherThrowingSpec =
        new TaskSpec(ThrowingTask.class, linkedSteps(), input, output, parallelizeTask, context);

    ProblemException ex = Assert.assertThrows(ProblemException.class,
        () -> scheduler.addTasks(Arrays.asList(throwingSpec, okSpec, anotherThrowingSpec)));

    // the code should attempt to instantiate all tasks, that way the user can try to fix all the problems in one go
    assertThat(
      ex.getProblems(),
      contains(
        equalTo(ThrowingTask.pickledProblem),
        equalTo(ThrowingTask.pickledProblem)
      )
    );
  }

  private void assertStepsNotClosed(TaskSpec... taskSpecs) {
    assertStepsNotClosed(Arrays.asList(taskSpecs));
  }
  private void assertStepsClosed(TaskSpec... taskSpecs) {
    assertStepsClosed(Arrays.asList(taskSpecs));
  }

  private void assertStepsNotClosed(List<TaskSpec> taskSpecs) {
    for (TaskSpec taskSpec : taskSpecs) {
      for (RealizedStep realizedStep : taskSpec.getForSteps()) {
        verify(realizedStep.getResult().get(), times(0)).close();
      }
    }
  }
  private void assertStepsClosed(List<TaskSpec> taskSpecs) {
    for (TaskSpec taskSpec : taskSpecs) {
      for (RealizedStep realizedStep : taskSpec.getForSteps()) {
        verify(realizedStep.getResult().get(), times(1)).close();
      }
    }
  }
}
