/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.steps;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.ListIterator;
import java.util.Optional;

import com.google.common.collect.Lists;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import nz.org.riskscape.engine.Engine;
import nz.org.riskscape.engine.Project;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.bind.ParameterField;
import nz.org.riskscape.engine.pipeline.Collector;
import nz.org.riskscape.engine.pipeline.RealizationInput;
import nz.org.riskscape.engine.pipeline.Realized;
import nz.org.riskscape.engine.pipeline.RealizedStep;
import nz.org.riskscape.engine.problem.ProblemFactory;
import nz.org.riskscape.engine.relation.TupleIterator;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.sort.SortBy;
import nz.org.riskscape.engine.sort.SortBy.Direction;
import nz.org.riskscape.engine.sort.TupleComparator;
import nz.org.riskscape.engine.types.LambdaType;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.typexp.BadTypeExpressionException;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ExpressionParser;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.ExpressionProblems;
import nz.org.riskscape.rl.ast.Lambda;
import nz.org.riskscape.rl.ast.ListDeclaration;
import nz.org.riskscape.util.ListUtils;

/**
 * Sorts output based on user configured {@link SortBy} expressions.
 */
public class SortStep extends BaseStep<SortStep.Parameters> {

  interface LocalProblems extends ProblemFactory {
    static LocalProblems get() {
      return Problems.get(LocalProblems.class);
    }

    Problem deltaAttributeAlreadyExists(String memberName, Struct existing);
    Problem deltaTypeDifferentToActual(Type specifiedType, Type actualReturnType);
  }

  public static class Parameters {

    @Input
    RealizedStep input;

    @ParameterField
    Expression by;

    @ParameterField
    List<SortBy.Direction> direction;

    @ParameterField
    Optional<Lambda> delta;

    @ParameterField
    Optional<String> deltaAttribute = Optional.of("delta");

    @ParameterField
    Optional<String> deltaType;


    //Injected field
    public RealizationInput rInput;
  }

  @RequiredArgsConstructor
  static class InMemorySortCollector implements Collector<List<Tuple>> {

    @Getter
    private final Struct sourceType;

    @Getter
    private final Struct producedType;

    private final Comparator<Tuple> comparator;

    // Non-private for testing
    final Optional<RealizedExpression> deltaExpression;

    /**
     * If set to true, when calculating the delta expression will store the result
     * and add it to `previous` so that it can be used in the next delta calculation.
     */
    private final boolean cascade;


    @Override
    public List<Tuple> newAccumulator() {
      return Lists.newArrayList();
    }

    @Override
    public void accumulate(List<Tuple> accumulator, Tuple tuple) {
      accumulator.add(tuple);
    }

    @Override
    public List<Tuple> combine(List<Tuple> lhs, List<Tuple> rhs) {
      return ListUtils.concat(lhs, rhs);
    }

    @Override
    public Optional<Long> size(List<Tuple> accumulator) {
      return Optional.of((long) accumulator.size());
    }

    @Override
    public TupleIterator process(List<Tuple> accumulator) {
      // we always start by sorting the collected tuples
      accumulator.sort(comparator);

      if (deltaExpression.isPresent()) {
        // if delta expression is present we need to calculate the delta between the tuples and add it
        // in to the tuples
        List<Tuple> accumulatorWithDeltas = Lists.newArrayListWithCapacity(accumulator.size());

        RealizedExpression deltaExpr = deltaExpression.get();
        Struct deltaInputType = deltaExpr.getInputType().find(Struct.class).get();

        Tuple previous = null;
        for (Tuple current : accumulator) {
          // we need to make a Tuple with prev/current tuples to evaluate
          Tuple deltaInput = Tuple.ofValues(deltaInputType, previous, current);
          Object delta = deltaExpr.evaluate(deltaInput);

          Tuple currentWithDelta = new Tuple(producedType);
          currentWithDelta.setAll(current);
          currentWithDelta.set(producedType.size() - 1, delta);

          accumulatorWithDeltas.add(currentWithDelta);

          // If we're not cascading, can't pass in currentWithDelta because it's
          // the wrong type.
          previous = cascade ? currentWithDelta : current;
        }
        return getTupleIterator(accumulatorWithDeltas);
      }

      return getTupleIterator(accumulator);
    }

    private TupleIterator getTupleIterator(List<Tuple> list) {
      // we reverse the list and then iterate through the items backwards.
      // This lets us free up the tuples efficiently as we go (whereas ArrayList.remove()
      // would be notoriously slow if we iterated through the list forwards)
      List<Tuple> copy = new ArrayList<>(list);
      Collections.reverse(copy);
      ListIterator<Tuple> iterator = copy.listIterator(copy.size());
      return new TupleIterator() {
        @Override
        public Tuple next() {
          return iterator.previous();
        }

        @Override
        public void remove() {
          iterator.remove();
        }

        @Override
        public boolean hasNext() {
          return iterator.hasPrevious();
        }
      };
    }

    @SuppressWarnings({ "rawtypes", "unchecked" })
    @Override
    public Class<List<Tuple>> getAccumulatorClass() {
      // defeat java's type safety
      Class list = List.class;
      return list;
    }

  }

  public SortStep(Engine engine) {
    super(engine);
  }

  private Type normalize(RealizationContext context, Type given) {
    // strip off any nullability, normalize the struct, then rewrap it
    return Nullable.rewrap(given,
              type -> type.findAllowNull(Struct.class)
                .map(s -> (Type) context.normalizeStruct(s))
                .orElse(type));
  }


  @Override
  public ResultOrProblems<? extends Realized> realize(Parameters parameters) {
    return ProblemException.catching(() -> {
      Struct sourceType = parameters.input.getProduces();

      //First we need to concatenated all the sortbys into one.
      SortBy sortBy = buildSortBy(parameters);

      Project project = parameters.rInput.getExecutionContext().getProject();

      Type previousInputType = null;
      Struct deltaInput = sourceType;
      if (parameters.deltaType.isPresent()) {
        RealizationContext realizationContext = parameters.rInput.getRealizationContext();
        try {
          previousInputType = normalize(realizationContext, project.getTypeBuilder().build(parameters.deltaType.get()));
        } catch(BadTypeExpressionException e) {
          throw new ProblemException(e.getProblem());
        }

        deltaInput = sourceType.add(parameters.deltaAttribute.get(), previousInputType);
      }

      Optional<RealizedExpression> deltaExpression;
      if (parameters.delta.isPresent()) {
        deltaExpression = Optional.of(realizeDeltaExpression(parameters.delta.get(), deltaInput, sourceType,
            parameters.rInput.getRealizationContext()));

        // If deltaType is provided, make sure that the *actual* return type of delta lambda is compatible
        // with it. If we don't check it here, users can hit some nasty error messages when we actually
        // process the realized expression.
        if (previousInputType != null) {

          Type actualReturnType = deltaExpression.get().getResultType();
          if (!project.getTypeSet().isAssignable(actualReturnType, previousInputType)) {
            throw new ProblemException(
              LocalProblems.get().deltaTypeDifferentToActual(previousInputType, actualReturnType)
            );
          }
        }
      } else {
        deltaExpression = Optional.empty();
      }

      Struct targetType = getTargetType(parameters.input.getProduces(), deltaExpression, parameters.deltaAttribute);

      boolean cascade = parameters.deltaType.isPresent();
      return TupleComparator.createComparator(sourceType, sortBy,
          parameters.rInput.getExecutionContext().getExpressionRealizer())
          .map(c -> new InMemorySortCollector(parameters.input.getProduces(), targetType, c, deltaExpression, cascade))
          .getOrThrow();
    });
  }

  private SortBy buildSortBy(Parameters parameters) throws ProblemException {
    List<SortBy> sortBys = new ArrayList<>();

    ListDeclaration sortByAsList = ExpressionParser.INSTANCE.toList(parameters.by);
    for (int i = 0; i < sortByAsList.getElements().size(); i++) {
      Expression expr = sortByAsList.getElements().get(i);
      Direction dir = Direction.ASC;
      if (parameters.direction.size() > i) {
        dir = parameters.direction.get(i);
      }
      sortBys.add(new SortBy(expr, dir));
    }

    return SortBy.concatenate(sortBys);
  }

  private Struct getTargetType(Struct sourceType, Optional<RealizedExpression> deltaExpression,
      Optional<String> deltaAttribute) throws ProblemException {
    // The target type starts off being the source type
    Struct targetType = sourceType;
    if (deltaExpression.isPresent()) {
      // but if there is a delta expression then the delta attribute must be appended
      if (targetType.getEntry(deltaAttribute.get()) != null) {
        throw new ProblemException(
            LocalProblems.get().deltaAttributeAlreadyExists(deltaAttribute.get(), sourceType));
      }
      targetType = targetType.add(deltaAttribute.get(), deltaExpression.get().getResultType());
    }
    return targetType;
  }

  private RealizedExpression realizeDeltaExpression(Lambda lambda, Struct prevType, Struct currType,
      RealizationContext context) throws ProblemException {
    LambdaType type = LambdaType.create(lambda);
    if (type.getArgs().size() != 2) {
      // error, two lambda args expected
      throw new ProblemException(ExpressionProblems.get().lambdaArityError(
          lambda,
          type.getArgs().size(),
          2
      ));
    }
    Struct lambdaInput = Struct.of(
        type.getArgs().get(0), Nullable.of(prevType),  // first arg (previous) is null because it will be
                                                        // null on the first call as there is no previous tuple.
        type.getArgs().get(1), currType);

    return context.getExpressionRealizer().realize(lambdaInput, lambda.getExpression()).getOrThrow();
  }

}
