/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.sched;

import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeMatcher;
import org.junit.Before;
import org.junit.Test;

import com.google.common.collect.ImmutableMap;

import lombok.NonNull;
import nz.org.riskscape.config.BootstrapIniSettings;
import nz.org.riskscape.engine.ProjectTest;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.pipeline.Collector;
import nz.org.riskscape.engine.pipeline.Realized;
import nz.org.riskscape.engine.pipeline.RealizedPipeline;
import nz.org.riskscape.engine.pipeline.RealizedStep;
import nz.org.riskscape.engine.pipeline.TestPipelineJobContext;
import nz.org.riskscape.engine.projection.AsyncProjector;
import nz.org.riskscape.engine.projection.FlatProjector;
import nz.org.riskscape.engine.projection.Projector;
import nz.org.riskscape.engine.relation.ListRelation;
import nz.org.riskscape.engine.relation.Relation;
import nz.org.riskscape.engine.restriction.Restrictor;
import nz.org.riskscape.engine.steps.AsyncSelectStep;
import nz.org.riskscape.engine.steps.FilterStep;
import nz.org.riskscape.engine.steps.GroupByStep;
import nz.org.riskscape.engine.steps.JoinStep;
import nz.org.riskscape.engine.steps.RelationInputStep;
import nz.org.riskscape.engine.steps.SelectStep;
import nz.org.riskscape.engine.steps.UnnestStep;
import nz.org.riskscape.engine.task.ChainTask;
import nz.org.riskscape.engine.task.LinkedSteps;
import nz.org.riskscape.engine.task.SinkTask;
import nz.org.riskscape.engine.task.TaskSpec;
import nz.org.riskscape.engine.task.TupleInputTask;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.pipeline.PipelineParser;
import nz.org.riskscape.pipeline.ast.PipelineDeclaration;
import nz.org.riskscape.problem.ProblemSink;

public class TaskBuilderTest extends ProjectTest {

  TestPipelineJobContext jobContext;
  PipelineDeclaration pipeline;
  Struct type = Types.TEXT.asStruct();
  List<Tuple> tuples = new ArrayList<>();
  ListRelation relation = new ListRelation(type, tuples);
  SchedulerBasedExecutor executor = new SchedulerBasedExecutor(engine);

  SchedulerParams params = new SchedulerParams(2);
  TaskBuilder subject = new TaskBuilder(params);

  @Before
  public void setupSteps() {
    engine.getPipelineSteps().add(new SelectStep(engine));
    engine.getPipelineSteps().add(new FilterStep(engine));
    engine.getPipelineSteps().add(new RelationInputStep(engine));
    engine.getPipelineSteps().add(new JoinStep(engine));
    engine.getPipelineSteps().add(new GroupByStep(engine));
    engine.getPipelineSteps().add(new UnnestStep(engine));
    engine.getPipelineSteps().add(new AsyncSelectStep(engine));
  }

  @Test
  public void testALinearDag() throws Exception {
    RealizedPipeline realized;
    List<LinkedSteps> decomposed;
    LinkedSteps linked;

    // realize a pipeline with a single step
    realized = realize(parse("input(value: 'hi')"));
    decomposed = subject.decompose(realized);
    assertEquals(1, decomposed.size());

    // add a 2nd step to the pipeline
    realized = realize(parse("input(value: 'hi') -> filter(false)"));
    decomposed = subject.decompose(realized);
    assertEquals(2, decomposed.size());
    // check first step is predecessor for 2nd step (and vice versa for descendants)
    assertRelated(decomposed, 0, 1);

    // add a 3rd step that can be chained together with the 2nd step
    realized = realize(parse("input(value: 'hi') -> filter(true) -> select(*)"));
    decomposed = subject.decompose(realized);
    assertEquals(2, decomposed.size());
    linked = decomposed.get(1);
    assertThat(linked.getGrouped(), contains(producesResult(Restrictor.class), producesResult(Projector.class)));

    // ...and add another step to the chain
    realized = realize(parse("""
        input(value: 'hi') -> filter(true) -> select({[*] as list}) -> unnest(list)
        """));
    decomposed = subject.decompose(realized);
    assertEquals(2, decomposed.size());
    linked = decomposed.get(1);
    assertThat(linked.getGrouped(), contains(producesResult(Restrictor.class), producesResult(Projector.class),
        producesResult(FlatProjector.class)));

    // then add an collector step that can't be chained
    realized = realize(parse("""
        input(value: 'hi') -> filter(true) -> select({[*] as list}) -> unnest(list) -> group(select: list, by: list)
        """));
    decomposed = subject.decompose(realized);
    assertEquals(3, decomposed.size());
    linked = decomposed.get(2);
    assertThat(linked.getGrouped(), contains(producesResult(Collector.class)));
    assertRelated(decomposed, 0, 1);
    assertRelated(decomposed, 1, 2);

    // then add 2 more steps that can be chained together after the collector
    realized = realize(parse("""
        input(value: 'hi') -> filter(true) -> select({[*] as list}) -> unnest(list) -> group(select: list, by: list)
        -> select(*) -> select(*)
        """));

    decomposed = subject.decompose(realized);
    assertEquals(4, decomposed.size());
    linked = decomposed.get(3);
    assertThat(linked.getGrouped(), contains(producesResult(Projector.class), producesResult(Projector.class)));

    assertRelated(decomposed, 0, 1);
    assertRelated(decomposed, 1, 2);
    assertRelated(decomposed, 2, 3);
  }

  private @NonNull PipelineDeclaration parse(String string) {
    pipeline = PipelineParser.INSTANCE.parsePipeline(string);
    return pipeline;
  }

  @Test
  public void testDecomposingADagWithFanIn() throws Exception {

    RealizedPipeline realized = realize(parse("""
        input(value: 'foo', name: 'foo') as lhs -> join.lhs
        input(value: 'bar', name: 'bar') as rhs -> join.rhs

        join(on: true)
        """));
    List<LinkedSteps> decomposed = subject.decompose(realized);
    assertEquals(3, decomposed.size());

    assertRelated(decomposed, "lhs", "join");
    assertRelated(decomposed, "rhs", "join");
  }

  @Test
  public void testDecomposingADagWithFanInAndFanOut() throws Exception {
    parse("""
        input(value: 'foo', name: 'foo') as r1 -> join.lhs
        input(value: 'bar', name: 'bar') as r2 -> join.rhs
        join(on: true)
        join -> group(select: foo, by: foo) as c1
        join -> group(select: bar, by: bar) as c2
        """);

    RealizedPipeline realized = realize(pipeline);
    List<LinkedSteps> decomposed = subject.decompose(realized);
    assertEquals(5, decomposed.size());
    assertRelated(decomposed, "r1", "join");
    assertRelated(decomposed, "r2", "join");
    assertRelated(decomposed, "join", "c1");
    assertRelated(decomposed, "join", "c2");
  }

  @Test
  public void testALopsidedJoinerWithPredecessors() throws Exception {

    parse("""
       input(value: 'lhs', name: 'lhs') as lhs -> filter(true) as "lhs-filter" -> select(*) as "lhs-mapper" -> join
       input(value: 'rhs', name: 'rhs') as rhs -> join.rhs
       join(on: true)
       """);

    RealizedPipeline realized = realize(pipeline);
    List<LinkedSteps> decomposed = subject.decompose(realized);
    assertEquals(4, decomposed.size());

    assertRelated(decomposed, "lhs", "lhs-filter");
    assertGrouped(decomposed, "lhs-filter", "lhs-mapper");
    assertRelated(decomposed, "lhs-filter", "join");
    assertRelated(decomposed, "rhs", "join");

    parse("""
        input(value: 'lhs', name: 'lhs') as lhs -> filter(true) as "lhs-filter" -> select(*) as "lhs-mapper" -> join
        input(value: 'rhs', name: 'rhs') as rhs -> join.rhs
        join(on: true) -> filter(true)
        """);

    realized = realize(pipeline);
    decomposed = subject.decompose(realized);
    assertEquals(5, decomposed.size());
  }

  @Test
  public void testFanOutBreaksChain() throws Exception {
    // add a chainable step that fans out
    parse("""
        input(value: 'foo') -> filter(true) as chain1 -> select(*) as chain2
        chain2 -> filter(true) as fanout1
        chain2 -> select(*) as fanout2
        """);

    RealizedPipeline realized = realize(pipeline);
    List<LinkedSteps> decomposed = subject.decompose(realized);
    assertEquals(4, decomposed.size());
    assertRelated(decomposed, "input", "chain1");
    assertGrouped(decomposed, "chain1", "chain2");
    assertRelated(decomposed, "chain1", "fanout1");
    assertRelated(decomposed, "chain1", "fanout2");
  }

  @Test
  public void willCapUncappedStepsIfNecessary() throws Exception {
    parse("input(value: 'foo')");
    RealizedPipeline realized = realize(pipeline);

    jobContext = new TestPipelineJobContext(realized);
    List<LinkedSteps> steps = subject.decompose(jobContext.getPipeline());
    assertThat(steps, contains(
        hasProperty("grouped", contains(producesResult(Relation.class)))
    ));

    List<TaskSpec> tasks = subject.convertToTasks(jobContext, steps);
    assertThat(tasks, contains(
        hasProperty("workerTaskClass", is(TupleInputTask.class)),
        hasProperty("workerTaskClass", is(SinkTask.class))
    ));
  }

  @Test
  public void willCapUncappedStepsIfNecessaryFromTheLastStepInChain() throws InterruptedException, Exception {
    // build a pipeline with steps that can be chained together
    parse("input(value: 'foo') -> filter(true) -> select(*)");
    RealizedPipeline realized = realize(pipeline);

    jobContext = new TestPipelineJobContext(realized);
    List<LinkedSteps> steps = subject.decompose(jobContext.getPipeline());
    assertThat(steps, contains(
        hasProperty("grouped", contains(
            producesResult(Relation.class)
        )),
        hasProperty("grouped", contains(
            producesResult(Restrictor.class),
            producesResult(Projector.class)
        ))
    ));

    List<TaskSpec> tasks = subject.convertToTasks(jobContext, steps);
    assertThat(tasks, contains(
        hasProperty("workerTaskClass", is(TupleInputTask.class)),
        hasProperty("workerTaskClass", is(ChainTask.class)),  // filter and mapping are grouped in this task
        hasProperty("workerTaskClass", is(SinkTask.class))
    ));
  }

  @Test
  public void canHandleAnAsyncProjector() throws Exception {
    parse("input(value: 'foo') -> filter(true) -> async_select(*) -> select(*)");
    RealizedPipeline realized = realize(pipeline);

    jobContext = new TestPipelineJobContext(realized);
    List<LinkedSteps> steps = subject.decompose(jobContext.getPipeline());
    assertThat(steps, contains(
        hasProperty("grouped", contains(
            producesResult(Relation.class)
        )),
        hasProperty("grouped", contains(
            producesResult(Restrictor.class)
        )),
        hasProperty("grouped", contains(
            producesResult(AsyncProjector.class)
        )),
        hasProperty("grouped", contains(
            producesResult(Projector.class)
        ))
    ));

    List<TaskSpec> tasks = subject.convertToTasks(jobContext, steps);
    assertThat(tasks, contains(
        hasProperty("workerTaskClass", is(TupleInputTask.class)),
        hasProperty("workerTaskClass", is(ChainTask.class)),
        hasProperty("workerTaskClass", is(SinkTask.class)),
        hasProperty("workerTaskClass", is(TupleInputTask.class)),
        hasProperty("workerTaskClass", is(ChainTask.class)),
        hasProperty("workerTaskClass", is(SinkTask.class))
    ));
  }

  @Test
  public void canConfigureBufferSettings() throws Exception {
    // default settings
    assertThat(params.getPageSize(), is(Optional.empty()));
    assertThat(params.getTuplesPerTask(), is(Optional.empty()));

    BootstrapIniSettings settings = new BootstrapIniSettings(
        ImmutableMap.of(
            SchedulerParams.PAGE_SIZE_SETTING, Arrays.asList("123"),
            SchedulerParams.TUPLES_PER_TASK_SETTING, Arrays.asList("456")
        ),
        ProblemSink.DEVNULL);
    SchedulerParams paramsWithConfig = new SchedulerParams(2, settings);
    assertThat(paramsWithConfig.getPageSize(), is(Optional.of(123)));
    assertThat(paramsWithConfig.getTuplesPerTask(), is(Optional.of(456)));
  }

  private RealizedPipeline realize(PipelineDeclaration ast) {
    return executionContext.realize(ast);
  }

  private void assertGrouped(List<LinkedSteps> decomposed, String... stepNames) {
    List<String> stepNamesList = Arrays.asList(stepNames);
    LinkedSteps found = null;

    // find the linked step that matches
    for (LinkedSteps linkedSteps : decomposed) {
      for (RealizedStep step : linkedSteps.getGrouped()) {
        if (stepNamesList.contains(step.getStepName())) {
          if (found == null) {
            found = linkedSteps;
          } else {
            // we should only find one matching linked step
            if (found != linkedSteps) {
              fail("step " + step + " was found in " + linkedSteps + ", not " + found);
            }
          }
        }
      }
    }

    if (found == null) {
      fail("did not find any steps containing " + stepNamesList);
    } else {
      // check the grouped steps match exactly
      List<String> foundStepNames = found.getGrouped().stream().map(RealizedStep::getStepName)
          .collect(Collectors.toList());
      assertEquals(stepNamesList, foundStepNames);
    }
  }

  private LinkedSteps findStep(List<LinkedSteps> decomposed, String withStepName) {
    for (LinkedSteps linkedSteps : decomposed) {
      if (linkedSteps.getGrouped().stream().anyMatch(rs -> rs.getStepName().equals(withStepName))) {
        return linkedSteps;
      }
    }
    return null;
  }

  private void assertRelated(List<LinkedSteps> decomposed, String parentStep, String childStep) {
    LinkedSteps parent = findStep(decomposed, parentStep);
    LinkedSteps child = findStep(decomposed, childStep);

    assertThat(parent.getDescendants(), hasItem(child));
    assertThat(child.getPredecessors(), hasItem(parent));
  }

  private void assertRelated(List<LinkedSteps> decomposed, int i, int j) {
    LinkedSteps parent = decomposed.get(i);
    LinkedSteps child = decomposed.get(j);

    assertThat(parent.getDescendants(), hasItem(child));
    assertThat(child.getPredecessors(), hasItem(parent));
  }

  private Matcher<RealizedStep> producesResult(Class<? extends Realized> expected) {
    return new TypeSafeMatcher<RealizedStep>(RealizedStep.class) {

      @Override
      public void describeTo(Description description) {
        description.appendValue("step with result ").appendValue(expected);
      }

      @Override
      protected void describeMismatchSafely(RealizedStep item, Description mismatchDescription) {
        mismatchDescription.appendText("step with result ").appendValue(item.getResult().orElse(null));
      }

      @Override
      protected boolean matchesSafely(RealizedStep item) {
        return expected.isInstance(item.getResult().get());
      }
    };
  }
}
