/*
 * Decompiled with CFR 0.152.
 */
package nz.org.riskscape.wizard.model2.smp;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import nz.org.riskscape.engine.pipeline.RealizedStep;
import nz.org.riskscape.engine.query.TupleUtils;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.pipeline.ast.PipelineDeclaration;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.wizard.Question;
import nz.org.riskscape.wizard.QuestionSet;
import nz.org.riskscape.wizard.Survey;
import nz.org.riskscape.wizard.bld.IncrementalBuildState;
import nz.org.riskscape.wizard.bld.PipelineChange;
import nz.org.riskscape.wizard.bld.change.AppendJoinChange;
import nz.org.riskscape.wizard.bld.dsl.IncompletePipelineChange;
import nz.org.riskscape.wizard.bld.loc.EndOfBranchWith;
import nz.org.riskscape.wizard.model2.input.InputDataPhase;
import nz.org.riskscape.wizard.model2.smp.SampleType;
import nz.org.riskscape.wizard.survey2.BasePhase;
import nz.org.riskscape.wizard.survey2.Choices;
import nz.org.riskscape.wizard.survey2.DefaultQuestionSet2;
import org.locationtech.jts.geom.Point;

public class SamplePhase
extends BasePhase {
    private static final String QS_SAMPLING = "sample";
    private static final String SAMPLE_TYPE_QID = "-by";
    public static final String SAMPLED_HAZARD = InputDataPhase.HAZARD_ATTRIBUTE + "_sampled";
    public static final String EXPOSED_RATIO = "exposed_ratio";
    public static final String WELL_KNOWN_STEP = "sampled";
    private static final Question.I18nLookupFunction I18N_LOOKUP = (q, suffix, locale) -> {
        String datasetName = q.getName().split("-")[0];
        String genericName = q.getName().replaceAll("^" + datasetName + "-", "");
        String i18nKey = "question.sample." + genericName + "." + suffix;
        String i18nMsg = q.getMessageSource().getMessage(i18nKey, new Object[]{datasetName}, locale);
        return Optional.ofNullable(i18nMsg);
    };
    private static final List<SampleType> ONLY_CENTROID = Arrays.asList(SampleType.CENTROID);

    public SamplePhase(Survey survey) {
        super(survey);
    }

    @Override
    protected QuestionSet buildQuestionSet(IncrementalBuildState buildState) {
        DefaultQuestionSet2 questions = new DefaultQuestionSet2(QS_SAMPLING, this);
        questions.setDefaultLocation(InputDataPhase.MAIN_BRANCH);
        List<RealizedStep> startSteps = this.applicableStepStream(buildState).collect(Collectors.toList());
        this.sortSamplingInputs(buildState, startSteps);
        Struct exposureType = buildState.getInputStruct(InputDataPhase.MAIN_BRANCH).getMember(InputDataPhase.EXPOSURE_ATTRIBUTE).flatMap(sm -> sm.getType().findAllowNull(Struct.class)).orElse(Struct.EMPTY_STRUCT);
        Class exposureGeometryType = Optional.ofNullable(TupleUtils.findGeometryMember((Struct)exposureType, (TupleUtils.FindOption)TupleUtils.FindOption.OPTIONAL)).map(member -> member.getType().internalType()).orElse(null);
        for (RealizedStep step : startSteps) {
            String datasetName = InputDataPhase.getDatasetName(step);
            String coverage = InputDataPhase.getAttributeName(datasetName);
            if (InputDataPhase.isHazardDataset(datasetName)) {
                List<SampleType> sampleTypeOptions = this.getSampleTypeOptions(exposureGeometryType, datasetName);
                questions.addQuestion(datasetName + SAMPLE_TYPE_QID, SampleType.class).customizeQuestion(q -> q.askWhen(ibs -> sampleTypeOptions != ONLY_CENTROID)).customizeQuestion(q -> q.withI18nLookup(I18N_LOOKUP)).customizeQuestion(q -> q.withChoices(ibs -> Choices.forEnums(sampleTypeOptions))).thenNoChange();
            }
            questions.addQuestion(datasetName + "-buffer", Expression.class).customizeQuestion(q -> q.withI18nLookup(I18N_LOOKUP).optionalOne()).customizeQuestion(q -> q.askWhen(ibs -> this.getSampleType((IncrementalBuildState)ibs, datasetName) != SampleType.CENTROID)).then((input, distance) -> {
                SampleType type = this.getSampleType(input.getBuildState(), datasetName);
                return this.joinBranchWithSample(datasetName, this.getSampleStep(coverage, type, (Expression)distance));
            });
            questions.addHiddenQuestion(datasetName + "-finish").ifNotAnswered(datasetName + "-buffer").then(input -> {
                SampleType type = this.getSampleType(input.getBuildState(), datasetName);
                return this.joinBranchWithSample(datasetName, this.getSampleStep(coverage, type));
            });
        }
        questions.addHiddenQuestion("append-name").then(input -> PipelineChange.chainStep("select({*}) as %s", WELL_KNOWN_STEP));
        return questions;
    }

    private List<SampleType> getSampleTypeOptions(Class<?> geometryType, String datasetName) {
        if (geometryType != null) {
            if (geometryType.equals(Point.class) || !InputDataPhase.isHazardDataset(datasetName)) {
                return Arrays.asList(SampleType.CENTROID, SampleType.CLOSEST);
            }
            return Arrays.asList(SampleType.CENTROID, SampleType.CLOSEST, SampleType.ALL_INTERSECTIONS);
        }
        return ONLY_CENTROID;
    }

    private SampleType getSampleType(IncrementalBuildState buildState, String datasetName) {
        return buildState.getResponse(QS_SAMPLING, datasetName + SAMPLE_TYPE_QID, SampleType.class).orElse(SampleType.CLOSEST);
    }

    private IncompletePipelineChange joinBranchWithSample(String datasetName, String sampleStep) {
        String joinStepName = String.format("exposures_join_%s", datasetName);
        return answer -> AppendJoinChange.builder(answer).append("join(on: true).lhs as %s -> %s", joinStepName, sampleStep).joins(EndOfBranchWith.stepNamed(datasetName), joinStepName + ".rhs", new Object[0]).build();
    }

    private String getSampleStep(String coverageAttribute, SampleType sampleType) {
        String sampleFunctionCall = sampleType.getFunctionCall(InputDataPhase.EXPOSURE_ATTRIBUTE, coverageAttribute);
        if (sampleType == SampleType.ALL_INTERSECTIONS) {
            boolean isHazardLayer = InputDataPhase.HAZARD_ATTRIBUTE.equals(coverageAttribute);
            String sampledGeom = isHazardLayer ? SAMPLED_HAZARD : coverageAttribute;
            String exposedRatioStep = "";
            if (isHazardLayer) {
                exposedRatioStep = String.format("-> select({*, map(%s, geom -> measure(geom) / measure(exposure)) as %s})", SAMPLED_HAZARD, EXPOSED_RATIO);
            }
            return String.format("select({*, %s as %s}) as \"sample_%s_layer\" -> select({*, map(%s, h -> h.sampled) as %s}) %s", sampleFunctionCall, sampledGeom, coverageAttribute, sampledGeom, coverageAttribute, exposedRatioStep);
        }
        return String.format("select({*, %s as %s}) as \"sample_%s_layer\"", sampleFunctionCall, coverageAttribute, coverageAttribute);
    }

    private String getSampleStep(String coverageAttribute, SampleType sampleType, Expression bufferDistance) {
        if (sampleType == SampleType.ALL_INTERSECTIONS) {
            String bufferedGeom = String.format("buffer(%s, %s)", InputDataPhase.EXPOSURE_ATTRIBUTE, bufferDistance.toSource());
            boolean isHazardLayer = InputDataPhase.HAZARD_ATTRIBUTE.equals(coverageAttribute);
            String sampledGeom = isHazardLayer ? SAMPLED_HAZARD : coverageAttribute;
            return String.format("select({*, %s as %s}) -> select({*, map(%s, h -> h.sampled) as %s})", sampleType.getFunctionCall(bufferedGeom, coverageAttribute), sampledGeom, sampledGeom, coverageAttribute);
        }
        return String.format("select({*, %s as %s})", sampleType.getFunctionCall(InputDataPhase.EXPOSURE_ATTRIBUTE, coverageAttribute, bufferDistance.toSource()), coverageAttribute);
    }

    private void sortSamplingInputs(IncrementalBuildState buildState, List<RealizedStep> startSteps) {
        PipelineDeclaration decl = buildState.getAst();
        startSteps.sort((lhs, rhs) -> {
            boolean rhsIsHazard;
            boolean lhsIsHazard = InputDataPhase.isHazardDataset(InputDataPhase.getDatasetName(lhs));
            if (lhsIsHazard != (rhsIsHazard = InputDataPhase.isHazardDataset(InputDataPhase.getDatasetName(rhs)))) {
                return Boolean.compare(rhsIsHazard, lhsIsHazard);
            }
            return Integer.compare(((PipelineDeclaration.Found)decl.findDefinition(lhs.getStepName()).get()).getChainIndex(), ((PipelineDeclaration.Found)decl.findDefinition(rhs.getStepName()).get()).getChainIndex());
        });
    }

    private Stream<RealizedStep> applicableStepStream(IncrementalBuildState buildState) {
        List<String> toSample = InputDataPhase.getDatasetsToSample(buildState);
        return buildState.getRealizedPipeline().getEndSteps().stream().filter(step -> toSample.contains(step.getStepName()));
    }
}

