/*
 * Decompiled with CFR 0.152.
 */
package nz.org.riskscape.engine.steps;

import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import lombok.Generated;
import nz.org.riskscape.dsl.Token;
import nz.org.riskscape.engine.Engine;
import nz.org.riskscape.engine.IdentifiedException;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.bind.ParameterField;
import nz.org.riskscape.engine.function.IdentifiedFunction;
import nz.org.riskscape.engine.pipeline.Collector;
import nz.org.riskscape.engine.pipeline.RealizationInput;
import nz.org.riskscape.engine.pipeline.Realized;
import nz.org.riskscape.engine.pipeline.RealizedStep;
import nz.org.riskscape.engine.relation.TupleIterator;
import nz.org.riskscape.engine.rl.ExpressionRealizer;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.rl.RealizedExpression;
import nz.org.riskscape.engine.rl.agg.Accumulator;
import nz.org.riskscape.engine.rl.agg.AggregateExpressionRealizer;
import nz.org.riskscape.engine.rl.agg.RealizedAggregateExpression;
import nz.org.riskscape.engine.steps.BaseStep;
import nz.org.riskscape.engine.steps.Input;
import nz.org.riskscape.engine.types.DuplicateKeysException;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.TypeProblems;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.problem.Problem;
import nz.org.riskscape.problem.ProblemException;
import nz.org.riskscape.problem.Problems;
import nz.org.riskscape.problem.ResultOrProblems;
import nz.org.riskscape.rl.ExpressionParser;
import nz.org.riskscape.rl.ast.BinaryOperation;
import nz.org.riskscape.rl.ast.BracketedExpression;
import nz.org.riskscape.rl.ast.Expression;
import nz.org.riskscape.rl.ast.ExpressionConverter;
import nz.org.riskscape.rl.ast.ExpressionProblems;
import nz.org.riskscape.rl.ast.ExpressionVisitor;
import nz.org.riskscape.rl.ast.FunctionCall;
import nz.org.riskscape.rl.ast.Lambda;
import nz.org.riskscape.rl.ast.ListDeclaration;
import nz.org.riskscape.rl.ast.MinimalVisitor;
import nz.org.riskscape.rl.ast.PropertyAccess;
import nz.org.riskscape.rl.ast.SelectAllExpression;
import nz.org.riskscape.rl.ast.StructDeclaration;

public class GroupByStep
extends BaseStep<Params> {
    private static final StructDeclaration GROUP_BY_NOTHING = new StructDeclaration(Collections.emptyList(), Optional.empty());

    @Override
    public String getId() {
        return "group";
    }

    public GroupByStep(Engine engine) {
        super(engine);
    }

    @Override
    public ResultOrProblems<? extends Realized> realize(Params parameters) {
        return ProblemException.catching(() -> {
            Struct processInputType;
            Struct inputType = parameters.input.getProduces();
            RealizationContext context = parameters.rInput.getExecutionContext().getRealizationContext();
            ExpressionRealizer realizer = context.getExpressionRealizer();
            StructDeclaration groupBy = parameters.by.map(expr -> ExpressionParser.INSTANCE.toStruct(expr)).orElse(GROUP_BY_NOTHING);
            RealizedExpression rGroupBy = (RealizedExpression)realizer.realize((Type)inputType, (Expression)groupBy).getOrThrow(ExpressionProblems.get().failedToRealize((Expression)groupBy, (Type)inputType));
            LinkedHashMap<FunctionCall, String> aggregations = new LinkedHashMap<FunctionCall, String>();
            ArrayList<Problem> findAndReplaceErrors = new ArrayList<Problem>();
            Expression withGroupsReplaced = (Expression)parameters.select.accept((ExpressionVisitor)new FindAndReplaceExpressions(rGroupBy), null);
            Expression rewritten = (Expression)withGroupsReplaced.accept((ExpressionVisitor)new FindAndReplaceAggregations(context, aggregations, findAndReplaceErrors), null);
            if (!findAndReplaceErrors.isEmpty()) {
                throw new ProblemException((Problems)ExpressionProblems.get().failedToRealize(parameters.select, (Type)inputType).withChildren(findAndReplaceErrors));
            }
            StructDeclaration toAggregate = new StructDeclaration(aggregations.entrySet().stream().map(entry -> StructDeclaration.jsonStyleMember((String)((String)entry.getValue()), (Expression)((Expression)entry.getKey()))).collect(Collectors.toList()), Optional.empty());
            RealizedAggregateExpression rAggExpression = (RealizedAggregateExpression)realizer.realizeAggregate((Type)inputType, (Expression)toAggregate).getOrThrow();
            Struct aggregateResult = (Struct)rAggExpression.getResultType().find(Struct.class).get();
            try {
                processInputType = context.normalizeStruct(((Struct)rGroupBy.getResultType().find(Struct.class).get()).and(aggregateResult));
            }
            catch (DuplicateKeysException ex) {
                throw new ProblemException((Problems)TypeProblems.get().duplicateKeys(ex.getDuplicates()));
            }
            ArrayList ungrouped = new ArrayList();
            rewritten.accept((ExpressionVisitor)new CheckForUngroupedPropertyAccess(inputType, processInputType), ungrouped);
            rewritten = (Expression)rewritten.accept((ExpressionVisitor)new WildcardReplacer(((Struct)rGroupBy.getResultType().find(Struct.class).get()).getMembers()), null);
            if (!ungrouped.isEmpty()) {
                throw new ProblemException((Problems)ExpressionProblems.get().failedToRealize(parameters.select, (Type)inputType).withChildren(ungrouped.stream().map(expr -> ExpressionProblems.get().propertyOutsideOfAggregationFunction(expr)).collect(Collectors.toList())));
            }
            RealizedExpression processExpression = realizer.asStruct(context, (RealizedExpression)realizer.realize((Type)processInputType, rewritten).getOrThrow());
            return new Instance(inputType, processExpression, () -> new AccumInstance(rGroupBy, rAggExpression));
        });
    }

    public static class Params {
        @ParameterField
        public Expression select;
        @ParameterField
        public Optional<Expression> by;
        @Input
        public RealizedStep input;
        public RealizationInput rInput;
    }

    private static class FindAndReplaceExpressions
    extends ExpressionConverter<StructDeclaration> {
        private final List<Pair<Expression, String>> references;

        FindAndReplaceExpressions(RealizedExpression groupBy) {
            List memberExpressions = groupBy.getDependencies();
            Iterator members = ((Struct)groupBy.getResultType().find(Struct.class).get()).getMembers().iterator();
            this.references = new ArrayList<Pair<Expression, String>>();
            for (RealizedExpression childExpression : memberExpressions) {
                if (childExpression.getExpression().isA(SelectAllExpression.class).isPresent() || childExpression.getExpression().isA(PropertyAccess.class).map(pa -> pa.isTrailingSelectAll()).orElse(false).booleanValue()) {
                    ((Struct)childExpression.getResultType().find(Struct.class).get()).getMembers().forEach(m -> members.next());
                    continue;
                }
                this.references.add((Pair<Expression, String>)Pair.of((Object)childExpression.getExpression(), (Object)((Struct.StructMember)members.next()).getKey()));
            }
        }

        public Expression visit(BinaryOperation expression, StructDeclaration data) {
            PropertyAccess found = this.findInGroup((Expression)expression);
            if (found == null) {
                return super.visit(expression, (Object)data);
            }
            return found;
        }

        public Expression visit(BracketedExpression expression, StructDeclaration data) {
            PropertyAccess found = this.findInGroup((Expression)expression);
            if (found == null) {
                return super.visit(expression, (Object)data);
            }
            return found;
        }

        public Expression visit(FunctionCall expression, StructDeclaration data) {
            PropertyAccess found = this.findInGroup((Expression)expression);
            if (found == null) {
                return super.visit(expression, (Object)data);
            }
            return found;
        }

        public Expression visit(Lambda expression, StructDeclaration data) {
            PropertyAccess found = this.findInGroup((Expression)expression);
            if (found == null) {
                return super.visit(expression, (Object)data);
            }
            return found;
        }

        public Expression visit(ListDeclaration expression, StructDeclaration data) {
            PropertyAccess found = this.findInGroup((Expression)expression);
            if (found == null) {
                return super.visit(expression, (Object)data);
            }
            return found;
        }

        public Expression visit(PropertyAccess expression, StructDeclaration data) {
            PropertyAccess found = this.findInGroup((Expression)expression);
            if (found == null) {
                return super.visit(expression, (Object)data);
            }
            return found;
        }

        public Expression visit(StructDeclaration expression, StructDeclaration data) {
            PropertyAccess found = this.findInGroup((Expression)expression);
            if (found == null) {
                return super.visit(expression, (Object)data);
            }
            return found;
        }

        private PropertyAccess findInGroup(Expression expression) {
            for (Pair<Expression, String> pair : this.references) {
                if (!expression.equals(pair.getLeft())) continue;
                return PropertyAccess.of((String[])new String[]{(String)pair.getRight()});
            }
            return null;
        }
    }

    private static class FindAndReplaceAggregations
    extends ExpressionConverter<StructDeclaration> {
        private final RealizationContext context;
        private final Map<FunctionCall, String> collected;
        private final List<Problem> errors;

        public Expression visit(FunctionCall expression, StructDeclaration parent) {
            IdentifiedFunction idf;
            try {
                idf = (IdentifiedFunction)this.context.getProject().getFunctionSet().get(expression.getIdentifier().getValue(), this.context.getProblemSink());
            }
            catch (IdentifiedException ex) {
                this.errors.add(Problems.caught((Throwable)ex));
                return expression;
            }
            boolean isAggFunction = idf.getAggregationFunction().isPresent();
            if (isAggFunction) {
                String ident = this.collected.get(expression);
                if (ident == null) {
                    ident = AggregateExpressionRealizer.getImplicitName(expression, this.collected.values());
                    if (parent != null) {
                        ident = parent.getMembers().stream().filter(attr -> attr.getExpression() == expression).findFirst().flatMap(attr -> Optional.ofNullable(attr.getIdentifier()).map(Token::getValue)).orElse(ident);
                    }
                    ident = ExpressionRealizer.makeUnique((String)ident, this.collected.values());
                }
                this.collected.put(expression, ident);
                return PropertyAccess.of((String[])new String[]{ident});
            }
            return super.visit(expression, (Object)parent);
        }

        public Expression visit(StructDeclaration expression, StructDeclaration data) {
            return super.visit(expression, (Object)expression);
        }

        @Generated
        public FindAndReplaceAggregations(RealizationContext context, Map<FunctionCall, String> collected, List<Problem> errors) {
            this.context = context;
            this.collected = collected;
            this.errors = errors;
        }
    }

    private static class CheckForUngroupedPropertyAccess
    extends MinimalVisitor<List<PropertyAccess>> {
        private final Struct inputType;
        private final Struct processType;

        public List<PropertyAccess> visit(PropertyAccess expression, List<PropertyAccess> data) {
            if (expression.getReceiver().isPresent()) {
                return (List)super.visit(expression, data);
            }
            String first = expression.getFirstIdentifier().getValue();
            if (!this.processType.hasMember(first) && this.inputType.hasMember(first)) {
                data.add(expression);
            }
            return (List)super.visit(expression, data);
        }

        public List<PropertyAccess> visit(Lambda expression, List<PropertyAccess> data) {
            return data;
        }

        @Generated
        public CheckForUngroupedPropertyAccess(Struct inputType, Struct processType) {
            this.inputType = inputType;
            this.processType = processType;
        }
    }

    private static class WildcardReplacer
    extends ExpressionConverter<Void> {
        final List<Struct.StructMember> members;

        public Expression visit(StructDeclaration expression, Void data) {
            ArrayList<StructDeclaration.Member> newAttributes = new ArrayList<StructDeclaration.Member>();
            for (StructDeclaration.Member member : expression.getMembers()) {
                if (member.isSelectAll()) {
                    for (Struct.StructMember structMember : this.members) {
                        newAttributes.add(StructDeclaration.jsonStyleMember((String)structMember.getKey(), (Expression)PropertyAccess.of((String[])new String[]{structMember.getKey()})));
                    }
                    continue;
                }
                newAttributes.add(StructDeclaration.jsonStyleMember((Token)member.getIdentifier(), (Expression)member.getExpression()));
            }
            return new StructDeclaration(newAttributes, Optional.empty());
        }

        @Generated
        public WildcardReplacer(List<Struct.StructMember> members) {
            this.members = members;
        }
    }

    private static class Instance
    implements Collector<AccumInstance> {
        private final Class<AccumInstance> accumulatorClass = AccumInstance.class;
        private final Struct sourceType;
        private final RealizedExpression processExpression;
        private final Set<Collector.Characteristic> characteristics = EnumSet.of(Collector.Characteristic.PARALLELIZABLE);
        private final Supplier<AccumInstance> newAccumulator;

        public Struct getProducedType() {
            return (Struct)this.processExpression.getResultType().find(Struct.class).get();
        }

        public AccumInstance newAccumulator() {
            return this.newAccumulator.get();
        }

        public void accumulate(AccumInstance accumulator, Tuple tuple) {
            accumulator.accumulate(tuple);
        }

        public AccumInstance combine(AccumInstance lhs, AccumInstance rhs) {
            return lhs.combine(rhs);
        }

        public Optional<Long> size(AccumInstance accumulator) {
            return Optional.of(Long.valueOf(accumulator.groups.size()));
        }

        public TupleIterator process(AccumInstance accumulator) {
            Struct processInputType = (Struct)this.processExpression.getInputType().find(Struct.class).get();
            return TupleIterator.wrappedAndMapped(accumulator.groups.entrySet().iterator(), entry -> {
                Tuple group = (Tuple)entry.getKey();
                Tuple processedResult = (Tuple)((Accumulator)entry.getValue()).process();
                Tuple processedInput = new Tuple(processInputType);
                processedInput.setAll(group);
                processedInput.setAll(group.size(), processedResult);
                return (Tuple)this.processExpression.evaluate((Object)processedInput);
            });
        }

        @Generated
        public Instance(Struct sourceType, RealizedExpression processExpression, Supplier<AccumInstance> newAccumulator) {
            this.sourceType = sourceType;
            this.processExpression = processExpression;
            this.newAccumulator = newAccumulator;
        }

        @Generated
        public Class<AccumInstance> getAccumulatorClass() {
            return this.accumulatorClass;
        }

        @Generated
        public Struct getSourceType() {
            return this.sourceType;
        }

        @Generated
        public RealizedExpression getProcessExpression() {
            return this.processExpression;
        }

        @Generated
        public Set<Collector.Characteristic> getCharacteristics() {
            return this.characteristics;
        }
    }

    private static class AccumInstance {
        private final RealizedExpression groupBy;
        private final RealizedAggregateExpression aggregateExpression;
        private Map<Tuple, Accumulator> groups = new HashMap<Tuple, Accumulator>();

        public void accumulate(Tuple tuple) {
            Tuple group = (Tuple)this.groupBy.evaluate((Object)tuple);
            this.groups.computeIfAbsent(group, k -> this.aggregateExpression.newAccumulator()).accumulate((Object)tuple);
        }

        public AccumInstance combine(AccumInstance rhs) {
            for (Map.Entry<Tuple, Accumulator> entry : rhs.groups.entrySet()) {
                this.groups.merge(entry.getKey(), entry.getValue(), (l, r) -> l.combine(r));
            }
            return this;
        }

        @Generated
        public AccumInstance(RealizedExpression groupBy, RealizedAggregateExpression aggregateExpression) {
            this.groupBy = groupBy;
            this.aggregateExpression = aggregateExpression;
        }
    }
}

