/*
 * Decompiled with CFR 0.152.
 */
package nz.org.riskscape.defaults.curves;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Generated;
import nz.org.riskscape.defaults.curves.CannotFitCurveException;
import nz.org.riskscape.defaults.curves.CurveFitter;
import nz.org.riskscape.defaults.curves.FitCurve;
import nz.org.riskscape.defaults.curves.ObservedPoints;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.bind.BindingContext;
import nz.org.riskscape.engine.function.ArgumentList;
import nz.org.riskscape.engine.problem.GeneralProblems;
import nz.org.riskscape.engine.problem.ProblemFactory;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.rl.agg.Accumulator;
import nz.org.riskscape.engine.rl.agg.AggregationFunction;
import nz.org.riskscape.engine.rl.agg.RealizedAggregateExpression;
import nz.org.riskscape.engine.types.FunctionType;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.engine.util.SegmentedList;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.FunctionCall;

public class FitCurveFunction
implements AggregationFunction {
    public static final LocalProblems PROBLEMS = (LocalProblems)Problems.get(LocalProblems.class);
    private static final Struct FIT_RESULT = Struct.of((String)"function", (Type)new FunctionType(Collections.singletonList(Types.FLOATING), (Type)Types.FLOATING), (String)"fit", (Type)Types.TEXT, (String)"score", (Type)Types.FLOATING);
    private final ArgumentList arguments = ArgumentList.create((String)"x-value", (Type)Types.FLOATING, (String)"y-value", (Type)Types.FLOATING, (String)"fitters", (Type)Nullable.of((Type)Struct.EMPTY_STRUCT));
    private final Type returnType = FIT_RESULT;
    private final List<CurveFitter<?>> fitters;

    public ResultOrProblems<RealizedAggregateExpression> realize(RealizationContext context, Type inputType, FunctionCall fc) {
        return ProblemException.catching(() -> {
            List fittersAndParams;
            RealizedExpression xExpression = this.realize(context, inputType, fc, "x-value");
            RealizedExpression yExpression = this.realize(context, inputType, fc, "y-value");
            BindingContext bindingContext = context.getProject().newBindingContext(context);
            Optional fittersArgs = this.arguments.getArgument(fc, "fitters");
            if (fittersArgs.isPresent()) {
                Tuple tuple = (Tuple)this.arguments.evaluateConstant(context, fc, "fitters", Tuple.class, (Type)Struct.EMPTY_STRUCT).getOrThrow();
                fittersAndParams = new ArrayList(tuple.size());
                for (String fitterId : tuple.getStruct().getMemberKeys()) {
                    Map<Object, Object> unbound;
                    Object params = tuple.fetch(fitterId);
                    if (params instanceof String) {
                        unbound = Collections.emptyMap();
                        fitterId = (String)params;
                    } else {
                        unbound = params instanceof Tuple ? this.tupleToParamMap((Tuple)params) : null;
                    }
                    String finalFitterId = fitterId;
                    CurveFitter found = this.fitters.stream().filter(cf -> cf.getId().equals(finalFitterId)).findAny().orElse(null);
                    if (found == null) {
                        throw new ProblemException((Problems)GeneralProblems.get().noSuchObjectExists(fitterId, CurveFitter.class));
                    }
                    Object boundObject = found.bind(bindingContext, unbound).orElseThrow(problems -> new ProblemException((Problems)Problems.foundWith((Object)(finalFitterId + " options"), (List)problems)));
                    fittersAndParams.add(new FitterAndParams(found, boundObject));
                }
            } else {
                fittersAndParams = this.fitters.stream().map(fitter -> new FitterAndParams((CurveFitter)fitter, fitter.getBindingClass().cast(fitter.bind(bindingContext, Collections.emptyMap()).get()))).collect(Collectors.toList());
            }
            Struct resultType = context.normalizeStruct(FIT_RESULT);
            return RealizedAggregateExpression.create((Type)inputType, (Type)resultType, (Expression)fc, () -> new AccumImpl(resultType, fittersAndParams, xExpression, yExpression));
        });
    }

    private Map<String, List<?>> tupleToParamMap(Tuple tuple) {
        HashMap optionsMap = new HashMap(tuple.size());
        for (String key : tuple.getStruct().getMemberKeys()) {
            Object value = tuple.fetch(key);
            optionsMap.put(key, Collections.singletonList(value.toString()));
        }
        return optionsMap;
    }

    private RealizedExpression realize(RealizationContext context, Type inputType, FunctionCall fc, String argName) throws ProblemException {
        RealizedExpression realized = (RealizedExpression)context.getExpressionRealizer().realize(inputType, ((FunctionCall.Argument)this.arguments.getRequiredArgument(fc, argName).getOrThrow()).getExpression()).getOrThrow(Problems.foundWith((Object)this.arguments.get(argName), (Problem[])new Problem[0]));
        if (!realized.getResultType().isNumeric()) {
            throw new ProblemException((Problems)TypeProblems.get().mismatch((Object)this.arguments.get(argName), (Type)Types.FLOATING, realized.getResultType()));
        }
        return realized;
    }

    @Generated
    public FitCurveFunction(List<CurveFitter<?>> fitters) {
        this.fitters = fitters;
    }

    @Generated
    public ArgumentList getArguments() {
        return this.arguments;
    }

    @Generated
    public Type getReturnType() {
        return this.returnType;
    }

    private static class FitterAndParams {
        final CurveFitter fitter;
        final Object params;

        public <T> FitCurve<T> fit(ObservedPoints points) {
            Object castParams = this.fitter.getBindingClass().cast(this.params);
            return this.fitter.fit(castParams, points);
        }

        @Generated
        public FitterAndParams(CurveFitter fitter, Object params) {
            this.fitter = fitter;
            this.params = params;
        }

        @Generated
        public CurveFitter getFitter() {
            return this.fitter;
        }

        @Generated
        public Object getParams() {
            return this.params;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof FitterAndParams)) {
                return false;
            }
            FitterAndParams other = (FitterAndParams)o;
            if (!other.canEqual(this)) {
                return false;
            }
            CurveFitter this$fitter = this.getFitter();
            CurveFitter other$fitter = other.getFitter();
            if (this$fitter == null ? other$fitter != null : !this$fitter.equals(other$fitter)) {
                return false;
            }
            Object this$params = this.getParams();
            Object other$params = other.getParams();
            return !(this$params == null ? other$params != null : !this$params.equals(other$params));
        }

        @Generated
        protected boolean canEqual(Object other) {
            return other instanceof FitterAndParams;
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            CurveFitter $fitter = this.getFitter();
            result = result * 59 + ($fitter == null ? 43 : $fitter.hashCode());
            Object $params = this.getParams();
            result = result * 59 + ($params == null ? 43 : $params.hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "FitCurveFunction.FitterAndParams(fitter=" + String.valueOf(this.getFitter()) + ", params=" + String.valueOf(this.getParams()) + ")";
        }
    }

    private static final class AccumImpl
    implements Accumulator {
        final Struct resultType;
        final List<FitterAndParams> fittersAndParams;
        final RealizedExpression xExpression;
        final RealizedExpression yExpression;
        List<Double> xValues = SegmentedList.forClass(Double.class);
        List<Double> yValues = SegmentedList.forClass(Double.class);

        public Accumulator combine(Accumulator rhs) {
            AccumImpl impl = (AccumImpl)rhs;
            this.xValues.addAll(impl.xValues);
            this.yValues.addAll(impl.yValues);
            return this;
        }

        public void accumulate(Object input) {
            Number xValue = (Number)this.xExpression.evaluate(input);
            Number yValue = (Number)this.yExpression.evaluate(input);
            if (xValue == null || yValue == null) {
                return;
            }
            this.xValues.add(xValue.doubleValue());
            this.yValues.add(yValue.doubleValue());
        }

        public Object process() {
            ObservedPoints points = new ObservedPoints(this.xValues, this.yValues);
            FitCurve<?> winning = null;
            ArrayList<Problem> problems = new ArrayList<Problem>();
            for (FitterAndParams fitter : this.fittersAndParams) {
                FitCurve<?> curve;
                try {
                    curve = this.doFit(fitter, points);
                }
                catch (CannotFitCurveException e) {
                    problems.add(PROBLEMS.fittingFailed(fitter.fitter.getId()).withChildren(new Problems[]{Problems.caught((Throwable)((Object)e))}));
                    continue;
                }
                if (winning != null && !(curve.getFitScore() > winning.getFitScore())) continue;
                winning = curve;
            }
            if (winning == null) {
                throw new CannotFitCurveException(PROBLEMS.nothingFits().withChildren(problems));
            }
            return Tuple.ofValues((Struct)this.resultType, (Object[])new Object[]{winning.getFunction(), winning.getFitter().getId(), winning.getFitScore()});
        }

        private <T> FitCurve<?> doFit(FitterAndParams fitter, ObservedPoints points) {
            return fitter.fit(points);
        }

        public boolean isEmpty() {
            return this.xValues.isEmpty();
        }

        @Generated
        public AccumImpl(Struct resultType, List<FitterAndParams> fittersAndParams, RealizedExpression xExpression, RealizedExpression yExpression) {
            this.resultType = resultType;
            this.fittersAndParams = fittersAndParams;
            this.xExpression = xExpression;
            this.yExpression = yExpression;
        }
    }

    public static interface LocalProblems
    extends ProblemFactory {
        public Problem fittingFailed(String var1);

        public Problem nothingFits();
    }
}

