/*
 * RiskScape™ Copyright New Zealand Institute for Earth Science Limited
 * (Earth Sciences New Zealand) is distributed for research purposes only
 * under the terms of AGPLv3.
 *
 * RiskScape™ Copyright 2025 New Zealand Institute for Earth Science
 * Limited (Earth Sciences New Zealand). All rights reserved. Source code
 * available under the AGPLv3.
 * 
 * This program is free software: you can redistribute it and/or modify it under
 *  the terms of the GNU Affero General Public License as published by the Free
 *  Software Foundation, either version 3 of the License, or (at your option) any
 *  later version.
 * 
 * This program is distributed for RESEARCH PURPOSES ONLY, in the hope that it will
 * be useful for research and education initiatives.
 * 
 * If you are not a researcher, or you are a researcher who wishes to use this
 * program on terms other than AGPLv3 (including those who wish to restrict the
 * distribution of any source code created using this program), please contact:
 * https://riskscape.org.nz
 * 
 * This program is distributed WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.  You should have received a copy
 * of the GNU Affero General Public License along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 * 
 * By way of summary only, under the AGPLv3:
 *     • Permissions of this strongest copyleft license are conditioned
 *       on making available complete source code of licensed works and
 *       modifications, which include larger works using a licensed work,
 *       under the same license.
 *     • Copyright and license notices must be preserved.
 *     • Contributors provide an express grant of patent rights.
 *     • When a modified version is used to provide a service over a
 *       network, the complete source code of the modified version must be made
 *       available.
 */
package nz.org.riskscape.defaults.interp;

import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;

import com.google.common.primitives.Doubles;

import lombok.EqualsAndHashCode;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import nz.org.riskscape.engine.bind.ParameterField;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.rl.ScopedLambdaExpression;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.TypeVisitor;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ExpressionParser;
import nz.org.riskscape.rl.ast.Expression;

/**
 * Marked final to (possibly) help the JVM to optimize this code.  Can be removed if a future refactor deems it awkward.
 */
@EqualsAndHashCode(callSuper = true, of = {"xValues", "yValues", "options"})
public final class BilinearContinuousFunctionType extends StackableContinuousFunctionType<BilinearContinuousFunction> {

  @RequiredArgsConstructor @EqualsAndHashCode @ToString
  static final class XY {
    final double x;
    final double y;
  }

  /**
   * A simple container for the zero loss value.
   *
   * Required to allow a null zero loss to be possible when put in an optional.
   */
  @RequiredArgsConstructor @EqualsAndHashCode @ToString
  static class ZeroLoss {
    final Object value;
  }

  public static final BilinearContinuousFunctionType ANY_BILINEAR = new BilinearContinuousFunctionType();

  public static class Options {
    @ParameterField
    public boolean applyLogToX = false;
    @ParameterField
    public boolean applyLogToY = false;
    @ParameterField
    public Optional<Object> zeroLoss = Optional.empty();
    @ParameterField
    public boolean compress = false;
  }

  @RequiredArgsConstructor
  public static class Builder {

    public final RealizationContext context;

    public RealizedExpression valueExpression;
    public double[] xValues;
    public double[] yValues;
    public Optional<ZeroLoss> zeroLoss;
    public Options options;

    public ResultOrProblems<BilinearContinuousFunctionType> build(Consumer<Builder> setupVisitor) {
      setupVisitor.accept(this);

      return ProblemException.catching(() -> {
        // we need to build expressions to do the x and y axis interpolations
        // both of these expressions have the same formulae so it may seem that we'd only need one.
        // but we need both to cover the case that valueExpression returns integers because that will
        // produce a floating interpolated result to feed into yInterpExpr.
        Struct xInterpExprInput = Struct.of(
            "ratio1", Types.FLOATING, "value1", valueExpression.getResultType(),
            "ratio2", Types.FLOATING, "value2", valueExpression.getResultType()
        );
        RealizedExpression xInterpExpr = context.getExpressionRealizer()
            .realize(xInterpExprInput, "(ratio1 * value1) + (ratio2 * value2)").getOrThrow();

        // yInterpExpr appears identical to xInterpExpr, but is has xInterpExpr result type
        // as its input (this is necessary as valueExpression may return integers that xInterpExpr
        // will turn into floats)
        Struct yInterpExprInput = Struct.of(
            "ratio1", Types.FLOATING, "value1", xInterpExpr.getResultType(),
            "ratio2", Types.FLOATING, "value2", xInterpExpr.getResultType()
        );
        RealizedExpression yInterpExpr = context.getExpressionRealizer()
            .realize(yInterpExprInput, "(ratio1 * value1) + (ratio2 * value2)").getOrThrow();

        try {
          return new BilinearContinuousFunctionType(valueExpression, xInterpExpr, yInterpExpr,
            xValues, yValues, zeroLoss, options);
        } catch (IllegalArgumentException ex) {
          // the compress option doesn't return nice errors - it's an advanced feature
          throw new ProblemException(Problems.caught(ex));
        }
      });
    }

    public RealizedExpression realize(Struct inputType, String expr) throws ProblemException {
      Expression parsed = ExpressionParser.parseString(expr);
      return context.getExpressionRealizer().realize(inputType, parsed).getOrThrow();
    }
  }

  /**
   * Expression to interpolate the {@link #valueExpression} results across the x axis
   */
  private final RealizedExpression xInterpExpr;

  /**
   * Expression to interpolate the {@link #xInterpExpr} results across the y axis
   */
  private final RealizedExpression yInterpExpr;

  private final double[] xValues;
  private final double[] yValues;
  /**
   * The zeroLoss to be returned when required.
   *
   * This should be coerced if necessary to match return type of the value expression. (which is why we
   * don't use zeroLoss straight out of options)
   */
  private final Optional<ZeroLoss> zeroLoss;
  private final Options options;

  private BilinearContinuousFunctionType() {
    super(TWO_DIMENSIONAL, null, Types.FLOATING, false);
    this.xInterpExpr = null;
    this.yInterpExpr = null;
    this.xValues = new double[] {};
    this.yValues = new double[] {};
    this.zeroLoss = Optional.empty();
    this.options = new Options();
  }

  private BilinearContinuousFunctionType(RealizedExpression valueExpression,
      RealizedExpression xInterpExpr, RealizedExpression yInterpExpr, double[] xValues, double[] yValues,
      Optional<ZeroLoss> zeroLoss, Options options) {
    super(TWO_DIMENSIONAL, valueExpression, yInterpExpr.getResultType(), options.compress);
    this.xInterpExpr = xInterpExpr;
    this.yInterpExpr = yInterpExpr;
    this.xValues = xValues;
    this.yValues = yValues;
    this.zeroLoss = zeroLoss;
    this.options = options;
  }

  @Override
  public Object applyTo(Object func, double... dimensionValues) {
    // let's extract the args to a more convenient form.
    BilinearContinuousFunction function = (BilinearContinuousFunction) func;
    double x = dimensionValues[0];
    double y = dimensionValues[1];

    // does a bilinear interpolation as described at: https://en.wikipedia.org/wiki/Bilinear_interpolation
    // variables are given names matching the formulae as close as possible
    Pair<Double, Double> x1x2 = surrounding(x, xValues);
    double x1 = x1x2.getLeft();
    double x2 = x1x2.getRight();
    Pair<Double, Double> y1y2 = surrounding(y, yValues);
    double y1 = y1y2.getLeft();
    double y2 = y1y2.getRight();

    // if the required x or y are outside of the sampled range then we mung it to be either the
    // smallest or largest.
    x = withinRange(x, x1, x2);
    y = withinRange(y, y1, y2);

    // we duplicate our x/y values possibly with a log scale to use when interpolating
    double gx = options.applyLogToX ? Math.log(x) : x;
    double gx1 = options.applyLogToX ? Math.log(x1) : x1;
    double gx2 = options.applyLogToX ? Math.log(x2) : x2;
    double gy = options.applyLogToY ? Math.log(y) : y;
    double gy1 = options.applyLogToY ? Math.log(y1) : y1;
    double gy2 = options.applyLogToY ? Math.log(y2) : y2;

    // the below formulae match wikipedia but we added in the ternary operators to deal with the
    // single value case. we don't need to do interpolation then.
    Object fxy1;
    Object fxy2;
    if (x1 == x2) {
      // x1/x2 are the same no need to interpolate, but we still need to run the expression to ensure
      // we get the right types to put into the next step. So we use fixed ratios of 1 and 0
      fxy1 = xInterpExpr.evaluateValues(
          1D, computeIfNecessary(x1, y1, function),
          0D, computeIfNecessary(x1, y1, function)
      );
      fxy2 = xInterpExpr.evaluateValues(
          1D, computeIfNecessary(x1, y2, function),
          0D, computeIfNecessary(x1, y2, function)
      );
    } else {
      // we need to get the values at the vertex points. we do this in highest x/y values first
      // to allow any zero loss short circuits a chance to apply.
      Object x2y2 = computeIfNecessary(x2, y2, function);
      Object x2y1 = computeIfNecessary(x2, y1, function);
      Object x1y2 = computeIfNecessary(x1, y2, function);
      fxy1 = xInterpExpr.evaluateValues(
          (gx2 - gx) / (gx2 - gx1), computeIfNecessary(x1, y1, function),
          (gx - gx1) / (gx2 - gx1), x2y1
      );

      fxy2 = xInterpExpr.evaluateValues(
          (gx2 - gx) / (gx2 - gx1), x1y2,
          (gx - gx1) / (gx2 - gx1), x2y2
      );
    }

    if (y1 == y2) {
      // y1/y2 are the same so no need to interpolate, but we still need to run the expression to ensure
      // we get the right types to put into the next step. So we use fixed ratios of 1 and 0
      return yInterpExpr.evaluateValues(1D, fxy1, 0D, fxy1);
    } else {
      return yInterpExpr.evaluateValues(
          (gy2 - gy) / (gy2 - gy1), fxy1,
          (gy - gy1) / (gy2 - gy1), fxy2
      );
    }
  }

  @Override
  BilinearContinuousFunction newFunction(ScopedLambdaExpression scope) {
    return new BilinearContinuousFunction(scope, this);
  }

  @Override
  public int getSize() {
    // TODO memoize?
    return xValues.length * yValues.length;
  }

  @Override
  public <T, U> U visit(TypeVisitor<T, U> tv, T data) {
    throw new UnsupportedOperationException();
  }

  @Override
  public String toString() {
    return String.format("BilinearContinuousCurve(xvalues=%s, yvalues=%s, returnType=%s)",
        Doubles.asList(xValues), Doubles.asList(yValues), returnType);
  }

  private Object computeIfNecessary(double x, double y, BilinearContinuousFunction function) {
    return getOrComputeValue(function, xyToIndex(x, y));
  }

  @Override
  Object getOrComputeValue(BilinearContinuousFunction function, int key) {
    Object existing = getValue(function, key);

    if (existing != null) {
      return existing;
    }
    XY xy = indexToXY(key);

    if (zeroLoss.isPresent() && function.useZeroLoss(xy)) {
      // we can apply the zero loss and short circuit evaluating the function to speed things up.
      // we also set the value which is important when stacking the function.
      Object zeroLossValue = zeroLoss.get().value;

      setValue(function, key, zeroLossValue);

      return zeroLossValue;
    } else {
      Object value = function.evaluate(valueExpression, xy.x, xy.y);
      if (zeroLoss.isPresent() && Objects.equals(zeroLoss.get().value, value)) {
        function.addZeroLoss(xy);
      }

      setValue(function, key, value);

      return value;
    }
  }

  /**
   * Converts a known index into the x/y values it represents
   *
   * package private for test access.
   */
  XY indexToXY(int index) {
    double x = xValues[index % xValues.length];
    double y = yValues[index / xValues.length];
    return new XY(x, y);
  }

  /**
   * Converts x/y values to an index into the values array.
   *
   * package private for test access.
   *
   * @throws RuntimeException if x or y are not known. This is a programmer error.
   */
  int xyToIndex(double x, double y) {
    int xIdx = Arrays.binarySearch(xValues, x);
    int yIdx = Arrays.binarySearch(yValues, y);
    if (xIdx == -1 || yIdx == -1) {
      // this is a programmer error, the x/y values should exist
      throw new RuntimeException("Not a known point");
    }
    return (yIdx * xValues.length) + xIdx;
  }

  /**
   * Returns the values surrounding value to base interpolation on.
   *
   * There are cases where the result will contain a single value. They are:
   * - exact match for value. no need to interpolate then
   * - value less than first item in values
   * - value higher than last item in values
   */
  // static for easy access from tests
  static Pair<Double, Double> surrounding(double value, double[] values) {
      double firstPoint = values[0];
      double lastPoint = values[values.length - 1];
    if (values.length == 1) {
      // this is to deal with the case that there is only a single entry in the values list.
      return Pair.of(firstPoint, firstPoint);
    }
    if (value <= firstPoint) {
      // value is smaller than the lowest option. so we use the smallest two
      return Pair.of(firstPoint, firstPoint);
    } else if (value >= lastPoint) {
      // value is greater than the largest option. we roll with that.
      return Pair.of(lastPoint, lastPoint);
    }
    for (int i = 1; i < values.length; i++) {
      double v1 = values[i - 1];
      double v2 = values[i];
      if (v1 == value || v2 == value) {
        // if the required value matches a known point there is no need to interpolate.
        return Pair.of(value, value);
      }
      if (value >= v1 && value <= v2) {
        return Pair.of(v1, v2);
      }
    }
    throw new IllegalStateException("should never get here");
  }

  /**
   * @return value if it is within the range of min-max, else min or max
   */
  double withinRange(double value, double min, double max) {
    if (value < min) {
      return min;
    } else if (value > max) {
      return max;
    }
    return value;
  }
}
