/*
 * Decompiled with CFR 0.152.
 */
package nz.org.riskscape.engine.sched;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import nz.org.riskscape.engine.join.LopsidedJoiner;
import nz.org.riskscape.engine.output.PipelineJobContext;
import nz.org.riskscape.engine.output.SinkParameters;
import nz.org.riskscape.engine.pipeline.Collector;
import nz.org.riskscape.engine.pipeline.Realized;
import nz.org.riskscape.engine.pipeline.RealizedPipeline;
import nz.org.riskscape.engine.pipeline.RealizedStep;
import nz.org.riskscape.engine.pipeline.Sink;
import nz.org.riskscape.engine.pipeline.SinkConstructor;
import nz.org.riskscape.engine.pipeline.TupleInput;
import nz.org.riskscape.engine.projection.AsyncProjector;
import nz.org.riskscape.engine.projection.FlatProjector;
import nz.org.riskscape.engine.projection.Projector;
import nz.org.riskscape.engine.relation.Relation;
import nz.org.riskscape.engine.restriction.Restrictor;
import nz.org.riskscape.engine.sched.AccumulatorSink;
import nz.org.riskscape.engine.sched.CombinedPageBuffer;
import nz.org.riskscape.engine.sched.LinkedListBuffer;
import nz.org.riskscape.engine.sched.LopsidedJoinAdapter;
import nz.org.riskscape.engine.sched.PageAllocator;
import nz.org.riskscape.engine.sched.SchedulerParams;
import nz.org.riskscape.engine.task.AccumulatorProcessorTask;
import nz.org.riskscape.engine.task.ChainTask;
import nz.org.riskscape.engine.task.LinkedSteps;
import nz.org.riskscape.engine.task.PageBuffer;
import nz.org.riskscape.engine.task.ReadPageBuffer;
import nz.org.riskscape.engine.task.SinkTask;
import nz.org.riskscape.engine.task.TaskSpec;
import nz.org.riskscape.engine.task.TupleInputTask;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.util.ListUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TaskBuilder {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(TaskBuilder.class);
    private static final boolean SINGLE_THREADED = false;
    private static final boolean MULTI_THREADED = true;
    private final SchedulerParams params;

    private ArrayList<RealizedStep> groupSteps(RealizedPipeline pipeline, RealizedStep step) {
        ArrayList<RealizedStep> grouped = new ArrayList<RealizedStep>();
        while (this.canChain(step) && (step.getDependencies().size() == 1 || grouped.isEmpty())) {
            grouped.add(step);
            Set dependents = pipeline.getDependents(step);
            if (dependents.size() != 1) break;
            step = (RealizedStep)dependents.iterator().next();
        }
        if (grouped.isEmpty()) {
            grouped.add(step);
        }
        return grouped;
    }

    private boolean canChain(RealizedStep step) {
        Class type = step.getStepType();
        return type == Projector.class || type == FlatProjector.class || type == Restrictor.class;
    }

    public boolean canChain(List<RealizedStep> steps) {
        return steps.stream().allMatch(step -> this.canChain((RealizedStep)step));
    }

    List<LinkedSteps> decompose(RealizedPipeline pipeline) {
        ArrayList<LinkedSteps> decomposed = new ArrayList<LinkedSteps>();
        LinkedList<RealizedStep> toVisit = new LinkedList<RealizedStep>();
        HashSet<RealizedStep> seen = new HashSet<RealizedStep>();
        toVisit.addAll(pipeline.getStartSteps());
        while (!toVisit.isEmpty()) {
            RealizedStep step = (RealizedStep)toVisit.remove(0);
            if (!seen.containsAll(step.getDependencies())) {
                toVisit.addAll(0, step.getDependencies());
                toVisit.add(step);
                continue;
            }
            if (seen.contains(step)) continue;
            ArrayList<RealizedStep> grouped = this.groupSteps(pipeline, step);
            List<LinkedSteps> predecessors = this.findPredecessors(step, decomposed);
            decomposed.add(new LinkedSteps(predecessors, grouped));
            seen.addAll(grouped);
            RealizedStep lastStep = grouped.get(grouped.size() - 1);
            toVisit.addAll(pipeline.getDependents(lastStep));
        }
        return decomposed;
    }

    private List<LinkedSteps> findPredecessors(RealizedStep step, ArrayList<LinkedSteps> decomposed) {
        ArrayList<LinkedSteps> collected = new ArrayList<LinkedSteps>();
        for (LinkedSteps linkedSteps : decomposed) {
            RealizedStep lastStep = linkedSteps.getLastStep();
            if (!step.getDependencies().contains(lastStep)) continue;
            collected.add(linkedSteps);
        }
        return collected;
    }

    public List<TaskSpec> convertToTasks(PipelineJobContext context) {
        return this.convertToTasks(context, this.decompose(context.getPipeline()));
    }

    List<TaskSpec> convertToTasks(PipelineJobContext context, List<LinkedSteps> steps) {
        TaskBuildingState state = new TaskBuildingState(context, steps);
        ArrayList<TaskSpec> tasks = new ArrayList<TaskSpec>(steps.size());
        for (LinkedSteps step : steps) {
            List<TaskSpec> newTasks = this.toTasks(state, step);
            state.built.put(step, newTasks);
            tasks.addAll(newTasks);
            if (!step.getDescendants().isEmpty() || step.containsOnly(SinkConstructor.class)) continue;
            final RealizedStep terminalStep = step.getLastStep();
            RealizedStep fakeRealizedStep = RealizedStep.named((String)(terminalStep.getStepName() + "-capped")).withResult((Realized)new SinkConstructor(){

                public ResultOrProblems<Sink> newInstance(PipelineJobContext context) {
                    return context.getOutputContainer().createSinkForStep(new SinkParameters(terminalStep.getStepName(), terminalStep.getProduces()));
                }
            }).withDependencies(new RealizedStep[]{terminalStep});
            LinkedSteps fakeLinkedSteps = new LinkedSteps(Arrays.asList(step), Arrays.asList(fakeRealizedStep));
            tasks.add(this.toSinkTask(state, fakeLinkedSteps));
        }
        if (tasks.size() > 0) {
            log.info("Scheduler adding {} tasks, {} tuples per page, {} tuples per task", new Object[]{tasks.size(), state.getPageSize(), state.numTuplesPerBuffer()});
        }
        return tasks;
    }

    List<TaskSpec> toTasks(TaskBuildingState state, LinkedSteps convert) {
        if (convert.containsOnly(Relation.class)) {
            return Arrays.asList(this.toRelationTask(state, convert));
        }
        if (this.canChain(convert.getGrouped())) {
            return Arrays.asList(this.toChainTask(state, convert));
        }
        if (convert.containsOnly(Collector.class)) {
            return this.toCollectorTasks(state, convert);
        }
        if (convert.containsOnly(SinkConstructor.class)) {
            return Arrays.asList(this.toSinkTask(state, convert));
        }
        if (convert.containsOnly(LopsidedJoiner.class)) {
            return this.toJoinerTasks(state, convert);
        }
        if (convert.containsOnly(AsyncProjector.class)) {
            return this.toAsyncProjectorTask(state, convert);
        }
        throw new RuntimeException("Could not convert steps into tasks: " + String.valueOf(convert));
    }

    private List<TaskSpec> toAsyncProjectorTask(TaskBuildingState state, LinkedSteps steps) {
        RealizedStep collectorStep = steps.getFirstStep();
        AsyncProjector collector = (AsyncProjector)collectorStep.getRealized(AsyncProjector.class).get();
        SinkConstructor sinkConstructor = collector.getOutput();
        ReadPageBuffer inputBuffer = state.getInputFor(steps);
        TaskSpec sinkTask = new TaskSpec(SinkTask.class, List.of(collectorStep.withResult((Realized)sinkConstructor)), inputBuffer, null, false, state.context);
        TupleInput input = collector.getInput();
        TaskSpec inputTask = new TaskSpec(TupleInputTask.class, List.of(collectorStep.withResult((Realized)input)), null, this.newOutputBuffer(state), false, state.context);
        return List.of(sinkTask, inputTask);
    }

    private LinkedListBuffer newOutputBuffer(TaskBuildingState state) {
        PageAllocator allocator = new PageAllocator(1).doubleEvery(this.params.getMaxThreadsPerTask()).withMaxPageSize(state.getPageSize());
        return new LinkedListBuffer(allocator, state.numTuplesPerBuffer());
    }

    private <T> List<TaskSpec> toJoinerTasks(TaskBuildingState state, LinkedSteps steps) {
        LopsidedJoiner joiner = (LopsidedJoiner)steps.getFirstStep().getResult().get();
        LopsidedJoinAdapter adapter = new LopsidedJoinAdapter(joiner);
        SinkConstructor sinkConstructor = adapter.newSinkConstructor();
        ReadPageBuffer rhsInputBuffer = state.findInputForSteps(steps, 1);
        RealizedStep sinkStep = steps.getFirstStep().withResult((Realized)sinkConstructor);
        TaskSpec sinkTask = new TaskSpec(SinkTask.class, List.of(sinkStep), rhsInputBuffer, null, false, state.context);
        FlatProjector flatProjector = adapter.newFlatProjector();
        ReadPageBuffer lhsInputBuffer = state.findInputForSteps(steps, 0);
        LinkedListBuffer outputBuffer = this.newOutputBuffer(state);
        RealizedStep projectorStep = steps.getFirstStep().withResult((Realized)flatProjector);
        TaskSpec chainTask = new TaskSpec(ChainTask.class, List.of(projectorStep), lhsInputBuffer, outputBuffer, true, state.context);
        chainTask.addDependency(sinkTask);
        return Arrays.asList(sinkTask, chainTask);
    }

    private TaskSpec toSinkTask(TaskBuildingState state, LinkedSteps steps) {
        return new TaskSpec(SinkTask.class, steps.getGrouped(), state.getInputFor(steps), null, false, state.context);
    }

    private List<TaskSpec> toCollectorTasks(TaskBuildingState state, LinkedSteps steps) {
        TaskSpec singleTask;
        Collector collector = (Collector)steps.getFirstStep().getRealized(Collector.class).get();
        if (steps.getPredecessors().size() != 1) {
            throw new IllegalStateException("A collector should never be proceeded by more than one step");
        }
        LinkedSteps predecessor = steps.getPredecessors().get(0);
        TaskSpec processorSpec = new TaskSpec(AccumulatorProcessorTask.class, steps.getGrouped(), null, this.newOutputBuffer(state), false, state.context);
        List<TaskSpec> builtForPredecessor = state.built.get(predecessor);
        if (builtForPredecessor.size() == 1 && (singleTask = builtForPredecessor.get(0)).getLastStep().getRealized(AccumulatorSink.Constructor.class).isPresent()) {
            processorSpec.addDependency(singleTask);
            return Arrays.asList(processorSpec);
        }
        boolean parallelizable = collector.getCharacteristics().contains(Collector.Characteristic.PARALLELIZABLE);
        AccumulatorSink.Constructor sinkConstructor = new AccumulatorSink.Constructor(Collections.singletonList(collector));
        RealizedStep originalStep = steps.getFirstStep();
        RealizedStep fauxStep = originalStep.withResult((Realized)sinkConstructor);
        TaskSpec accumulatorSpec = new TaskSpec(SinkTask.class, Collections.singletonList(fauxStep), state.getInputFor(steps), null, parallelizable, state.context);
        processorSpec.addDependency(accumulatorSpec);
        return Arrays.asList(accumulatorSpec, processorSpec);
    }

    private TaskSpec toRelationTask(TaskBuildingState state, LinkedSteps steps) {
        Relation relation = (Relation)steps.getFirstStep().getRealized(Relation.class).get();
        RealizedStep asTupleInputStep = steps.getFirstStep().withResult((Realized)relation.toTupleInput());
        return new TaskSpec(TupleInputTask.class, List.of(asTupleInputStep), null, this.newOutputBuffer(state), false, state.context);
    }

    private ReadPageBuffer getChainInput(TaskBuildingState state, LinkedSteps steps) {
        int numInputs = steps.getFirstStep().getDependencies().size();
        if (numInputs > 1) {
            ArrayList<ReadPageBuffer> inputs = new ArrayList<ReadPageBuffer>();
            for (int i = 0; i < numInputs; ++i) {
                inputs.add(state.findInputForSteps(steps, i));
            }
            return new CombinedPageBuffer(inputs);
        }
        return state.getInputFor(steps);
    }

    private TaskSpec toChainTask(TaskBuildingState state, LinkedSteps steps) {
        List groupedSteps;
        List<Collector<?>> parallelCollectorSteps = this.findFollowedOnlyByParallelizableCollectors(steps);
        LinkedListBuffer outputBuffer = null;
        if (!parallelCollectorSteps.isEmpty()) {
            RealizedStep newSinkStep = this.createParallelizedCollectorStep(steps.getLastStep(), parallelCollectorSteps);
            groupedSteps = ListUtils.append(steps.getGrouped(), (Object)newSinkStep);
        } else {
            groupedSteps = steps.getGrouped();
            outputBuffer = this.newOutputBuffer(state);
        }
        return new TaskSpec(ChainTask.class, groupedSteps, this.getChainInput(state, steps), outputBuffer, true, state.context);
    }

    private RealizedStep createParallelizedCollectorStep(RealizedStep follows, List<Collector<?>> collectors) {
        AccumulatorSink.Constructor constructor = new AccumulatorSink.Constructor(collectors);
        return RealizedStep.named((String)(follows.getStepName() + "-sink")).withResult((Realized)constructor).withDependencies(new RealizedStep[]{follows});
    }

    private List<Collector<?>> findFollowedOnlyByParallelizableCollectors(LinkedSteps steps) {
        List<LinkedSteps> descendants = steps.getDescendants();
        List<Collector<?>> collectors = descendants.stream().filter(ls -> ls.getGrouped().size() == 1).map(ls -> ls.getGrouped().get(0).getRealized(Collector.class)).filter(Optional::isPresent).map(opt -> (Collector)opt.get()).filter(Collector::isParallelizable).collect(Collectors.toList());
        return collectors.size() == descendants.size() ? collectors : Collections.emptyList();
    }

    @Generated
    public TaskBuilder(SchedulerParams params) {
        this.params = params;
    }

    private class TaskBuildingState {
        Map<LinkedSteps, List<TaskSpec>> built = new HashMap<LinkedSteps, List<TaskSpec>>();
        final int estimatedBuffers;
        final PipelineJobContext context;

        TaskBuildingState(PipelineJobContext context, List<LinkedSteps> steps) {
            this.context = context;
            int numSinks = (int)steps.stream().filter(step -> step.containsOnly(SinkConstructor.class)).count();
            this.estimatedBuffers = steps.size() - numSinks;
        }

        private PageBuffer getOutputBuffer(LinkedSteps steps) {
            List<TaskSpec> builtTasks = this.built.get(steps);
            TaskSpec lastTask = builtTasks.get(builtTasks.size() - 1);
            return (PageBuffer)lastTask.getOutput().get();
        }

        public ReadPageBuffer findInputForSteps(LinkedSteps steps, int index) {
            boolean isFirstDescendant;
            LinkedSteps predecessor = steps.getOrderedPredecessor(index);
            PageBuffer buffer = this.getOutputBuffer(predecessor);
            boolean bl = isFirstDescendant = predecessor.getDescendants().indexOf(steps) == 0;
            if (predecessor.hasFanOut() && !isFirstDescendant) {
                return buffer.newReaderClone();
            }
            return buffer;
        }

        public ReadPageBuffer getInputFor(LinkedSteps steps) {
            return this.findInputForSteps(steps, 0);
        }

        public int numTuplesPerBuffer() {
            return TaskBuilder.this.params.getTuplesPerTask().orElse(TaskBuilder.this.params.getMaxTuplesQueued() / this.estimatedBuffers);
        }

        public int getPageSize() {
            int defaultPageSize = Math.max(1, this.numTuplesPerBuffer() / (TaskBuilder.this.params.getMaxThreadsPerTask() * 2));
            return TaskBuilder.this.params.getPageSize().orElse(defaultPageSize);
        }

        @Generated
        public TaskBuildingState(int estimatedBuffers, PipelineJobContext context) {
            this.estimatedBuffers = estimatedBuffers;
            this.context = context;
        }
    }
}

