/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.engine.function;


import java.util.ArrayList;
import java.util.List;

import lombok.RequiredArgsConstructor;
import nz.org.riskscape.engine.ArgsProblems;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.problem.ProblemFactory;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.rl.agg.Accumulator;
import nz.org.riskscape.engine.rl.agg.AggregationFunction;
import nz.org.riskscape.engine.rl.agg.RealizedAggregateExpression;
import nz.org.riskscape.engine.types.LambdaType;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.ExpressionProblems;
import nz.org.riskscape.rl.ast.FunctionCall;
import nz.org.riskscape.rl.ast.Lambda;
import nz.org.riskscape.rl.ast.StructDeclaration;

/**
 * An {@link AggregationFunction} that is composed of one or more child {@link AggregationFunction}.
 *
 * Expected to be used for aggregating multiple attributes of a {@link Struct} but can also be used
 * to aggregate a simple value with multiple aggregations.
 *
 * In either case a {@link Struct} is produced.
 */
public class CompositeAggregationFunction implements AggregationFunction {

  public interface LocalProblems extends ProblemFactory {

    static LocalProblems get() {
      return Problems.get(LocalProblems.class);
    }

    /**
     * Used when struct declaration uses lambdas with differing argument names. We enforce the same
     * name to be used for each because they do all refer to the same thing.
     *
     * @param seen newly used lambda arg
     * @param expected the lambda arg expected to be used by all lambdas
     */
    Problem lambdaArgsDiffer(String seen, String expected);
  }

  @RequiredArgsConstructor
  private static class AccumInstance implements Accumulator {

    /**
     * {@link RealizedExpression} to obtain the value (tuple) to be accumulated
     */
    private final RealizedExpression valueExpression;

    /**
     * The {@link Struct} type that the child accumulator expects. This will be a single attr
     * struct whose only attr will be of the type returned be value expression.
     */
    private final Struct lambdaType;
    private final Accumulator child;

    @Override
    public Accumulator combine(Accumulator other) {
      AccumInstance rhs = (AccumInstance) other;

      return new AccumInstance(valueExpression, lambdaType, child.combine(rhs.child));
    }

    @Override
    public void accumulate(Object input) {
      Object toAccumulate = valueExpression.evaluate(input);
      if (toAccumulate == null) {
        return;
      }
      child.accumulate(Tuple.ofValues(lambdaType, toAccumulate));
    }

    @Override
    public Object process() {
      return child.process();
    }

    @Override
    public boolean isEmpty() {
      return child.isEmpty();
    }

  }

  /**
   * Processes the childExpr which is expected to be a {@link StructDeclaration} containing {@link Lambda}s
   * of matching types. Returns a {@link Pair} containing a {@link StructDeclaration} that is modified
   * to have the lambda expressions replaced and a {@link Struct} type that the modified struct declaration
   * can be realized against.
   */
  private Pair<StructDeclaration, Struct> getChildInputType(Expression childExpr, Type itemType)
      throws ProblemException {

    StructDeclaration childStructDecl = childExpr.isA(StructDeclaration.class)
        .orElseThrow(() -> new ProblemException(ExpressionProblems.get().mismatch(childExpr, StructDeclaration.class,
            "{mean: v -> mean(v.attr)}")));

    LambdaType type = null;
    List<StructDeclaration.Member> kludgedMembers = new ArrayList<>(childStructDecl.getMembers().size());
    for (StructDeclaration.Member member : childStructDecl.getMembers()) {
      Lambda lambda = member.getExpression().isA(Lambda.class)
          .orElseThrow(() -> new ProblemException(
              ExpressionProblems.get().mismatch(member.getExpression(), Lambda.class, "v -> mean(v.attr)")));

      if (type == null) {
        type = LambdaType.create(lambda);
        if (type.getArgs().size() != 1) {
          // error, only one expected
          throw new ProblemException(ExpressionProblems.get().lambdaArityError(
              lambda,
              type.getArgs().size(),
              1
          ));
        }
      } else {
        LambdaType newType = LambdaType.create(lambda);
        if (! newType.getArgs().equals(type.getArgs())) {
          // this will occur if different lambda arguments are used
          throw new ProblemException(
              LocalProblems.get().lambdaArgsDiffer(newType.getArgs().get(0), type.getArgs().get(0)));
        }
      }

      // And we replace the lambda expression with its child expression.
      kludgedMembers.add(member.cloneWithExpression(lambda.getExpression()));
    }

    return Pair.of(childStructDecl.withNewMembers(kludgedMembers, childStructDecl.getBoundary()),
        Struct.of(type.getArgs().get(0), itemType));
  }

  @Override
  public ResultOrProblems<RealizedAggregateExpression> realize(RealizationContext context, Type inputType,
      FunctionCall fc) {

    return ProblemException.<RealizedAggregateExpression>catching(() -> {

      if (fc.getArguments().size() != 2) {
        // need value expression and a struct declaration
        throw new ProblemException(ArgsProblems.get().wrongNumber(2, fc.getArguments().size()));
      }

      // get and validate type to aggregate. this is always expected to be in arg: 0
      RealizedExpression itemExpression = context.getExpressionRealizer()
          .realize(inputType, fc.getArguments().get(0).getExpression())
          .getOrThrow(Problems.foundWith(fc.getArguments().get(0)));

      Expression childExpr = fc.getArguments().get(1).getExpression();
      Pair<StructDeclaration, Struct> lambdaType = getChildInputType(childExpr, itemExpression.getResultType());

      RealizedAggregateExpression childAgg = context.getExpressionRealizer()
          .realizeAggregate(lambdaType.getRight(), lambdaType.getLeft())
          .getOrThrow(Problems.foundWith(fc.getArguments().get(1)));

      return new RealizedAggregateExpression() {

        @Override
        public Accumulator newAccumulator() {
          return new AccumInstance(itemExpression, lambdaType.getRight(), childAgg.newAccumulator());
        }

        @Override
        public Type getResultType() {
          return childAgg.getResultType();
        }

        @Override
        public Type getInputType() {
          return inputType;
        }

        @Override
        public Expression getExpression() {
          return fc;
        }
      };
    }).composeProblems(Problems.foundWith(fc));
  }

}
