/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.java.logging.slf4j;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.Generated;
import org.jspecify.annotations.Nullable;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Preconditions;
import org.openrewrite.Recipe;
import org.openrewrite.Tree;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.StringUtils;
import org.openrewrite.internal.lang.NonNull;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.Flag;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JRightPadded;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.MethodCall;
import org.openrewrite.java.tree.Space;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.java.tree.TypeUtils;
import org.openrewrite.marker.Markers;

public class WrapExpensiveLogStatementsInConditionals
extends Recipe {
    private static final MethodMatcher infoMatcher = new MethodMatcher("org.slf4j.Logger info(..)");
    private static final MethodMatcher debugMatcher = new MethodMatcher("org.slf4j.Logger debug(..)");
    private static final MethodMatcher traceMatcher = new MethodMatcher("org.slf4j.Logger trace(..)");
    private static final MethodMatcher isInfoEnabledMatcher = new MethodMatcher("org.slf4j.Logger isInfoEnabled()");
    private static final MethodMatcher isDebugEnabledMatcher = new MethodMatcher("org.slf4j.Logger isDebugEnabled()");
    private static final MethodMatcher isTraceEnabledMatcher = new MethodMatcher("org.slf4j.Logger isTraceEnabled()");

    public String getDisplayName() {
        return "Wrap expensive log statements in conditionals";
    }

    public String getDescription() {
        return "When trace, debug and info log statements use methods for constructing log messages, those methods are called regardless of whether the log level is enabled. This recipe encapsulates those log statements in an `if` statement that checks the log level before calling the log method. It then bundles surrounding log statements with the same log level into the `if` statement to improve readability of the resulting code.";
    }

    public TreeVisitor<?, ExecutionContext> getVisitor() {
        return Preconditions.check((TreeVisitor)Preconditions.or((TreeVisitor[])new TreeVisitor[]{new UsesMethod(infoMatcher), new UsesMethod(debugMatcher), new UsesMethod(traceMatcher)}), (TreeVisitor)new AddIfEnabledVisitor());
    }

    private static class AddIfEnabledVisitor
    extends JavaVisitor<ExecutionContext> {
        final Set<UUID> visitedBlocks = new HashSet<UUID>();

        private AddIfEnabledVisitor() {
        }

        public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
            J container;
            J.MethodInvocation m = (J.MethodInvocation)super.visitMethodInvocation(method, (Object)ctx);
            if (m.getSelect() != null && (infoMatcher.matches((MethodCall)m) || debugMatcher.matches((MethodCall)m) || traceMatcher.matches((MethodCall)m)) && !this.isInIfStatementWithLogLevelCheck(this.getCursor(), m) && this.isAnyArgumentExpensive(m) && (container = (J)this.getCursor().getParentTreeCursor().getValue()) instanceof J.Block) {
                UUID id = container.getId();
                J.If if_ = ((J.If)JavaTemplate.builder((String)"if(#{logger:any(org.slf4j.Logger)}.is#{}Enabled()) {}").javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, new String[]{"slf4j-api-2.+"})).build().apply(this.getCursor(), m.getCoordinates().replace(), new Object[]{m.getSelect(), StringUtils.capitalize((String)m.getSimpleName())})).withThenPart((Statement)m.withPrefix(m.getPrefix().withWhitespace("\n" + m.getPrefix().getWhitespace().replace("\n", "")))).withPrefix(m.getPrefix().withComments(Collections.emptyList()));
                this.visitedBlocks.add(id);
                return if_;
            }
            return m;
        }

        public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
            J j = super.visitCompilationUnit(cu, (Object)ctx);
            if (j != cu && !this.visitedBlocks.isEmpty()) {
                this.doAfterVisit((TreeVisitor)new MergeLogStatementsInCheck(this.visitedBlocks));
            }
            return j;
        }

        private boolean isInIfStatementWithLogLevelCheck(Cursor cursor, J.MethodInvocation m) {
            J.If enclosingIf = (J.If)cursor.firstEnclosing(J.If.class);
            if (enclosingIf == null) {
                return false;
            }
            List sideEffects = enclosingIf.getIfCondition().getSideEffects();
            return infoMatcher.matches((MethodCall)m) && sideEffects.stream().allMatch(e -> e instanceof J.MethodInvocation && isInfoEnabledMatcher.matches((MethodCall)((J.MethodInvocation)e))) || debugMatcher.matches((MethodCall)m) && sideEffects.stream().allMatch(e -> e instanceof J.MethodInvocation && isDebugEnabledMatcher.matches((MethodCall)((J.MethodInvocation)e))) || traceMatcher.matches((MethodCall)m) && sideEffects.stream().allMatch(e -> e instanceof J.MethodInvocation && isTraceEnabledMatcher.matches((MethodCall)((J.MethodInvocation)e)));
        }

        private boolean isAnyArgumentExpensive(J.MethodInvocation m) {
            return m.getArguments().stream().anyMatch(arg -> !(arg instanceof J.MethodInvocation && AddIfEnabledVisitor.isSimpleGetter((J.MethodInvocation)arg) || arg instanceof J.Literal || arg instanceof J.Identifier || arg instanceof J.FieldAccess || arg instanceof J.Binary && AddIfEnabledVisitor.isOnlyLiterals((J.Binary)arg)));
        }

        private static boolean isSimpleGetter(J.MethodInvocation mi) {
            return (mi.getSimpleName().startsWith("get") && mi.getSimpleName().length() > 3 || mi.getSimpleName().startsWith("is") && mi.getSimpleName().length() > 2) && mi.getMethodType() != null && mi.getMethodType().getParameterNames().isEmpty() && (mi.getSelect() == null || mi.getSelect() instanceof J.Identifier) && !mi.getMethodType().hasFlags(new Flag[]{Flag.Static});
        }

        private static boolean isOnlyLiterals(J.Binary binary) {
            return AddIfEnabledVisitor.isLiteralOrBinary((J)binary.getLeft()) && AddIfEnabledVisitor.isLiteralOrBinary((J)binary.getRight());
        }

        private static boolean isLiteralOrBinary(J expression) {
            return expression instanceof J.Literal || AddIfEnabledVisitor.isSimpleBooleanGetter(expression) || AddIfEnabledVisitor.isBooleanIdentifier(expression) || expression instanceof J.Binary && AddIfEnabledVisitor.isOnlyLiterals((J.Binary)expression);
        }

        private static boolean isSimpleBooleanGetter(J expression) {
            if (expression instanceof J.MethodInvocation) {
                J.MethodInvocation mi = (J.MethodInvocation)expression;
                return AddIfEnabledVisitor.isSimpleGetter(mi) && mi.getMethodType() != null && AddIfEnabledVisitor.isTypeBoolean(mi.getMethodType().getReturnType());
            }
            return false;
        }

        private static boolean isBooleanIdentifier(J expression) {
            return expression instanceof J.Identifier && AddIfEnabledVisitor.isTypeBoolean(((J.Identifier)expression).getType());
        }

        private static boolean isTypeBoolean(@Nullable JavaType type) {
            return type == JavaType.Primitive.Boolean || TypeUtils.isAssignableTo((String)"java.lang.Boolean", (JavaType)type);
        }
    }

    private static class StatementAccumulator {
        private final Function<J, J> formatter;
        private final List<Statement> statements = new ArrayList<Statement>();
        private final List<Statement> logStatementsCache = new ArrayList<Statement>();
        private AccumulatorKind accumulatorKind = AccumulatorKind.NONE;
        private // Could not load outer class - annotation placement on inner may be incorrect
        @Nullable J.If ifCache = null;

        public StatementAccumulator(Function<J, J> formatter) {
            this.formatter = formatter;
        }

        public void push(Statement statement) {
            AccumulatorKind newKind = this.getKind(statement);
            if (newKind != this.accumulatorKind && this.accumulatorKind != AccumulatorKind.NONE) {
                this.handleLogStatements();
            }
            this.accumulatorKind = newKind;
            if (statement instanceof J.If) {
                J.If if_ = (J.If)statement;
                if (if_.getThenPart() instanceof J.MethodInvocation && this.isInIfStatementWithOnlyLogLevelCheck(if_, (J.MethodInvocation)if_.getThenPart())) {
                    if (newKind != AccumulatorKind.NONE) {
                        if (this.ifCache == null) {
                            this.ifCache = if_;
                            this.logStatementsCache.add(if_.getThenPart());
                        } else {
                            this.logStatementsCache.add((Statement)if_.getThenPart().withPrefix(if_.getThenPart().getPrefix().withWhitespace(if_.getPrefix().getWhitespace())));
                        }
                    } else {
                        this.statements.add(if_.getThenPart());
                    }
                    return;
                }
                if (if_.getThenPart() instanceof J.Block && !((J.Block)if_.getThenPart()).getStatements().isEmpty() && ((J.Block)if_.getThenPart()).getStatements().stream().allMatch(s -> s instanceof J.MethodInvocation && this.isInIfStatementWithOnlyLogLevelCheck(if_, (J.MethodInvocation)s))) {
                    if (newKind != AccumulatorKind.NONE) {
                        this.ifCache = if_;
                        this.logStatementsCache.addAll(((J.Block)if_.getThenPart()).getStatements());
                    } else {
                        this.statements.addAll(((J.Block)if_.getThenPart()).getStatements());
                    }
                    return;
                }
            } else if (statement instanceof J.MethodInvocation && newKind != AccumulatorKind.NONE) {
                this.logStatementsCache.add(statement);
                return;
            }
            this.statements.add(statement);
        }

        public List<Statement> pull() {
            if (!this.logStatementsCache.isEmpty()) {
                this.handleLogStatements();
            }
            return this.statements;
        }

        private AccumulatorKind getKind(Statement statement) {
            if (statement instanceof J.If) {
                J.If if_ = (J.If)statement;
                if (if_.getThenPart() instanceof J.MethodInvocation && this.isInIfStatementWithOnlyLogLevelCheck(if_, (J.MethodInvocation)if_.getThenPart())) {
                    J.MethodInvocation mi = (J.MethodInvocation)if_.getThenPart();
                    return AccumulatorKind.fromMethodInvocation(mi);
                }
                if (if_.getThenPart() instanceof J.Block && !((J.Block)if_.getThenPart()).getStatements().isEmpty() && ((J.Block)if_.getThenPart()).getStatements().stream().allMatch(s -> s instanceof J.MethodInvocation && this.isInIfStatementWithOnlyLogLevelCheck(if_, (J.MethodInvocation)s))) {
                    return AccumulatorKind.fromMethodInvocation((J.MethodInvocation)((J.Block)if_.getThenPart()).getStatements().get(0));
                }
            } else if (statement instanceof J.MethodInvocation) {
                J.MethodInvocation mi = (J.MethodInvocation)statement;
                return AccumulatorKind.fromMethodInvocation(mi);
            }
            return AccumulatorKind.NONE;
        }

        private void handleLogStatements() {
            if (this.ifCache == null) {
                this.statements.addAll(this.logStatementsCache);
            } else {
                J.If anIf = this.ifCache.withThenPart((Statement)new J.Block(Tree.randomId(), Space.EMPTY, Markers.EMPTY, JRightPadded.build((Object)false), this.logStatementsCache.stream().map(JRightPadded::build).collect(Collectors.toList()), Space.EMPTY));
                this.statements.add((Statement)this.formatter.apply((J)anIf));
            }
            this.logStatementsCache.clear();
            this.ifCache = null;
        }

        private boolean isInIfStatementWithOnlyLogLevelCheck(J.If if_, J.MethodInvocation m) {
            J.ControlParentheses ifCondition = if_.getIfCondition();
            return ifCondition.getTree() instanceof J.MethodInvocation && (infoMatcher.matches((MethodCall)m) && isInfoEnabledMatcher.matches((Expression)ifCondition.getTree()) || debugMatcher.matches((MethodCall)m) && isDebugEnabledMatcher.matches((Expression)ifCondition.getTree()) || traceMatcher.matches((MethodCall)m) && isTraceEnabledMatcher.matches((Expression)ifCondition.getTree()));
        }

        private static enum AccumulatorKind {
            NONE,
            INFO,
            DEBUG,
            TRACE;


            public static AccumulatorKind fromMethodInvocation(J.MethodInvocation mi) {
                if (infoMatcher.matches((MethodCall)mi)) {
                    return INFO;
                }
                if (debugMatcher.matches((MethodCall)mi)) {
                    return DEBUG;
                }
                if (traceMatcher.matches((MethodCall)mi)) {
                    return TRACE;
                }
                return NONE;
            }
        }
    }

    private static final class MergeLogStatementsInCheck
    extends JavaIsoVisitor<ExecutionContext> {
        private final Set<UUID> blockIds;

        public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
            J.Block b = super.visitBlock(block, (Object)ctx);
            if (this.blockIds.contains(b.getId())) {
                StatementAccumulator acc = new StatementAccumulator(j -> this.autoFormat((J)j, ctx, this.getCursor()));
                for (Statement statement : b.getStatements()) {
                    acc.push(statement);
                }
                return b.withStatements(acc.pull());
            }
            return b;
        }

        @ConstructorProperties(value={"blockIds"})
        @Generated
        public MergeLogStatementsInCheck(Set<UUID> blockIds) {
            this.blockIds = blockIds;
        }

        @Generated
        public Set<UUID> getBlockIds() {
            return this.blockIds;
        }

        @NonNull
        @Generated
        public String toString() {
            return "WrapExpensiveLogStatementsInConditionals.MergeLogStatementsInCheck(blockIds=" + this.getBlockIds() + ")";
        }

        @Generated
        public boolean equals(@org.openrewrite.internal.lang.Nullable Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof MergeLogStatementsInCheck)) {
                return false;
            }
            MergeLogStatementsInCheck other = (MergeLogStatementsInCheck)((Object)o);
            if (!other.canEqual((Object)this)) {
                return false;
            }
            Set<UUID> this$blockIds = this.getBlockIds();
            Set<UUID> other$blockIds = other.getBlockIds();
            return !(this$blockIds == null ? other$blockIds != null : !((Object)this$blockIds).equals(other$blockIds));
        }

        @Generated
        protected boolean canEqual(@org.openrewrite.internal.lang.Nullable Object other) {
            return other instanceof MergeLogStatementsInCheck;
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            Set<UUID> $blockIds = this.getBlockIds();
            result = result * 59 + ($blockIds == null ? 43 : ((Object)$blockIds).hashCode());
            return result;
        }
    }
}

