/*
 * Decompiled with CFR 0.152.
 */
package nz.org.riskscape.defaults.function;

import com.google.common.collect.Range;
import java.util.ArrayList;
import java.util.List;
import lombok.Generated;
import nz.org.riskscape.defaults.curves.CannotFitCurveException;
import nz.org.riskscape.defaults.curves.ObservedPoints;
import nz.org.riskscape.engine.ArgsProblems;
import nz.org.riskscape.engine.function.ArgumentList;
import nz.org.riskscape.engine.function.BaseRealizableFunction;
import nz.org.riskscape.engine.function.FunctionArgument;
import nz.org.riskscape.engine.function.RiskscapeFunction;
import nz.org.riskscape.engine.problem.GeneralProblems;
import nz.org.riskscape.engine.problem.ProblemFactory;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.rl.agg.Accumulator;
import nz.org.riskscape.engine.rl.agg.AggregationFunction;
import nz.org.riskscape.engine.rl.agg.RealizedAggregateExpression;
import nz.org.riskscape.engine.types.RSList;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.engine.util.SegmentedList;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.FunctionCall;

public class AALHazardBasedFunction
extends BaseRealizableFunction
implements AggregationFunction {
    public static final Range<Double> PROBABILITY_RANGE = Range.closed((Comparable)Double.valueOf(0.0), (Comparable)Double.valueOf(1.0));
    public static final int MIN_NUM_EVENTS = 3;
    public static final LocalProblems PROBLEMS = (LocalProblems)Problems.get(LocalProblems.class);
    private static final ArgumentList ARGUMENTS = ArgumentList.create((String)"loss", (Type)Types.FLOATING, (String)"ep", (Type)Types.FLOATING);
    static final ArgumentList SCALAR_ARGUMENTS = ArgumentList.create((String)"loss", (Type)RSList.create((Type)Types.FLOATING), (String)"ep", (Type)RSList.create((Type)Types.FLOATING));

    public AALHazardBasedFunction() {
        super(ARGUMENTS, (Type)Types.FLOATING);
    }

    public ResultOrProblems<RiskscapeFunction> realize(RealizationContext context, FunctionCall fc, List<Type> givenTypes) {
        return ProblemException.catching(() -> {
            if (givenTypes.size() != this.arguments.size()) {
                throw new ProblemException((Problems)ArgsProblems.get().wrongNumber(this.arguments.size(), givenTypes.size()));
            }
            RSList lossType = (RSList)SCALAR_ARGUMENTS.getRequiredAs(givenTypes, 0, RSList.class).getOrThrow();
            RSList epType = (RSList)SCALAR_ARGUMENTS.getRequiredAs(givenTypes, 1, RSList.class).getOrThrow();
            if (!lossType.getContainedType().isNumeric()) {
                throw new ProblemException((Problems)ArgsProblems.mismatch((FunctionArgument)SCALAR_ARGUMENTS.get(0), (Type)((Type)givenTypes.get(0))));
            }
            if (!epType.getContainedType().isNumeric() || epType.getContainedType().isNullable()) {
                throw new ProblemException((Problems)ArgsProblems.mismatch((FunctionArgument)SCALAR_ARGUMENTS.get(1), (Type)((Type)givenTypes.get(1))));
            }
            return RiskscapeFunction.create((Object)((Object)this), (List)givenTypes, (Type)Types.FLOATING, args -> {
                ArrayList losses = new ArrayList((List)args.get(0));
                List eps = (List)args.get(1);
                if (losses.size() != eps.size()) {
                    throw new CannotFitCurveException(Problems.foundWith((Object)fc, (Problems)GeneralProblems.get().differentListLengths((Object)SCALAR_ARGUMENTS.get(0), losses.size(), (Object)SCALAR_ARGUMENTS.get(1), eps.size())));
                }
                AALCalculator calculator = new AALCalculator(fc);
                for (int i = 0; i < losses.size(); ++i) {
                    calculator.addLossDataPoint((Number)losses.get(i), (Number)eps.get(i));
                }
                return calculator.calculateAAL();
            }, (AutoCloseable[])new AutoCloseable[0]);
        });
    }

    public RiskscapeFunction asFunction() {
        return AggregationFunction.addAggregationTo((AggregationFunction)this, (RiskscapeFunction)super.asFunction());
    }

    public ResultOrProblems<RealizedAggregateExpression> realize(RealizationContext context, Type inputType, FunctionCall fc) {
        return ProblemException.catching(() -> {
            RealizedExpression xExpression = this.realizeArg(context, inputType, fc, "ep");
            RealizedExpression yExpression = this.realizeArg(context, inputType, fc, "loss");
            if (xExpression.getResultType().isNullable()) {
                throw new ProblemException((Problems)TypeProblems.get().mismatch((Object)this.arguments.get("ep"), (Type)Types.FLOATING, xExpression.getResultType()));
            }
            return RealizedAggregateExpression.create((Type)inputType, (Type)Types.FLOATING, (Expression)fc, () -> new AccumImpl(fc, xExpression, yExpression));
        });
    }

    private RealizedExpression realizeArg(RealizationContext context, Type inputType, FunctionCall fc, String argName) throws ProblemException {
        RealizedExpression realized = (RealizedExpression)context.getExpressionRealizer().realize(inputType, ((FunctionCall.Argument)this.arguments.getRequiredArgument(fc, argName).getOrThrow()).getExpression()).getOrThrow(Problems.foundWith((Object)this.arguments.get(argName), (Problem[])new Problem[0]));
        if (!realized.getResultType().isNumeric()) {
            throw new ProblemException((Problems)TypeProblems.get().mismatch((Object)this.arguments.get(argName), (Type)Types.FLOATING, realized.getResultType()));
        }
        return realized;
    }

    private static final class AccumImpl
    extends AALCalculator
    implements Accumulator {
        final RealizedExpression xExpression;
        final RealizedExpression yExpression;

        AccumImpl(FunctionCall fc, RealizedExpression xExpression, RealizedExpression yExpression) {
            super(fc);
            this.yExpression = yExpression;
            this.xExpression = xExpression;
        }

        public Accumulator combine(Accumulator rhs) {
            AccumImpl impl = (AccumImpl)rhs;
            this.xValues.addAll(impl.xValues);
            this.yValues.addAll(impl.yValues);
            return this;
        }

        public void accumulate(Object input) {
            Number probability = (Number)this.xExpression.evaluate(input);
            Number loss = (Number)this.yExpression.evaluate(input);
            this.addLossDataPoint(loss, probability);
        }

        public Object process() {
            return this.calculateAAL();
        }

        public boolean isEmpty() {
            return this.xValues.isEmpty();
        }
    }

    private static class AALCalculator {
        final FunctionCall fc;
        List<Double> xValues = SegmentedList.forClass(Double.class);
        List<Double> yValues = SegmentedList.forClass(Double.class);

        public void addLossDataPoint(Number loss, Number probability) {
            if (probability == null || probability.doubleValue() <= 0.0 || probability.doubleValue() >= 1.0) {
                throw new CannotFitCurveException(Problems.foundWith((Object)this.fc, (Problems)GeneralProblems.get().badValue((Object)probability, (Object)ARGUMENTS.get("ep"), PROBABILITY_RANGE)));
            }
            if (this.xValues.contains(probability.doubleValue())) {
                throw new CannotFitCurveException(Problems.foundWith((Object)this.fc, (Problems)PROBLEMS.duplicateProbability(probability.doubleValue())));
            }
            if (loss == null) {
                loss = 0;
            }
            if (loss.doubleValue() < 0.0) {
                throw new CannotFitCurveException(Problems.foundWith((Object)this.fc, (Problems)GeneralProblems.get().badValue((Object)loss, (Object)ARGUMENTS.get("loss"), (Object)Range.atLeast((Comparable)Integer.valueOf(0)))));
            }
            this.xValues.add(probability.doubleValue());
            this.yValues.add(loss.doubleValue());
        }

        private double calculateAAL(List<Pair<Double, Double>> datapoints) {
            if (datapoints.size() < 3) {
                throw new CannotFitCurveException(Problems.foundWith((Object)this.fc, (Problems)PROBLEMS.tooFewEvents(datapoints.size(), 3)));
            }
            ArrayList<Pair<Double, Double>> xyPairs = new ArrayList<Pair<Double, Double>>(datapoints);
            xyPairs.sort((a, b) -> ((Double)a.getLeft()).compareTo((Double)b.getLeft()));
            double aal = 0.0;
            for (int i = 0; i < xyPairs.size() - 1; ++i) {
                double x1 = (Double)((Pair)xyPairs.get(i)).getLeft();
                double x2 = (Double)((Pair)xyPairs.get(i + 1)).getLeft();
                double y1 = (Double)((Pair)xyPairs.get(i)).getRight();
                double y2 = (Double)((Pair)xyPairs.get(i + 1)).getRight();
                aal += (x2 - x1) * ((y1 + y2) / 2.0);
            }
            double minEP = (Double)((Pair)xyPairs.get(0)).getLeft();
            double maxLoss = (Double)((Pair)xyPairs.get(0)).getRight();
            return aal += minEP * maxLoss;
        }

        public double calculateAAL() {
            ObservedPoints points = new ObservedPoints(this.xValues, this.yValues);
            return this.calculateAAL(points.asListOfPairs());
        }

        @Generated
        public AALCalculator(FunctionCall fc) {
            this.fc = fc;
        }
    }

    public static interface LocalProblems
    extends ProblemFactory {
        public Problem duplicateProbability(Double var1);

        public Problem tooFewEvents(int var1, int var2);
    }
}

