package org.aspectj.weaver.patterns;

import org.aspectj.weaver.Shadow;
import org.aspectj.weaver.patterns.Pointcut.MatchesNothingPointcut;

import java.util.Iterator;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

public class PointcutRewriter {

    private static final boolean WATCH_PROGRESS = false;

    public Pointcut rewrite(Pointcut pc, boolean forceRewrite) {
        Pointcut result = pc;
        if (forceRewrite || !isDNF(pc)) {
            if (WATCH_PROGRESS) {
                System.out.println("Initial pointcut is        ==> " + format(pc));
            }
            result = distributeNot(result);
            if (WATCH_PROGRESS) {
                System.out.println("Distributing NOT gives     ==> " + format(result));
            }
            result = pullUpDisjunctions(result);
            if (WATCH_PROGRESS) {
                System.out.println("Pull up disjunctions gives ==> " + format(result));
            }
        } else {
            if (WATCH_PROGRESS) {
                System.out.println("Not distributing NOTs or pulling up disjunctions, already DNF ==> " + format(pc));
            }
        }
        result = simplifyAnds(result);
        if (WATCH_PROGRESS) {
            System.out.println("Simplifying ANDs gives     ==> " + format(result));
        }
        result = removeNothings(result);
        if (WATCH_PROGRESS) {
            System.out.println("Removing nothings gives    ==> " + format(result));
        }
        result = sortOrs(result);
        if (WATCH_PROGRESS) {
            System.out.println("Sorting ORs gives          ==> " + format(result));
        }
        return result;
    }

    public Pointcut rewrite(Pointcut pc) {
        return rewrite(pc, false);
    }

    private boolean isDNF(Pointcut pc) {
        return isDNFHelper(pc, true);
    }

    private boolean isDNFHelper(Pointcut pc, boolean canStillHaveOrs) {
        if (isAnd(pc)) {
            AndPointcut ap = (AndPointcut) pc;
            return isDNFHelper(ap.getLeft(), false) && isDNFHelper(ap.getRight(), false);
        } else if (isOr(pc)) {
            if (!canStillHaveOrs) {
                return false;
            }
            OrPointcut op = (OrPointcut) pc;
            return isDNFHelper(op.getLeft(), true) && isDNFHelper(op.getRight(), true);
        } else if (isNot(pc)) {
            return isDNFHelper(((NotPointcut) pc).getNegatedPointcut(), canStillHaveOrs);
        } else {
            return true;
        }
    }

    public static String format(Pointcut p) {
        String s = p.toString();
        return s;
    }

    private Pointcut distributeNot(Pointcut pc) {
        if (isNot(pc)) {
            NotPointcut npc = (NotPointcut) pc;
            Pointcut notBody = distributeNot(npc.getNegatedPointcut());
            if (isNot(notBody)) {
                return ((NotPointcut) notBody).getNegatedPointcut();
            } else if (isAnd(notBody)) {
                AndPointcut apc = (AndPointcut) notBody;
                Pointcut newLeft = distributeNot(new NotPointcut(apc.getLeft(), npc.getStart()));
                Pointcut newRight = distributeNot(new NotPointcut(apc.getRight(), npc.getStart()));
                return new OrPointcut(newLeft, newRight);
            } else if (isOr(notBody)) {
                OrPointcut opc = (OrPointcut) notBody;
                Pointcut newLeft = distributeNot(new NotPointcut(opc.getLeft(), npc.getStart()));
                Pointcut newRight = distributeNot(new NotPointcut(opc.getRight(), npc.getStart()));
                return new AndPointcut(newLeft, newRight);
            } else {
                return new NotPointcut(notBody, npc.getStart());
            }
        } else if (isAnd(pc)) {
            AndPointcut apc = (AndPointcut) pc;
            Pointcut left = distributeNot(apc.getLeft());
            Pointcut right = distributeNot(apc.getRight());
            return new AndPointcut(left, right);
        } else if (isOr(pc)) {
            OrPointcut opc = (OrPointcut) pc;
            Pointcut left = distributeNot(opc.getLeft());
            Pointcut right = distributeNot(opc.getRight());
            return new OrPointcut(left, right);
        } else {
            return pc;
        }
    }

    private Pointcut pullUpDisjunctions(Pointcut pc) {
        if (isNot(pc)) {
            NotPointcut npc = (NotPointcut) pc;
            return new NotPointcut(pullUpDisjunctions(npc.getNegatedPointcut()));
        } else if (isAnd(pc)) {
            AndPointcut apc = (AndPointcut) pc;
            Pointcut left = pullUpDisjunctions(apc.getLeft());
            Pointcut right = pullUpDisjunctions(apc.getRight());
            if (isOr(left) && !isOr(right)) {
                Pointcut leftLeft = ((OrPointcut) left).getLeft();
                Pointcut leftRight = ((OrPointcut) left).getRight();
                return pullUpDisjunctions(new OrPointcut(new AndPointcut(leftLeft, right), new AndPointcut(leftRight, right)));
            } else if (isOr(right) && !isOr(left)) {
                Pointcut rightLeft = ((OrPointcut) right).getLeft();
                Pointcut rightRight = ((OrPointcut) right).getRight();
                return pullUpDisjunctions(new OrPointcut(new AndPointcut(left, rightLeft), new AndPointcut(left, rightRight)));
            } else if (isOr(right) && isOr(left)) {
                Pointcut A = pullUpDisjunctions(((OrPointcut) left).getLeft());
                Pointcut B = pullUpDisjunctions(((OrPointcut) left).getRight());
                Pointcut C = pullUpDisjunctions(((OrPointcut) right).getLeft());
                Pointcut D = pullUpDisjunctions(((OrPointcut) right).getRight());
                Pointcut newLeft = new OrPointcut(new AndPointcut(A, C), new AndPointcut(A, D));
                Pointcut newRight = new OrPointcut(new AndPointcut(B, C), new AndPointcut(B, D));
                return pullUpDisjunctions(new OrPointcut(newLeft, newRight));
            } else {
                return new AndPointcut(left, right);
            }
        } else if (isOr(pc)) {
            OrPointcut opc = (OrPointcut) pc;
            return new OrPointcut(pullUpDisjunctions(opc.getLeft()), pullUpDisjunctions(opc.getRight()));
        } else {
            return pc;
        }
    }

    public Pointcut not(Pointcut p) {
        if (isNot(p)) {
            return ((NotPointcut) p).getNegatedPointcut();
        }
        return new NotPointcut(p);
    }

    public Pointcut createAndsFor(Pointcut[] ps) {
        if (ps.length == 1) {
            return ps[0];
        }
        if (ps.length == 2) {
            return new AndPointcut(ps[0], ps[1]);
        }
        Pointcut[] subset = new Pointcut[ps.length - 1];
        if (ps.length - 1 >= 0)
            System.arraycopy(ps, 1, subset, 0, ps.length - 1);
        return new AndPointcut(ps[0], createAndsFor(subset));
    }

    private Pointcut simplifyAnds(Pointcut pc) {
        if (isNot(pc)) {
            NotPointcut npc = (NotPointcut) pc;
            Pointcut notBody = npc.getNegatedPointcut();
            if (isNot(notBody)) {
                return simplifyAnds(((NotPointcut) notBody).getNegatedPointcut());
            } else {
                return new NotPointcut(simplifyAnds(npc.getNegatedPointcut()));
            }
        } else if (isOr(pc)) {
            OrPointcut opc = (OrPointcut) pc;
            return new OrPointcut(simplifyAnds(opc.getLeft()), simplifyAnds(opc.getRight()));
        } else if (isAnd(pc)) {
            return simplifyAnd((AndPointcut) pc);
        } else {
            return pc;
        }
    }

    private Pointcut simplifyAnd(AndPointcut apc) {
        SortedSet<Pointcut> nodes = new TreeSet<>(new PointcutEvaluationExpenseComparator());
        collectAndNodes(apc, nodes);
        for (Pointcut element : nodes) {
            if (element instanceof NotPointcut) {
                Pointcut body = ((NotPointcut) element).getNegatedPointcut();
                if (nodes.contains(body)) {
                    return Pointcut.makeMatchesNothing(body.state);
                }
            }
            if (element instanceof IfPointcut) {
                if (((IfPointcut) element).alwaysFalse()) {
                    return Pointcut.makeMatchesNothing(element.state);
                }
            }
            if (element.couldMatchKinds() == Shadow.NO_SHADOW_KINDS_BITS) {
                return element;
            }
        }
        if (apc.couldMatchKinds() == Shadow.NO_SHADOW_KINDS_BITS) {
            return Pointcut.makeMatchesNothing(apc.state);
        }
        Iterator<Pointcut> iter = nodes.iterator();
        Pointcut result = iter.next();
        while (iter.hasNext()) {
            Pointcut right = iter.next();
            result = new AndPointcut(result, right);
        }
        return result;
    }

    private Pointcut sortOrs(Pointcut pc) {
        SortedSet<Pointcut> nodes = new TreeSet<>(new PointcutEvaluationExpenseComparator());
        collectOrNodes(pc, nodes);
        Iterator<Pointcut> iter = nodes.iterator();
        Pointcut result = iter.next();
        while (iter.hasNext()) {
            Pointcut right = iter.next();
            result = new OrPointcut(result, right);
        }
        return result;
    }

    private Pointcut removeNothings(Pointcut pc) {
        if (isAnd(pc)) {
            AndPointcut apc = (AndPointcut) pc;
            Pointcut right = removeNothings(apc.getRight());
            Pointcut left = removeNothings(apc.getLeft());
            if (left instanceof MatchesNothingPointcut || right instanceof MatchesNothingPointcut) {
                return new MatchesNothingPointcut();
            }
            return new AndPointcut(left, right);
        } else if (isOr(pc)) {
            OrPointcut opc = (OrPointcut) pc;
            Pointcut right = removeNothings(opc.getRight());
            Pointcut left = removeNothings(opc.getLeft());
            if (left instanceof MatchesNothingPointcut && !(right instanceof MatchesNothingPointcut)) {
                return right;
            } else if (right instanceof MatchesNothingPointcut && !(left instanceof MatchesNothingPointcut)) {
                return left;
            } else if (!(left instanceof MatchesNothingPointcut) && !(right instanceof MatchesNothingPointcut)) {
                return new OrPointcut(left, right);
            } else if (left instanceof MatchesNothingPointcut && right instanceof MatchesNothingPointcut) {
                return new MatchesNothingPointcut();
            }
        }
        return pc;
    }

    private void collectAndNodes(AndPointcut apc, Set<Pointcut> nodesSoFar) {
        Pointcut left = apc.getLeft();
        Pointcut right = apc.getRight();
        if (isAnd(left)) {
            collectAndNodes((AndPointcut) left, nodesSoFar);
        } else {
            nodesSoFar.add(left);
        }
        if (isAnd(right)) {
            collectAndNodes((AndPointcut) right, nodesSoFar);
        } else {
            nodesSoFar.add(right);
        }
    }

    private void collectOrNodes(Pointcut pc, Set<Pointcut> nodesSoFar) {
        if (isOr(pc)) {
            OrPointcut opc = (OrPointcut) pc;
            collectOrNodes(opc.getLeft(), nodesSoFar);
            collectOrNodes(opc.getRight(), nodesSoFar);
        } else {
            nodesSoFar.add(pc);
        }
    }

    private boolean isNot(Pointcut pc) {
        return (pc instanceof NotPointcut);
    }

    private boolean isAnd(Pointcut pc) {
        return (pc instanceof AndPointcut);
    }

    private boolean isOr(Pointcut pc) {
        return (pc instanceof OrPointcut);
    }
}
