/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mappings;

public class AggregateJoinTransposeRule
extends RelOptRule {
    public static final AggregateJoinTransposeRule INSTANCE = new AggregateJoinTransposeRule(LogicalAggregate.class, RelFactories.DEFAULT_AGGREGATE_FACTORY, LogicalJoin.class, RelFactories.DEFAULT_JOIN_FACTORY);
    private final RelFactories.AggregateFactory aggregateFactory;
    private final RelFactories.JoinFactory joinFactory;

    public AggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory) {
        super(AggregateJoinTransposeRule.operand(aggregateClass, null, Aggregate.IS_SIMPLE, AggregateJoinTransposeRule.operand(joinClass, AggregateJoinTransposeRule.any()), new RelOptRuleOperand[0]));
        this.aggregateFactory = aggregateFactory;
        this.joinFactory = joinFactory;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        ImmutableBitSet joinColumns;
        Aggregate aggregate = (Aggregate)call.rel(0);
        Join join = (Join)call.rel(1);
        if (!aggregate.getAggCallList().isEmpty()) {
            return;
        }
        if (join.getJoinType() != JoinRelType.INNER) {
            return;
        }
        final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
        boolean allColumnsInAggregate = aggregateColumns.contains(joinColumns = RelOptUtil.InputFinder.bits(join.getCondition()));
        if (!allColumnsInAggregate) {
            return;
        }
        ArrayList leftKeys = Lists.newArrayList();
        ArrayList rightKeys = Lists.newArrayList();
        RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), (List<Integer>)leftKeys, rightKeys);
        if (!nonEquiConj.isAlwaysTrue()) {
            return;
        }
        ImmutableBitSet leftKeysBitSet = ImmutableBitSet.of(leftKeys);
        RelNode newLeftInput = this.aggregateFactory.createAggregate(join.getLeft(), false, leftKeysBitSet, null, aggregate.getAggCallList());
        ImmutableBitSet rightKeysBitSet = ImmutableBitSet.of(rightKeys);
        RelNode newRightInput = this.aggregateFactory.createAggregate(join.getRight(), false, rightKeysBitSet, null, aggregate.getAggCallList());
        Mappings.TargetMapping mapping = Mappings.target(new Function<Integer, Integer>(){

            public Integer apply(Integer a0) {
                return aggregateColumns.indexOf(a0);
            }
        }, join.getRowType().getFieldCount(), aggregateColumns.cardinality());
        RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
        RelNode newJoin = this.joinFactory.createJoin(newLeftInput, newRightInput, newCondition, join.getJoinType(), join.getVariablesStopped(), join.isSemiJoinDone());
        call.transformTo(newJoin);
    }
}

