/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.rl;

import static nz.org.riskscape.dsl.ConditionalParse.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import org.geotools.api.filter.Filter;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import nz.org.riskscape.dsl.ConditionalParse;
import nz.org.riskscape.dsl.Lexer;
import nz.org.riskscape.dsl.LexerException;
import nz.org.riskscape.dsl.ParseException;
import nz.org.riskscape.dsl.Token;
import nz.org.riskscape.dsl.UnexpectedTokenException;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ast.BinaryOperation;
import nz.org.riskscape.rl.ast.BracketedExpression;
import nz.org.riskscape.rl.ast.Constant;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.ExpressionProblems;
import nz.org.riskscape.rl.ast.FunctionCall;
import nz.org.riskscape.rl.ast.Lambda;
import nz.org.riskscape.rl.ast.ListDeclaration;
import nz.org.riskscape.rl.ast.MinimalVisitor;
import nz.org.riskscape.rl.ast.ParameterToken;
import nz.org.riskscape.rl.ast.PropertyAccess;
import nz.org.riskscape.rl.ast.SelectAllExpression;
import nz.org.riskscape.rl.ast.StructDeclaration;
import nz.org.riskscape.rl.ast.StructDeclaration.Member;

/**
 * Builds a Riskscape Language Expression AST from a string.  At the moment, the language is a mostly equivalent
 * version of ECQL that simplifies the language to make it more flexible for use throughout the riskscape engine.
 * Current plans are to extend its use to all uses of {@link org.geotools.api.filter.expression.Expression} and
 * {@link Filter} and possibly also taking in type expressions and pipeline expressions.
 *
 */
public class ExpressionParser {

  public static final ExpressionParser INSTANCE = new ExpressionParser();

  /**
   * Convenience form of `new ExpressionParser().parse(source)`
   */
  public static Expression parseString(String source) {
    return INSTANCE.parse(source);
  }

  /**
   * Used with {@link ConditionalParse}
   */
  public static final EnumSet<TokenTypes> IDENTIFIERS = EnumSet.of(
      TokenTypes.IDENTIFIER,
      TokenTypes.QUOTED_IDENTIFIER
  );

  /**
   * Used with {@link ConditionalParse}
   */
  public static final EnumSet<TokenTypes> KEY_IDENTIFIERS = EnumSet.of(
      TokenTypes.KEY_IDENTIFIER,
      TokenTypes.IDENTIFIER,
      TokenTypes.QUOTED_IDENTIFIER
  );

  /**
   * All of the things that can be inside a {@link Constant}
   */
  public static final EnumSet<TokenTypes> LITERALS = EnumSet.of(TokenTypes.STRING,
      TokenTypes.INTEGER,
      TokenTypes.DECIMAL,
      TokenTypes.SCIENTIFIC_NOTATION,
      TokenTypes.KEYWORD_FALSE,
      TokenTypes.KEYWORD_TRUE,
      TokenTypes.KEYWORD_NULL
  );

  /**
   * The list of tokens that typically start any expression
   *
   * NB this is 'derived', which kind of sucks, but it's just
   * for debugging/error messages
   */
  public static final EnumSet<TokenTypes> LEADING_TOKENS = EnumSet.copyOf(ImmutableSet.<TokenTypes>builder()
      .addAll(IDENTIFIERS)
      .addAll(LITERALS)
      .add(TokenTypes.LBRACE)
      .add(TokenTypes.LBRACK)
      .add(TokenTypes.MULTIPLY)
      .build()
  );


  Set<TokenTypes> binaryOperators = Sets.newHashSet(
      TokenTypes.PLUS,
      TokenTypes.MINUS,
      TokenTypes.MULTIPLY,
      TokenTypes.DIVIDE,
      TokenTypes.POW,
      TokenTypes.KEYWORD_OR,
      TokenTypes.OR,
      TokenTypes.KEYWORD_AND,
      TokenTypes.AND,
      TokenTypes.LESS_THAN,
      TokenTypes.GREATER_THAN,
      TokenTypes.LESS_THAN_EQUAL,
      TokenTypes.GREATER_THAN_EQUAL,
      TokenTypes.EQUALS,
      TokenTypes.NOT_EQUALS
  );

  /**
   * Parses the given string in to an {@link Expression}.  See {@link #parseAllowParameters(String)} for a version that
   * accepts {@link ParameterToken}s
   *
   * @return an Expression from the given string
   * @throws ParseException if source couldn't be parsed or if the expression contains {@link ParameterToken}s.
   * @throws LexerException if source couldn't be lexed in to tokens
   */
  public Expression parse(String source) {
    return parse(source, false);
  }

  /**
   * Convenience method to catch any exceptions from parsing the expression and
   * return them as a {@link ResultOrProblem}
   */
  public ResultOrProblems<Expression> parseOr(String source) {
    try {
      return ResultOrProblems.of(parse(source));
    } catch (ParseException | LexerException ex) {
      return ResultOrProblems.failed(Problems.caught(ex));
    }
  }

  /**
   * A version of {@link #parse(String)} that allows the expression to contain {@link ParameterToken}s.
   * @param source
   * @return Expression
   */
  public Expression parseAllowParameters(String source) {
    return parse(source, true);
  }

  private Expression parse(String source, boolean allowTokens) {
    Expression parsed;

    Lexer<TokenTypes> lexer = lex(source);

    if (lexer.peekType() == TokenTypes.EOF) {
      throw new UnexpectedTokenException(
        ExpressionProblems.get().emptyStringNotValid(), LEADING_TOKENS, lexer.expect(TokenTypes.EOF)
      );
    }

    try {
      parsed = parseExpression(lexer);
      lexer.expect(TokenTypes.EOF);
    } catch (UnexpectedTokenException ex) {
      // map to a Problem/Exception slightly more meaningful to user
      throw new MalformedExpressionException(source, ex);
    }

    if (!allowTokens) {
      checkForParameters(parsed);
    }

    return parsed;
  }

  /**
   * @throws ParseException if the given expression contains $foo style parameters
   */
  public void checkForParameters(Expression parsed) {
    parsed.accept(new MinimalVisitor<Object>() {

      @Override
      public Object visit(ParameterToken parameterToken, Object data) {
        throw new UnexpectedTokenException(
            ExpressionProblems.get().parametersNotAllowed(parameterToken),
            LEADING_TOKENS,
            parameterToken.getToken()
        );
      }

    }, null);
  }

  /**
   * Return a version of the given expression as a struct.  If it's already a struct declaration it is returned
   * as is, otherwise a simple anonymous struct declaration is put around it.
   */
  public StructDeclaration toStruct(Expression expr) {
    return expr
        .isA(StructDeclaration.class)
        .orElseGet(() -> new StructDeclaration(Arrays.asList(
            StructDeclaration.anonMember(expr)), Optional.empty()
        ));
  }

  /**
   * Return a version of the given expression as a {@link ListDeclaration}. If expression is a ListDeclaration
   * it is returned as is, otherwise a ListDeclaration is returned containing expression as the single item.
   */
  public ListDeclaration toList(Expression expr) {
    return expr.isA(ListDeclaration.class)
        .orElseGet(() -> new ListDeclaration(Arrays.asList(expr), Optional.empty()));
  }

  /**
   * Lexes source with the {@link TokenTypes#tokens() }.
   * @param source to lex
   */
  protected Lexer<TokenTypes> lex(String source) {
    return new Lexer<>(TokenTypes.tokens(), source);
  }

  public Expression parseExpression(Lexer<TokenTypes> lexer) {
    Expression start = lexer.tryThese(
        parseIfIs(
            "lambda-noarg",
            Arrays.asList(
                TokenTypes.LPAREN,
                TokenTypes.RPAREN,
                TokenTypes.CHAIN
            ),
            () -> parseLambdaExpression(lexer)),
        parseIf(
            "lambda-unary",
            Arrays.asList(
                Collections.singleton(TokenTypes.LPAREN),
                IDENTIFIERS,
                Collections.singleton(TokenTypes.RPAREN),
                EnumSet.of(TokenTypes.CHAIN)
            ),
            () -> parseLambdaExpression(lexer)),
        parseIf(
            "lambda-unary-bracketed",
            Arrays.asList(
                IDENTIFIERS,
                EnumSet.of(TokenTypes.CHAIN)
            ),
            () -> parseLambdaExpression(lexer)),
        parseIf(
            "lambda",
            Arrays.asList(
                Collections.singleton(TokenTypes.LPAREN),
                IDENTIFIERS,
                Collections.singleton(TokenTypes.COMMA)
            ),
            () -> parseLambdaExpression(lexer)),
        parseIf(
            "empty-struct",
            Arrays.asList(
              EnumSet.of(TokenTypes.LBRACE),
              EnumSet.of(TokenTypes.RBRACE)
            ),
            () -> parseStructExpression(lexer)),
        parseIf(
            "struct-expression",
            Arrays.asList(
              EnumSet.of(TokenTypes.LBRACE),
              KEY_IDENTIFIERS,
              EnumSet.of(TokenTypes.COLON)
            ),
            () -> parseStructExpression(lexer)),
        parseIf(
            "struct-expression-leading-select-all",
            Arrays.asList(
              EnumSet.of(TokenTypes.LBRACE),
              EnumSet.of(TokenTypes.MULTIPLY),
              EnumSet.of(TokenTypes.COMMA),
              KEY_IDENTIFIERS,
              EnumSet.of(TokenTypes.COLON)
            ),
            () -> parseStructExpression(lexer)),
        parseIfIs("struct-expression-as-syntax",
            TokenTypes.LBRACE,
            () -> parseStructExpressionAsSyntax(lexer)),
        parseIfIs("list-expression", TokenTypes.LBRACK, () -> parseListExpression(lexer)),
        parseIfOneOf("constant-expression", LITERALS,
            () -> parseConstantExpression(lexer)),
        parseIfIs("brackets-expression", TokenTypes.LPAREN,
            () -> parseBracketedExpression(lexer)),
        parseIf("function-expression", Arrays.asList(KEY_IDENTIFIERS, EnumSet.of(TokenTypes.LPAREN)),
            () -> parseFunctionExpression(lexer)),
        parseIfOneOf("property-expression", IDENTIFIERS,
          () -> parsePropertyExpression(lexer, Optional.empty())),
        parseIfIs("select-all", TokenTypes.MULTIPLY, () -> parseSplatExpression(lexer)),
        parseIfIs("token", TokenTypes.PARAMETER_IDENTIFIER, () -> parseParameterExpression(lexer))
    );

    if (binaryOperators.contains(lexer.peekType())) {
      return parseBinaryExpression(lexer, start);
    } else if (lexer.peekType() == TokenTypes.INDEX) {
      lexer.next();
      // allow the result of another expression to be splatted, e.g. foo().*
      if (lexer.peekType() == TokenTypes.MULTIPLY) {
        return new PropertyAccess(Optional.of(start), Collections.singletonList(lexer.next()));
      } else {
        return parsePropertyExpression(lexer, Optional.of(start));
      }
    } else {
      return start;
    }
  }

  private Expression parseParameterExpression(Lexer<TokenTypes> lexer) {
    return new ParameterToken(lexer.next());
  }

  /**
   * Not much to this one...
   */
  private SelectAllExpression parseSplatExpression(Lexer<TokenTypes> lexer) {
    return new SelectAllExpression(lexer.expect(TokenTypes.MULTIPLY));
  }

  private Expression parseLambdaExpression(Lexer<TokenTypes> lexer) {
    List<Token> args;
    Token left;
    if (lexer.peekType() == TokenTypes.LPAREN) {
      left = lexer.next();

      if (lexer.consumeIf(TokenTypes.RPAREN).isPresent()) {
        args = Collections.emptyList();
      } else {
        args = new LinkedList<>();
        while (true) {
          Token ident = lexer.expect(IDENTIFIERS);

          args.add(ident);

          // and again here to break the loop or consume the comma
          Token commaOrParen = lexer.expect(TokenTypes.COMMA, TokenTypes.RPAREN);
          if (commaOrParen.type == TokenTypes.RPAREN) {
            break;
          }
        }
      }
    } else {
      left = lexer.expect(IDENTIFIERS);
      args = Collections.singletonList(left);
    }

    lexer.expect(TokenTypes.CHAIN);

    Expression expr = parseExpression(lexer);

    return new Lambda(left, args, expr);
  }

  private Expression parseStructExpression(Lexer<TokenTypes> lexer) {
    Token opening = lexer.expect(TokenTypes.LBRACE);

    boolean selectAllSeen = false;
    List<Member> members = new ArrayList<>();
    while (lexer.peekType() !=TokenTypes.RBRACE) {
      // a splat might occur anywhere within the struct expression, but only once
      if (lexer.peekType().equals(TokenTypes.MULTIPLY)) {
        Token selectAll = lexer.next();
        if (selectAllSeen) {
          throw new UnexpectedTokenException(
              ExpressionProblems.get().duplicateSelectAll(selectAll),
              KEY_IDENTIFIERS,
              selectAll);
        }

        selectAllSeen = true;
        members.add(StructDeclaration.selectAllMember(selectAll));
      } else {
        Token ident = lexer.expect(KEY_IDENTIFIERS);
        lexer.expect(TokenTypes.COLON);
        members.add(StructDeclaration.jsonStyleMember(ident, parseExpression(lexer)));
      }

      if (lexer.peekType() !=TokenTypes.RBRACE) {
        // comma separators are required
        lexer.expect(TokenTypes.COMMA);
      }
    }

    return new StructDeclaration(members, Optional.of(Pair.of(opening, lexer.expect(TokenTypes.RBRACE))));
  }

  /**
   * Parses a {@link StructDeclaration} using alternate 'as' syntax. E.g
   *
   * <code>
   * {
   *   <some expression> [as ident],
   *   calc_loss(asset, hazard) as loss
   * }
   * </code>
   *
   * @return
   */
  private Expression parseStructExpressionAsSyntax(Lexer<TokenTypes> lexer) {
    Token opening = lexer.expect(TokenTypes.LBRACE);

    List<Member> members = new ArrayList<>();

    boolean selectAllSeen = false;
    while (lexer.peekType() != TokenTypes.RBRACE) {
      // a splat might occur anywhere within the struct expression, but only once
      if (lexer.peekType() == TokenTypes.MULTIPLY) {
        Token selectAll = lexer.next();

        if (selectAllSeen) {
          throw new UnexpectedTokenException(
              ExpressionProblems.get().duplicateSelectAll(selectAll),
              LEADING_TOKENS,
              selectAll);
        }

        selectAllSeen = true;
        members.add(StructDeclaration.selectAllMember(selectAll));
      } else {
        Expression expr = parseExpression(lexer);
        Optional<Token> asToken = lexer.consumeIf(TokenTypes.KEYWORD_AS);
        if (asToken.isPresent()) {
          Token ident = lexer.expect(TokenTypes.IDENTIFIER, TokenTypes.QUOTED_IDENTIFIER);
          members.add(StructDeclaration.sqlStyleMember(ident, expr, asToken.get()));
        } else {
          // if no `as` then we infer the identifer from the expression
          members.add(StructDeclaration.anonMember(expr));
        }
      }

      if (lexer.peekType() !=TokenTypes.RBRACE) {
        // comma separators are required
        lexer.expect(TokenTypes.COMMA);
      }
    }

    return new StructDeclaration(members, Optional.of(Pair.of(opening, lexer.expect(TokenTypes.RBRACE))));
  }

  private Expression parseListExpression(Lexer<TokenTypes> lexer) {
    Token opening = lexer.expect(TokenTypes.LBRACK);
    ArrayList<Expression> listDeclaration = new ArrayList<Expression>();

    while (lexer.peekType() != TokenTypes.RBRACK) {
      listDeclaration.add(parseExpression(lexer));

      if (lexer.peekType() != TokenTypes.RBRACK) {
        // comma separatores are required
        lexer.expect(TokenTypes.COMMA);
      }
    }

    return new ListDeclaration(listDeclaration, Optional.of(Pair.of(opening, lexer.expect(TokenTypes.RBRACK))));
  }

  private Expression parseBracketedExpression(Lexer<TokenTypes> lexer) {
    Token opening = lexer.expect(TokenTypes.LPAREN);
    Expression expression = parseExpression(lexer);

    return new BracketedExpression(expression, Optional.of(Pair.of(opening, lexer.expect(TokenTypes.RPAREN))));
  }

  public FunctionCall parseFunctionExpression(Lexer<TokenTypes> lexer) {
    Token identifier = lexer.expect(TokenTypes.IDENTIFIER, TokenTypes.QUOTED_IDENTIFIER);

    lexer.expect(TokenTypes.LPAREN);

    if (lexer.peekType() == TokenTypes.RPAREN) {
      return new FunctionCall(identifier, Collections.emptyList(), lexer.next());
    }
    List<FunctionCall.Argument> args = Lists.newArrayList();
    while (true) {
      Token keywordIdentifier = lexer.consumeIf(KEY_IDENTIFIERS).orElse(null);
      if (keywordIdentifier != null) {
        if (lexer.peekType() == TokenTypes.COLON) {
          lexer.next();
        } else {
          lexer.rewind(keywordIdentifier);
          keywordIdentifier = null;
        }
      }

      Expression arg = parseExpression(lexer);
      args.add(new FunctionCall.Argument(arg, Optional.ofNullable(keywordIdentifier)));
      Token separator = lexer.expect(TokenTypes.COMMA, TokenTypes.RPAREN);
      if (separator.type == TokenTypes.RPAREN) {
        lexer.rewind(separator);
        break;
      }
    }

    return new FunctionCall(identifier, args, lexer.expect(TokenTypes.RPAREN));
  }

  private Constant parseConstantExpression(Lexer<TokenTypes> lexer) {
    return new Constant(lexer.expect(LITERALS));
  }

  public PropertyAccess parsePropertyExpression(Lexer<TokenTypes> lexer, Optional<Expression> receiver) {
    List<Token> identifiers = Lists.newArrayList();
    identifiers.add(lexer.expect(TokenTypes.IDENTIFIER, TokenTypes.QUOTED_IDENTIFIER));

    while (lexer.peekType() == TokenTypes.INDEX) {
      lexer.next();

      if (lexer.peekType().equals(TokenTypes.MULTIPLY)) {
        identifiers.add(lexer.next());
        break;
      } else {
        identifiers.add(lexer.expect(TokenTypes.IDENTIFIER, TokenTypes.QUOTED_IDENTIFIER));
      }
    }

    return new PropertyAccess(receiver, identifiers);
  }

  private BinaryOperation parseBinaryExpression(Lexer<TokenTypes> lexer, Expression lhs) {
    Token operation = lexer.expect(binaryOperators.toArray(new TokenTypes[0]));

    return new BinaryOperation(lhs, operation, parseExpression(lexer));
  }
}
