/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.defaults.classifier;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import com.google.common.collect.Lists;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import nz.org.riskscape.dsl.LexerException;
import nz.org.riskscape.dsl.Token;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Struct.StructBuilder;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.typeset.MissingTypeException;
import nz.org.riskscape.engine.typexp.TypeBuilder;
import nz.org.riskscape.engine.typexp.TypeBuildingException;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ExpressionParser;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.ExpressionProblems;

/**
 * AST for classifier functions
 * @author nickg
 *
 */
public abstract class AST {

  public abstract Token getBoundaryToken();

  public Token getIdentifier() {
    return getBoundaryToken();
  }

  @EqualsAndHashCode(callSuper = false) @ToString @RequiredArgsConstructor
  public static class FunctionDecl extends AST {
    public final Token start;

    public final Optional<Metadata> id;
    public final Optional<Metadata> description;
    public final Optional<Metadata> category;

    @Getter
    public final StructType argumentTypesDecl;
    public final Optional<TypeDecl> returnTypeDecl;

    public final Optional<ExpressionDecl> pre;

    public final List<Filter> body;
    public final Optional<ExpressionDecl> defaultExpr;

    public final Optional<ExpressionDecl> post;

    public List<Problem> parseTypes(TypeBuilder builder) {
      List<Problem> problems = new ArrayList<Problem>();

      returnTypeDecl.ifPresent(rtd -> rtd.build(problems, builder));
      argumentTypesDecl.build(problems, builder);

      return problems;
    }

    public List<Problem> parseExpressions(ExpressionParser parser) {
      List<Problem> problems = new ArrayList<Problem>();

      pre.ifPresent(expr -> expr.build(problems, parser));
      post.ifPresent(expr -> expr.build(problems, parser));
      defaultExpr.ifPresent(expr -> expr.build(problems, parser));
      body.stream().forEach(filter -> filter.build(problems, parser));

      return problems;
    }

    public Set<String> scanStructKeys() {
      Set<String> keys = new LinkedHashSet<>();

      pre.ifPresent(p -> keys.addAll(p.scanStructKeys()));
      post.ifPresent(p -> keys.addAll(p.scanStructKeys()));
      defaultExpr.ifPresent(p -> keys.addAll(p.scanStructKeys()));
      body.stream().forEach(b -> keys.addAll(b.scanStructKeys()));

      return keys;
    }

    public Struct getInputType() {
      return argumentTypesDecl.built;
    }

    @Override
    public Token getBoundaryToken() {
      return start;
    }
  }

  @EqualsAndHashCode(callSuper = false) @ToString @RequiredArgsConstructor
  public static class Metadata extends AST {
    public final Token identifier;
    public final Token value;

    public String value() {
      return value.value;
    }

    @Override
    public Token getBoundaryToken() {
      return identifier;
    }

  }

  interface TypeDecl {
    Token getIdentifier();
    Type getBuilt();
    void build(List<Problem> problems, TypeBuilder builder);
  }

  @EqualsAndHashCode(callSuper = false) @ToString @RequiredArgsConstructor
  public static class SimpleType extends AST implements TypeDecl {

    @Getter
    public final Token identifier;

    @Getter
    public final Token expression;

    @Getter
    public Type built;

    @Override
    public void build(List<Problem> problems, TypeBuilder builder) {
      try {
        built = builder.build(expression.value);
      } catch (TypeBuildingException | MissingTypeException ex) {
        problems.add(Problems.foundWith(getIdentifier(),
            ExpressionProblems.cannotParse(Type.class, expression.value)
                .withChildren(Problems.caught(ex))));
      }
    }

    @Override
    public Token getBoundaryToken() {
      return identifier;
    }
  }

  // TODO the AST allows nesting, but this adds quite a bit of complexity to the type inference, so maybe dont?
  @EqualsAndHashCode(callSuper = false) @ToString @RequiredArgsConstructor
  public static class StructType extends AST implements TypeDecl {

    @Getter
    public final Token identifier;
    public final List<TypeDecl> children;

    @Getter
    public Struct built;

    public Optional<TypeDecl> find(String memberName) {
      for (TypeDecl typeDecl : children) {
        if (typeDecl.getIdentifier().value.equals(memberName)) {
          return Optional.of(typeDecl);
        }
      }

      return Optional.empty();
    }

    @Override
    public void build(List<Problem> problems, TypeBuilder builder) {
      StructBuilder structBuilder = new StructBuilder(children.size());

      for (TypeDecl typeDecl : children) {
        typeDecl.build(problems, builder);

        if (typeDecl.getBuilt() != null) {
          structBuilder.add(typeDecl.getIdentifier().value, typeDecl.getBuilt());
        }
      }

      ResultOrProblems<Struct> builtOr = structBuilder.buildOr();

      if (builtOr.hasProblems()) {
        problems.add(Problem.composite(builtOr.getProblems(), "Failed to build type for %s", identifier));
      } else {
        this.built = builtOr.get();
      }
    }

    @Override
    public Token getBoundaryToken() {
      return identifier;
    }
  }

  @EqualsAndHashCode(callSuper = false) @ToString @RequiredArgsConstructor
  public static class Filter extends AST {
    @Getter
    public final Token identifier;
    public final Token filterExpression;
    public final List<Filter> children;
    public final Optional<ExpressionDecl> orElse;
    public Expression built;

    public void build(List<Problem> problems, ExpressionParser parser) {

      SimpleExpression simple = new SimpleExpression(identifier, filterExpression);
      simple.build(problems, parser);
      this.built = simple.built;

      orElse.ifPresent(expr -> expr.build(problems, parser));

      for (Filter filter : children) {
        filter.build(problems, parser);
      }
    }

    public Set<String> scanStructKeys() {
      Set<String> collected = new LinkedHashSet<String>();
      orElse.ifPresent(o -> collected.addAll(o.scanStructKeys()));

      for (Filter filter : children) {
        collected.addAll(filter.scanStructKeys());
      }
      return collected;
    }

    @Override
    public Token getBoundaryToken() {
      return identifier;
    }
  }

  public interface ExpressionDecl {
    Token getIdentifier();

    Set<String> scanStructKeys();

    void build(List<Problem> problems, ExpressionParser parser);
  }

  @EqualsAndHashCode(callSuper = false) @ToString @RequiredArgsConstructor
  public static class SimpleExpression extends AST implements ExpressionDecl {

    /**
     * Testing constructor
     */
    public static SimpleExpression create(String ident, Expression expr) {
      SimpleExpression instance = new SimpleExpression(
          Token.token(TokenTypes.IDENTIFIER, ident),
          Lists.newArrayList(Token.token(TokenTypes.EXPRESSION, expr.toSource()))
      );
      instance.built = expr;
      return instance;
    }

    @Getter
    public final Token identifier;

    /**
     * List of {@link Token}s that make up the expression.
     */
    @Getter
    public final List<Token> expressionParts;

    /**
     * The raw string of the expression. This is formed by concatenating the raw value from each
     * {@link Token} from expressionParts.
     */
    @Getter
    public final String expression;

    public Expression built;

    public SimpleExpression(Token identifier, Token exprToken) {
      this(identifier, Lists.newArrayList(exprToken));
    }

    public SimpleExpression(Token identifier, List<Token> expressionParts) {
      this.identifier = identifier;
      this.expressionParts = expressionParts;
      this.expression = expressionParts.stream()
            .map(Token::rawValue)
            .collect(Collectors.joining());
    }

    @Override
    public void build(List<Problem> problems, ExpressionParser parser) {
      try {
        built = parser.parse(expression);
      } catch (LexerException | nz.org.riskscape.dsl.ParseException ex) {
        // include both the token identifier and the expression itself for context
        Problem exprProblem = Problems.foundWith(Expression.class, expression, Problems.caught(ex));
        problems.add(Problems.foundWith(getIdentifier(), exprProblem));
      }
    }

    @Override
    public Set<String> scanStructKeys() {
      return Collections.singleton(identifier.value);
    }

    @Override
    public Token getBoundaryToken() {
      return identifier;
    }
  }

  @EqualsAndHashCode(callSuper = false) @ToString @RequiredArgsConstructor
  public static class StructExpression extends AST implements ExpressionDecl {

    /**
     * Testing constructor
     */
    public static StructExpression create(String ident, SimpleExpression... children) {
      return new StructExpression(Token.token(TokenTypes.IDENTIFIER, ident), Arrays.asList(children));
    }

    @Getter
    public final Token identifier;
    public final List<ExpressionDecl> members;
//    public final List<SimpleExpression> members;

    public Optional<ExpressionDecl> find(String memberName) {
      for (ExpressionDecl typeDecl : members) {
        if (typeDecl.getIdentifier().value.equals(memberName)) {
          return Optional.of(typeDecl);
        }
      }

      return Optional.empty();
    }

    @Override
    public Set<String> scanStructKeys() {
      Set<String> childKeys = new HashSet<>();

      for (ExpressionDecl expressionDecl : members) {
        childKeys.addAll(expressionDecl.scanStructKeys());
      }

      return childKeys;
    }

    @Override
    public void build(List<Problem> problems, ExpressionParser parser) {
      for (ExpressionDecl expression : members) {
        expression.build(problems, parser);
      }
    }

    @Override
    public Token getBoundaryToken() {
      return identifier;
    }
  }
}
