/*
 * Decompiled with CFR 0.152.
 */
package nz.org.riskscape.engine.rl;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiPredicate;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;
import lombok.Generated;
import nz.org.riskscape.ReflectionUtils;
import nz.org.riskscape.engine.RiskscapeException;
import nz.org.riskscape.engine.Tuple;
import nz.org.riskscape.engine.function.BinaryOperatorFunction;
import nz.org.riskscape.engine.function.BinaryPredicateFunction;
import nz.org.riskscape.engine.function.NullSafeFunction;
import nz.org.riskscape.engine.function.OperatorResolver;
import nz.org.riskscape.engine.function.RiskscapeFunction;
import nz.org.riskscape.engine.rl.RealizationContext;
import nz.org.riskscape.engine.types.Nullable;
import nz.org.riskscape.engine.types.Struct;
import nz.org.riskscape.engine.types.Type;
import nz.org.riskscape.engine.types.Types;
import nz.org.riskscape.engine.util.Pair;
import nz.org.riskscape.rl.TokenTypes;
import nz.org.riskscape.rl.ast.BinaryOperation;

public class DefaultOperators
implements OperatorResolver {
    public static final EnumSet<TokenTypes> MATHS_OPERATORS = EnumSet.of(TokenTypes.PLUS, TokenTypes.MINUS, TokenTypes.MULTIPLY, TokenTypes.DIVIDE, TokenTypes.POW);
    public static final EnumSet<TokenTypes> BOOLEAN_COMPARATORS = EnumSet.of(TokenTypes.GREATER_THAN, TokenTypes.GREATER_THAN_EQUAL, TokenTypes.LESS_THAN, TokenTypes.LESS_THAN_EQUAL);
    public static final EnumSet<TokenTypes> EQUALITY_OPERATORS = EnumSet.of(TokenTypes.EQUALS, TokenTypes.NOT_EQUALS);
    public static final EnumSet<TokenTypes> BOOLEAN_LOGIC_OPERATORS = EnumSet.of(TokenTypes.OR, TokenTypes.AND);
    public static final Map<TokenTypes, List<RiskscapeFunction>> MATHS_FUNCTIONS = ImmutableMap.builder().put((Object)TokenTypes.PLUS, Arrays.asList(DefaultOperators.operatorFor(TokenTypes.PLUS, Long.class, (a, b) -> a + b), DefaultOperators.operatorFor(TokenTypes.PLUS, String.class, (a, b) -> a + b), DefaultOperators.operatorFor(TokenTypes.PLUS, Double.class, (a, b) -> a + b))).put((Object)TokenTypes.MINUS, Arrays.asList(DefaultOperators.operatorFor(TokenTypes.MINUS, Long.class, (a, b) -> a - b), DefaultOperators.operatorFor(TokenTypes.MINUS, Double.class, (a, b) -> a - b))).put((Object)TokenTypes.MULTIPLY, Arrays.asList(DefaultOperators.operatorFor(TokenTypes.MULTIPLY, Long.class, (a, b) -> a * b), DefaultOperators.operatorFor(TokenTypes.MULTIPLY, Double.class, (a, b) -> a * b))).put((Object)TokenTypes.DIVIDE, Arrays.asList(DefaultOperators.operatorFor(TokenTypes.DIVIDE, Double.class, (a, b) -> a / b))).put((Object)TokenTypes.POW, Arrays.asList(DefaultOperators.operatorFor(TokenTypes.POW, Double.class, (a, b) -> Math.pow(a, b)))).build();
    public static final DefaultOperators INSTANCE = new DefaultOperators();

    public <T> RiskscapeFunction predicateFor(TokenTypes operator, final Class<T> javaType, final BiPredicate<T, T> predicate) {
        Type type = Types.fromJavaType(javaType);
        final List<Type> argumentTypes = Arrays.asList(type, type);
        return new RiskscapeFunction(){

            public Object call(List<Object> args) {
                return predicate.test(javaType.cast(args.get(0)), javaType.cast(args.get(1)));
            }

            public List<Type> getArgumentTypes() {
                return argumentTypes;
            }

            public Type getReturnType() {
                return Types.BOOLEAN;
            }
        };
    }

    public RiskscapeFunction orFunction(final boolean nullableInputs) {
        final Type type = Nullable.ifTrue((boolean)nullableInputs, (Type)Types.BOOLEAN);
        final List<Type> argumentTypes = Arrays.asList(type, type);
        return new RiskscapeFunction(){

            public Object call(List<Object> args) {
                Object lhs = args.get(0);
                Object rhs = args.get(1);
                if (!nullableInputs || lhs != null && rhs != null) {
                    return Boolean.logicalOr((Boolean)Boolean.class.cast(lhs), (Boolean)Boolean.class.cast(rhs));
                }
                if (lhs != null) {
                    return lhs;
                }
                if (rhs != null) {
                    return rhs;
                }
                return null;
            }

            public List<Type> getArgumentTypes() {
                return argumentTypes;
            }

            public Type getReturnType() {
                return type;
            }
        };
    }

    public static <T> RiskscapeFunction operatorFor(TokenTypes operationToken, Class<T> javaType, BinaryOperator<T> function) {
        return new BinaryOperatorFunction(operationToken, function, javaType);
    }

    public Optional<RiskscapeFunction> resolve(RealizationContext context, BinaryOperation operation, Type inputType, Type lhs, Type rhs) {
        TokenTypes tt = operation.getNormalizedOperator();
        Optional<Object> result = Optional.empty();
        if (MATHS_OPERATORS.contains(tt)) {
            result = this.resolveMathsOperator(operation, inputType, lhs, rhs);
        } else if (BOOLEAN_LOGIC_OPERATORS.contains(tt)) {
            result = this.resolveLogicOperator(operation, inputType, lhs, rhs);
        } else if (BOOLEAN_COMPARATORS.contains(tt)) {
            result = this.resolveComparators(operation, inputType, lhs, rhs);
        } else if (EQUALITY_OPERATORS.contains(tt)) {
            result = this.resolveEqualityOperator(operation, inputType, lhs, rhs);
        }
        if (result.isPresent()) {
            return result;
        }
        if (Nullable.any((Type[])new Type[]{lhs, rhs})) {
            return this.resolve(context, operation, inputType, Nullable.strip((Type)lhs), Nullable.strip((Type)rhs)).map(function -> NullSafeFunction.wrap((RiskscapeFunction)function));
        }
        Struct lhsStruct = lhs.find(Struct.class).orElse(null);
        Struct rhsStruct = rhs.find(Struct.class).orElse(null);
        if (lhsStruct != null && rhsStruct != null) {
            return this.buildStructCompositeFunction(context, operation, inputType, lhsStruct, rhsStruct).map(pairs -> new BinaryStructCompositeFunction(context, lhs, rhs, (List<StructPair>)pairs));
        }
        if (lhsStruct != null || rhsStruct != null) {
            return this.buildOneSidedStructCompositeFunction(context, operation, inputType, lhs, rhs, lhsStruct != null);
        }
        return result;
    }

    private Optional<RiskscapeFunction> buildOneSidedStructCompositeFunction(RealizationContext context, BinaryOperation operation, Type inputType, Type lhs, Type rhs, boolean lhsIsStruct) {
        Struct struct = (Struct)(lhsIsStruct ? lhs : rhs).find(Struct.class).get();
        Type scalarType = lhsIsStruct ? rhs : lhs;
        Struct returnType = Struct.EMPTY_STRUCT;
        ArrayList<RiskscapeFunction> functions = new ArrayList<RiskscapeFunction>(struct.size());
        for (Struct.StructMember structMember : struct.getMembers()) {
            Type[] typeArray;
            if (lhsIsStruct) {
                Type[] typeArray2 = new Type[2];
                typeArray2[0] = structMember.getType();
                typeArray = typeArray2;
                typeArray2[1] = scalarType;
            } else {
                Type[] typeArray3 = new Type[2];
                typeArray3[0] = scalarType;
                typeArray = typeArray3;
                typeArray3[1] = structMember.getType();
            }
            Type[] args = typeArray;
            Optional<RiskscapeFunction> built = this.resolve(context, operation, inputType, args[0], args[1]);
            if (!built.isPresent()) {
                return Optional.empty();
            }
            functions.add(built.get());
            returnType = returnType.add(structMember.getKey(), built.get().getReturnType());
        }
        returnType = context.normalizeStruct(returnType);
        return Optional.of(new OneSidedStructCompositeFunction(Arrays.asList(lhs, rhs), returnType, lhsIsStruct, functions.toArray(new RiskscapeFunction[0])));
    }

    private Optional<RiskscapeFunction> resolveEqualityOperator(BinaryOperation operation, Type inputType, Type lhs, Type rhs) {
        Type unwrappedRhs;
        Type unwrappedLhs = Nullable.unwrap((Type)lhs);
        HashSet types = Sets.newHashSet((Object[])new Type[]{unwrappedLhs, unwrappedRhs = Nullable.unwrap((Type)rhs)});
        if (types.stream().allMatch(Type::isNumeric)) {
            return this.resolveComparators(operation, inputType, lhs, rhs);
        }
        boolean negate = operation.getNormalizedOperator() == TokenTypes.NOT_EQUALS;
        Optional<BinaryPredicateFunction> equalityFunction = Optional.of(BinaryPredicateFunction.untyped((TokenTypes)operation.getNormalizedOperator(), (Type)unwrappedLhs, (Type)unwrappedRhs, (l, r) -> Objects.equals(l, r) ^ negate));
        return equalityFunction.map(function -> Nullable.any((Type[])new Type[]{lhs, rhs}) ? NullSafeFunction.wrapIgnoringArgs((RiskscapeFunction)function) : function);
    }

    private Optional<RiskscapeFunction> resolveComparators(BinaryOperation operation, Type inputType, Type lhs, Type rhs) {
        boolean mixedRawTypes;
        Type rawLhs = Nullable.unwrap((Type)lhs);
        Type rawRhs = Nullable.unwrap((Type)rhs);
        TokenTypes tt = operation.getNormalizedOperator();
        boolean hasCommonComparableAncestor = ReflectionUtils.findCommonAncestorOfType((Class)rawLhs.internalType(), (Class)rawRhs.internalType(), Comparable.class).isPresent();
        BiPredicate<Comparable, Comparable> predicate = (l, r) -> {
            int result = l.compareTo(r);
            switch (tt) {
                case GREATER_THAN: {
                    return result == 1;
                }
                case LESS_THAN: {
                    return result == -1;
                }
                case GREATER_THAN_EQUAL: {
                    return result != -1;
                }
                case LESS_THAN_EQUAL: {
                    return result != 1;
                }
                case EQUALS: {
                    return result == 0;
                }
                case NOT_EQUALS: {
                    return result != 0;
                }
            }
            throw new RiskscapeException("Unexpected token - " + String.valueOf(tt));
        };
        HashSet rawTypes = Sets.newHashSet((Object[])new Type[]{rawLhs, rawRhs});
        boolean bl = mixedRawTypes = rawTypes.stream().distinct().toList().size() > 1;
        if (mixedRawTypes && rawTypes.stream().allMatch(Type::isNumeric)) {
            BiPredicate<Comparable, Comparable> wrap = predicate;
            predicate = (l, r) -> wrap.test((Comparable)((Number)((Object)l)).doubleValue(), (Comparable)((Number)((Object)r)).doubleValue());
        } else if (!hasCommonComparableAncestor) {
            return Optional.empty();
        }
        return Optional.of(new BinaryPredicateFunction(tt, Comparable.class, rawLhs, Comparable.class, rawRhs, predicate)).map(f -> Nullable.any((Type[])new Type[]{lhs, rhs}) ? NullSafeFunction.wrapIgnoringArgs((RiskscapeFunction)f) : f);
    }

    private Optional<RiskscapeFunction> resolveLogicOperator(BinaryOperation operation, Type inputType, Type lhs, Type rhs) {
        Type rawLhs = Nullable.unwrap((Type)lhs);
        Type rawRhs = Nullable.unwrap((Type)rhs);
        if (rawLhs.equals(Types.BOOLEAN) && rawRhs.equals(Types.BOOLEAN)) {
            BiPredicate<Boolean, Boolean> predicate;
            TokenTypes normalizedOperator = operation.getNormalizedOperator();
            switch (normalizedOperator) {
                case OR: {
                    return Optional.of(this.orFunction(Nullable.any((Type[])new Type[]{lhs, rhs})));
                }
                case AND: {
                    predicate = (l, r) -> l != false && r != false;
                    break;
                }
                default: {
                    return Optional.empty();
                }
            }
            RiskscapeFunction function = this.predicateFor(normalizedOperator, Boolean.class, predicate);
            return Optional.of(function).map(f -> Nullable.any((Type[])new Type[]{lhs, rhs}) ? NullSafeFunction.wrapIgnoringArgs((RiskscapeFunction)f) : f);
        }
        return Optional.empty();
    }

    private Optional<RiskscapeFunction> resolveMathsOperator(BinaryOperation functionCall, Type inputType, Type lhs, Type rhs) {
        HashSet rawTypes;
        Type rawRhs;
        Type rawLhs = lhs.getUnwrappedType();
        Optional<RiskscapeFunction> found = this.rslvRawMathsOperator(functionCall, inputType, rawLhs, rawRhs = rhs.getUnwrappedType());
        if (!found.isPresent() && (rawTypes = Sets.newHashSet((Object[])new Type[]{rawLhs, rawRhs})).stream().allMatch(type -> type.isNumeric() && !type.isNullable())) {
            found = this.rslvRawMathsOperator(functionCall, inputType, (Type)Types.FLOATING, (Type)Types.FLOATING).map(rf -> this.coerceIntArgsForOperator((RiskscapeFunction)rf));
        }
        return found;
    }

    private Optional<List<StructPair>> buildStructCompositeFunction(RealizationContext context, BinaryOperation functionCall, Type inputType, Struct rawLhs, Struct rawRhs) {
        List paired = rawLhs.getMembers().stream().map(lhsMember -> {
            Optional<Struct.StructMember> rhsMember = rawRhs.getMembers().stream().filter(rm -> rm.getKey().equals(lhsMember.getKey())).findFirst();
            return Pair.of((Object)lhsMember, rhsMember);
        }).collect(Collectors.toList());
        if (!paired.stream().allMatch(pair -> ((Optional)pair.getRight()).isPresent())) {
            return Optional.empty();
        }
        List pairs = paired.stream().map(pair -> Pair.of((Object)((Struct.StructMember)pair.getLeft()), (Object)((Struct.StructMember)((Optional)pair.getRight()).get()))).map(pair -> this.resolve(context, functionCall, inputType, ((Struct.StructMember)pair.getLeft()).getType(), ((Struct.StructMember)pair.getRight()).getType()).map(function -> new StructPair((Struct.StructMember)pair.getLeft(), (Struct.StructMember)pair.getRight(), (RiskscapeFunction)function))).filter(optPair -> optPair.isPresent()).map(Optional::get).collect(Collectors.toList());
        if (pairs.size() != rawLhs.size()) {
            return Optional.empty();
        }
        return Optional.of(pairs);
    }

    private Optional<RiskscapeFunction> rslvRawMathsOperator(BinaryOperation functionCall, Type inputType, Type rawLhs, Type rawRhs) {
        List matchedForOperator = MATHS_FUNCTIONS.getOrDefault(functionCall.getOperator().type, Collections.emptyList());
        List<Type> args = Arrays.asList(rawLhs, rawRhs);
        RiskscapeFunction found = null;
        for (RiskscapeFunction function : matchedForOperator) {
            if (!function.getArgumentTypes().equals(args)) continue;
            found = function;
        }
        return Optional.ofNullable(found);
    }

    private RiskscapeFunction coerceIntArgsForOperator(final RiskscapeFunction operator) {
        return new RiskscapeFunction(){

            public Type getReturnType() {
                return Types.FLOATING;
            }

            public List<Type> getArgumentTypes() {
                return Arrays.asList(Types.FLOATING, Types.FLOATING);
            }

            public Object call(List<Object> args) {
                Number lhs = (Number)args.get(0);
                Number rhs = (Number)args.get(1);
                return operator.call(Arrays.asList(lhs.doubleValue(), rhs.doubleValue()));
            }
        };
    }

    private class OneSidedStructCompositeFunction
    implements RiskscapeFunction {
        private final List<Type> argumentTypes;
        private final Struct returnType;
        private final boolean lhsIsStruct;
        private final RiskscapeFunction[] memberFunctions;

        public Object call(List<Object> args) {
            Object lhs = args.get(0);
            Object rhs = args.get(1);
            Tuple tuple = (Tuple)(this.lhsIsStruct ? lhs : rhs);
            Object scalar = this.lhsIsStruct ? rhs : lhs;
            Tuple toReturn = new Tuple(this.returnType);
            Object[] newArgs = new Object[2];
            List<Object> newArgsList = Arrays.asList(newArgs);
            newArgs[this.lhsIsStruct ? 1 : 0] = scalar;
            for (int i = 0; i < toReturn.size(); ++i) {
                newArgs[this.lhsIsStruct ? 0 : 1] = tuple.fetch(i);
                toReturn.set(i, this.memberFunctions[i].call(newArgsList));
            }
            return toReturn;
        }

        @Generated
        public OneSidedStructCompositeFunction(List<Type> argumentTypes, Struct returnType, boolean lhsIsStruct, RiskscapeFunction[] memberFunctions) {
            this.argumentTypes = argumentTypes;
            this.returnType = returnType;
            this.lhsIsStruct = lhsIsStruct;
            this.memberFunctions = memberFunctions;
        }

        @Generated
        public List<Type> getArgumentTypes() {
            return this.argumentTypes;
        }

        @Generated
        public Struct getReturnType() {
            return this.returnType;
        }
    }

    private static class StructPair {
        final Struct.StructMember lhs;
        final Struct.StructMember rhs;
        final RiskscapeFunction applyTo;

        @Generated
        public StructPair(Struct.StructMember lhs, Struct.StructMember rhs, RiskscapeFunction applyTo) {
            this.lhs = lhs;
            this.rhs = rhs;
            this.applyTo = applyTo;
        }
    }

    private static class BinaryStructCompositeFunction
    implements RiskscapeFunction {
        private final List<Type> argumentTypes;
        private final Struct returnType;
        private final List<StructPair> pairs;

        BinaryStructCompositeFunction(RealizationContext context, Type lhs, Type rhs, List<StructPair> pairs) {
            this.pairs = pairs;
            this.argumentTypes = Arrays.asList(lhs, rhs);
            Struct.StructBuilder builder = Struct.builder();
            for (StructPair pair : pairs) {
                builder.add(pair.lhs.getKey(), pair.applyTo.getReturnType());
            }
            this.returnType = context.normalizeStruct(builder.build());
        }

        public Object call(List<Object> args) {
            Tuple lhs = (Tuple)args.get(0);
            Tuple rhs = (Tuple)args.get(1);
            Tuple result = new Tuple(this.returnType);
            List<Object> argsList = Arrays.asList(null, null);
            int index = 0;
            for (StructPair pair : this.pairs) {
                argsList.set(0, lhs.fetch(pair.lhs));
                argsList.set(1, rhs.fetch(pair.rhs));
                result.set(index++, pair.applyTo.call(argsList));
            }
            return result;
        }

        @Generated
        public List<Type> getArgumentTypes() {
            return this.argumentTypes;
        }

        @Generated
        public Struct getReturnType() {
            return this.returnType;
        }
    }
}

