/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SemiJoinType;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;

public class JoinToCorrelateRule
extends RelOptRule {
    public static final JoinToCorrelateRule INSTANCE = new JoinToCorrelateRule(RelFactories.DEFAULT_FILTER_FACTORY);
    protected final RelFactories.FilterFactory filterFactory;

    protected JoinToCorrelateRule(RelFactories.FilterFactory filterFactory) {
        super(JoinToCorrelateRule.operand(LogicalJoin.class, JoinToCorrelateRule.any()));
        this.filterFactory = filterFactory;
        assert (filterFactory != null) : "Filter factory should not be null";
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        LogicalJoin join = (LogicalJoin)call.rel(0);
        switch (join.getJoinType()) {
            case INNER: 
            case LEFT: {
                return true;
            }
            case FULL: 
            case RIGHT: {
                return false;
            }
        }
        throw Util.unexpected(join.getJoinType());
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        assert (this.matches(call));
        LogicalJoin join = (LogicalJoin)call.rel(0);
        RelNode right = join.getRight();
        RelNode left = join.getLeft();
        final int leftFieldCount = left.getRowType().getFieldCount();
        RelOptCluster cluster = join.getCluster();
        final RexBuilder rexBuilder = cluster.getRexBuilder();
        String dynInIdStr = cluster.getQuery().createCorrel();
        CorrelationId correlationId = new CorrelationId(dynInIdStr);
        final RexNode corrVar = rexBuilder.makeCorrel(left.getRowType(), correlationId.getName());
        final ImmutableBitSet.Builder requiredColumns = ImmutableBitSet.builder();
        RexNode joinCondition = join.getCondition();
        joinCondition = joinCondition.accept(new RexShuttle(){

            @Override
            public RexNode visitInputRef(RexInputRef input) {
                int field = input.getIndex();
                if (field >= leftFieldCount) {
                    return rexBuilder.makeInputRef(input.getType(), input.getIndex() - leftFieldCount);
                }
                requiredColumns.set(field);
                return rexBuilder.makeFieldAccess(corrVar, field);
            }
        });
        joinCondition = RexUtil.flatten(rexBuilder, joinCondition);
        RelNode filteredRight = RelOptUtil.createFilter(right, joinCondition, this.filterFactory);
        LogicalCorrelate newRel = LogicalCorrelate.create(left, filteredRight, correlationId, requiredColumns.build(), SemiJoinType.of(join.getJoinType()));
        call.transformTo(newRel);
    }
}

