/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.java.testing.mockito;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.internal.lang.Nullable;
import org.openrewrite.java.AnnotationMatcher;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.RemoveAnnotationVisitor;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.Flag;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaCoordinates;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.java.tree.TypeUtils;
import org.openrewrite.template.SourceTemplate;

public class PowerMockitoMockStaticToMockito
extends Recipe {
    public String getDisplayName() {
        return "Replace `PowerMock.mockStatic()` with `Mockito.mockStatic()`";
    }

    public String getDescription() {
        return "Replaces `PowerMockito.mockStatic()` by `Mockito.mockStatic()`. Removes the `@PrepareForTest` annotation.";
    }

    public TreeVisitor<?, ExecutionContext> getVisitor() {
        return new PowerMockitoToMockitoVisitor();
    }

    private static class PowerMockitoToMockitoVisitor
    extends JavaVisitor<ExecutionContext> {
        private static final TestFrameworkInfo TESTNG_FRAMEWORK_INFO = new TestFrameworkInfo("BeforeMethod", "AfterMethod", "org.testng.annotations", "testng-7.7.1");
        private static final TestFrameworkInfo JUNIT_FRAMEWORK_INFO = new TestFrameworkInfo("BeforeEach", "AfterEach", "org.junit.jupiter.api", "junit-jupiter-api-5.9.2");
        private static final String MOCKED_STATIC = "org.mockito.MockedStatic";
        private static final String POWER_MOCK_RUNNER = "org.powermock.modules.junit4.PowerMockRunner";
        private static final MethodMatcher MOCKED_STATIC_MATCHER = new MethodMatcher("org.mockito.Mockito mockStatic(..)");
        private static final MethodMatcher MOCKED_STATIC_CLOSE_MATCHER = new MethodMatcher("org.mockito.ScopedMock close(..)", true);
        private static final MethodMatcher MOCKITO_VERIFY_MATCHER = new MethodMatcher("org.mockito.Mockito verify(..)");
        private static final AnnotationMatcher PREPARE_FOR_TEST_MATCHER = new AnnotationMatcher("@org.powermock.core.classloader.annotations.PrepareForTest");
        private static final AnnotationMatcher RUN_WITH_POWER_MOCK_RUNNER_MATCHER = new AnnotationMatcher("@org.junit.runner.RunWith(org.powermock.modules.junit4.PowerMockRunner.class)");
        private static final MethodMatcher MOCKITO_WHEN_MATCHER = new MethodMatcher("org.mockito.Mockito when(..)");
        private static final MethodMatcher MOCKITO_STATIC_METHOD_MATCHER = new MethodMatcher("org.mockito.Mockito *(..)");
        public static final String MOCKED_TYPES_FIELDS = "mockedTypesFields";
        private TestFrameworkInfo testFrameworkInfo;
        private List<J.Identifier> mockedTypesFields;

        private PowerMockitoToMockitoVisitor() {
        }

        private void initTestFrameworkInfo(boolean useTestNg) {
            this.testFrameworkInfo = useTestNg ? TESTNG_FRAMEWORK_INFO : JUNIT_FRAMEWORK_INFO;
        }

        private List<J.Identifier> getMockedTypesFields() {
            if (this.mockedTypesFields == null) {
                this.mockedTypesFields = new ArrayList<J.Identifier>();
            }
            return this.mockedTypesFields;
        }

        public J visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) {
            List<Expression> mockedTypes;
            ArrayList<J.Annotation> prepareForTestAnnotations = new ArrayList<J.Annotation>();
            for (J.Annotation annotation : classDecl.getAllAnnotations()) {
                if (PREPARE_FOR_TEST_MATCHER.matches(annotation)) {
                    prepareForTestAnnotations.add(annotation);
                    this.doAfterVisit((TreeVisitor)new RemoveAnnotationVisitor(PREPARE_FOR_TEST_MATCHER));
                    continue;
                }
                if (!RUN_WITH_POWER_MOCK_RUNNER_MATCHER.matches(annotation)) continue;
                this.doAfterVisit((TreeVisitor)new RemoveAnnotationVisitor(RUN_WITH_POWER_MOCK_RUNNER_MATCHER));
                this.maybeRemoveImport(POWER_MOCK_RUNNER);
            }
            boolean useTestNg = PowerMockitoToMockitoVisitor.containsTestNgTestMethods(classDecl.getBody().getStatements().stream().filter(statement -> statement instanceof J.MethodDeclaration).map(J.MethodDeclaration.class::cast).collect(Collectors.toList()));
            this.initTestFrameworkInfo(useTestNg);
            if (!prepareForTestAnnotations.isEmpty() && !(mockedTypes = PowerMockitoToMockitoVisitor.getMockedTypesFromPrepareForTestAnnotation(prepareForTestAnnotations)).isEmpty()) {
                classDecl = this.maybeAddSetUpMethodBody(classDecl, ctx);
                classDecl = this.maybeAddTearDownMethodBody(classDecl, ctx);
                classDecl = this.addFieldDeclarationForMockedTypes(classDecl, ctx, mockedTypes);
            }
            return super.visitClassDeclaration(classDecl, (Object)ctx);
        }

        @NotNull
        private J.ClassDeclaration addFieldDeclarationForMockedTypes(J.ClassDeclaration classDecl, ExecutionContext ctx, List<Expression> mockedTypes) {
            ArrayList<J.Identifier> mockedTypesIdentifiers = new ArrayList<J.Identifier>(mockedTypes.size());
            for (Expression mockedType : mockedTypes) {
                JavaType.FullyQualified fullyQualifiedMockedType;
                JavaType.Parameterized classType = TypeUtils.asParameterized((JavaType)mockedType.getType());
                if (classType == null || (fullyQualifiedMockedType = TypeUtils.asFullyQualified((JavaType)((JavaType)classType.getTypeParameters().get(0)))) == null) continue;
                String classlessTypeName = fullyQualifiedMockedType.getClassName();
                String mockedTypedFieldName = "mocked" + classlessTypeName;
                if (PowerMockitoToMockitoVisitor.isFieldAlreadyDefined(classDecl.getBody(), mockedTypedFieldName)) continue;
                classDecl = classDecl.withBody((J.Block)classDecl.getBody().withTemplate((SourceTemplate)JavaTemplate.builder(() -> this.getCursor().getParentTreeCursor(), (String)"private MockedStatic<#{}> mocked#{};").javaParser(() -> JavaParser.fromJavaVersion().classpathFromResources(ctx, new String[]{"mockito-core-3.12.4"}).build()).staticImports(new String[]{"org.mockito.Mockito.mockStatic"}).imports(new String[]{MOCKED_STATIC}).build(), classDecl.getBody().getCoordinates().firstStatement(), new Object[]{classlessTypeName, classlessTypeName}));
                J.VariableDeclarations mockField = (J.VariableDeclarations)classDecl.getBody().getStatements().get(0);
                mockedTypesIdentifiers.add(((J.VariableDeclarations.NamedVariable)mockField.getVariables().get(0)).getName());
            }
            this.getCursor().putMessage(MOCKED_TYPES_FIELDS, mockedTypesIdentifiers);
            this.maybeAutoFormat((J)classDecl, (J)classDecl.withPrefix(classDecl.getPrefix().withWhitespace("")), (J)classDecl.getName(), ctx, this.getCursor());
            this.maybeAddImport(MOCKED_STATIC);
            this.maybeAddImport("org.mockito.Mockito", "mockStatic");
            return classDecl;
        }

        private static boolean isFieldAlreadyDefined(J.Block classBody, String fieldName) {
            for (Statement statement : classBody.getStatements()) {
                if (!(statement instanceof J.VariableDeclarations)) continue;
                for (J.VariableDeclarations.NamedVariable namedVariable : ((J.VariableDeclarations)statement).getVariables()) {
                    if (!namedVariable.getSimpleName().equals(fieldName)) continue;
                    return true;
                }
            }
            return false;
        }

        @NotNull
        private J.ClassDeclaration maybeAddSetUpMethodBody(J.ClassDeclaration classDecl, ExecutionContext ctx) {
            return this.maybeAddMethodWithAnnotation(classDecl, ctx, "setUp", this.testFrameworkInfo.setUpMethodAnnotationSignature, this.testFrameworkInfo.setUpMethodAnnotation, this.testFrameworkInfo.additionalClasspathResource, this.testFrameworkInfo.setUpImportToAdd);
        }

        @NotNull
        private J.ClassDeclaration maybeAddMethodWithAnnotation(J.ClassDeclaration classDecl, ExecutionContext ctx, String methodName, String methodAnnotationSignature, String methodAnnotationToAdd, String additionalClasspathResource, String importToAdd) {
            if (PowerMockitoToMockitoVisitor.hasMethodWithAnnotation(classDecl, new AnnotationMatcher(methodAnnotationSignature))) {
                return classDecl;
            }
            J.MethodDeclaration firstTestMethod = PowerMockitoToMockitoVisitor.getFirstTestMethod(classDecl.getBody().getStatements().stream().filter(statement -> statement instanceof J.MethodDeclaration).map(J.MethodDeclaration.class::cast).collect(Collectors.toList()));
            JavaCoordinates tearDownCoordinates = firstTestMethod != null ? firstTestMethod.getCoordinates().before() : classDecl.getBody().getCoordinates().lastStatement();
            classDecl = classDecl.withBody((J.Block)classDecl.getBody().withTemplate((SourceTemplate)JavaTemplate.builder(() -> this.getCursor().getParentTreeCursor(), (String)(methodAnnotationToAdd + " void " + methodName + "() {}")).javaParser(() -> JavaParser.fromJavaVersion().classpathFromResources(ctx, new String[]{additionalClasspathResource}).build()).imports(new String[]{importToAdd}).build(), tearDownCoordinates, new Object[0]));
            this.maybeAddImport(importToAdd);
            return classDecl;
        }

        @NotNull
        private J.ClassDeclaration maybeAddTearDownMethodBody(J.ClassDeclaration classDecl, ExecutionContext ctx) {
            return this.maybeAddMethodWithAnnotation(classDecl, ctx, "tearDown", this.testFrameworkInfo.tearDownMethodAnnotationSignature, this.testFrameworkInfo.tearDownMethodAnnotation, this.testFrameworkInfo.additionalClasspathResource, this.testFrameworkInfo.tearDownImportToAdd);
        }

        @Nullable
        private static J.MethodDeclaration getFirstTestMethod(List<J.MethodDeclaration> methods) {
            for (J.MethodDeclaration methodDeclaration : methods) {
                for (J.Annotation annotation : methodDeclaration.getLeadingAnnotations()) {
                    if (!annotation.getSimpleName().equals("Test")) continue;
                    return methodDeclaration;
                }
            }
            return null;
        }

        private static boolean containsTestNgTestMethods(List<J.MethodDeclaration> methods) {
            for (J.MethodDeclaration methodDeclaration : methods) {
                for (J.Annotation annotation : methodDeclaration.getAllAnnotations()) {
                    JavaType annotationType = annotation.getAnnotationType().getType();
                    if (!(annotationType instanceof JavaType.Class) || !((JavaType.Class)annotationType).getFullyQualifiedName().equals("org.testng.annotations.Test")) continue;
                    return true;
                }
            }
            return false;
        }

        private static boolean hasMethodWithAnnotation(J.ClassDeclaration classDecl, AnnotationMatcher annotationMatcher) {
            for (Statement statement : classDecl.getBody().getStatements()) {
                if (!(statement instanceof J.MethodDeclaration)) continue;
                J.MethodDeclaration methodDeclaration = (J.MethodDeclaration)statement;
                if (!methodDeclaration.getAllAnnotations().stream().anyMatch(arg_0 -> ((AnnotationMatcher)annotationMatcher).matches(arg_0))) continue;
                return true;
            }
            return false;
        }

        private static List<Expression> getMockedTypesFromPrepareForTestAnnotation(List<J.Annotation> prepareForTestAnnotations) {
            ArrayList<Expression> mockedTypes = new ArrayList<Expression>();
            for (J.Annotation prepareForTest : prepareForTestAnnotations) {
                if (prepareForTest == null || prepareForTest.getArguments() == null) continue;
                mockedTypes.addAll(ListUtils.flatMap((List)prepareForTest.getArguments(), a -> {
                    if (a instanceof J.NewArray && ((J.NewArray)a).getInitializer() != null) {
                        return ((J.NewArray)a).getInitializer();
                    }
                    if (a instanceof J.Assignment && ((J.NewArray)((J.Assignment)a).getAssignment()).getInitializer() != null) {
                        return ((J.NewArray)((J.Assignment)a).getAssignment()).getInitializer();
                    }
                    return null;
                }));
            }
            return mockedTypes;
        }

        public J visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
            J.MethodDeclaration m = (J.MethodDeclaration)super.visitMethodDeclaration(method, (Object)ctx);
            AnnotationMatcher tearDownAnnotationMatcher = new AnnotationMatcher(this.testFrameworkInfo.tearDownMethodAnnotationSignature);
            if (m.getAllAnnotations().stream().anyMatch(arg_0 -> ((AnnotationMatcher)tearDownAnnotationMatcher).matches(arg_0))) {
                List<J.Identifier> mockedTypesIdentifiers = (List<J.Identifier>)this.getCursor().pollNearestMessage(MOCKED_TYPES_FIELDS);
                if (mockedTypesIdentifiers == null) {
                    mockedTypesIdentifiers = this.getMockedTypesFields();
                }
                if (mockedTypesIdentifiers != null) {
                    for (J.Identifier mockedTypesField : mockedTypesIdentifiers) {
                        J.Block methodBody = m.getBody();
                        if (methodBody == null || PowerMockitoToMockitoVisitor.isStaticMockAlreadyClosed(mockedTypesField, methodBody)) continue;
                        m = m.withBody((J.Block)methodBody.withTemplate((SourceTemplate)JavaTemplate.builder(() -> this.getCursor().getParentTreeCursor(), (String)"#{any(org.mockito.MockedStatic)}.close();").javaParser(() -> JavaParser.fromJavaVersion().classpathFromResources(ctx, new String[]{"mockito-core-3.*"}).build()).build(), methodBody.getCoordinates().lastStatement(), new Object[]{mockedTypesField}));
                    }
                }
                this.setMockedTypesFields(mockedTypesIdentifiers);
                return m;
            }
            AnnotationMatcher setUpAnnotationMatcher = new AnnotationMatcher(this.testFrameworkInfo.setUpMethodAnnotationSignature);
            if (m.getAllAnnotations().stream().anyMatch(arg_0 -> ((AnnotationMatcher)setUpAnnotationMatcher).matches(arg_0))) {
                List<J.Identifier> mockedTypesIdentifiers = (List<J.Identifier>)this.getCursor().pollNearestMessage(MOCKED_TYPES_FIELDS);
                if (mockedTypesIdentifiers == null) {
                    mockedTypesIdentifiers = this.getMockedTypesFields();
                }
                if (mockedTypesIdentifiers != null) {
                    for (J.Identifier mockedTypesField : mockedTypesIdentifiers) {
                        J.Block methodBody = m.getBody();
                        if (methodBody == null || PowerMockitoToMockitoVisitor.isStaticMockAlreadyOpened(mockedTypesField, methodBody)) continue;
                        String className = ((JavaType.Class)((JavaType.Parameterized)mockedTypesField.getType()).getTypeParameters().get(0)).getClassName();
                        m = m.withBody((J.Block)methodBody.withTemplate((SourceTemplate)JavaTemplate.builder(() -> this.getCursor().getParentTreeCursor(), (String)"mocked#{any(org.mockito.MockedStatic)} = mockStatic(#{}.class);").javaParser(() -> JavaParser.fromJavaVersion().classpathFromResources(ctx, new String[]{"mockito-core-3.*"}).build()).build(), methodBody.getCoordinates().lastStatement(), new Object[]{mockedTypesField, className}));
                    }
                }
                this.setMockedTypesFields(mockedTypesIdentifiers);
                return m;
            }
            return m;
        }

        private void setMockedTypesFields(List<J.Identifier> mockedTypesFields) {
            this.mockedTypesFields = mockedTypesFields;
        }

        private static boolean isStaticMockAlreadyClosed(J.Identifier staticMock, J.Block methodBody) {
            return methodBody.getStatements().stream().filter(statement -> statement instanceof J.MethodInvocation).map(J.MethodInvocation.class::cast).filter(arg_0 -> ((MethodMatcher)MOCKED_STATIC_CLOSE_MATCHER).matches(arg_0)).filter(methodInvocation -> methodInvocation.getSelect() instanceof J.Identifier).anyMatch(methodInvocation -> ((J.Identifier)methodInvocation.getSelect()).getSimpleName().equals(staticMock.getSimpleName()));
        }

        private static boolean isStaticMockAlreadyOpened(J.Identifier staticMock, J.Block methodBody) {
            return methodBody.getStatements().stream().filter(statement -> statement instanceof J.MethodInvocation).map(J.MethodInvocation.class::cast).filter(arg_0 -> ((MethodMatcher)MOCKED_STATIC_MATCHER).matches(arg_0)).filter(methodInvocation -> methodInvocation.getSelect() instanceof J.Identifier).anyMatch(methodInvocation -> ((J.Identifier)methodInvocation.getSelect()).getSimpleName().equals(staticMock.getSimpleName()));
        }

        public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
            if (MOCKITO_WHEN_MATCHER.matches(method) || MOCKITO_VERIFY_MATCHER.matches(method)) {
                method = this.modifyWhenMethodInvocation(method);
            } else if (MOCKED_STATIC_MATCHER.matches(method)) {
                J.Assignment assignment = (J.Assignment)this.getCursor().firstEnclosing(J.Assignment.class);
                if (assignment != null) {
                    return super.visitMethodInvocation(method, (Object)ctx);
                }
                return null;
            }
            return super.visitMethodInvocation(method, (Object)ctx);
        }

        @NotNull
        private J.MethodInvocation modifyWhenMethodInvocation(J.MethodInvocation whenMethod) {
            List methodArguments = whenMethod.getArguments();
            List staticMethodInvocationsInArguments = methodArguments.stream().filter(expression -> expression instanceof J.MethodInvocation).map(J.MethodInvocation.class::cast).filter(methodInvocation -> !MOCKITO_STATIC_METHOD_MATCHER.matches(methodInvocation)).filter(methodInvocation -> methodInvocation.getMethodType() != null).filter(methodInvocation -> methodInvocation.getMethodType().hasFlags(new Flag[]{Flag.Static})).collect(Collectors.toList());
            if (staticMethodInvocationsInArguments.size() == 1) {
                JavaType.Method methodType;
                J.MethodInvocation staticMI = (J.MethodInvocation)staticMethodInvocationsInArguments.get(0);
                String declaringClassName = this.getDeclaringClassName(staticMI);
                Object lambdaInvocation = staticMI.getArguments().stream().map(Expression::getType).noneMatch(Objects::nonNull) ? (Expression)staticMI.withTemplate((SourceTemplate)JavaTemplate.builder(() -> ((PowerMockitoToMockitoVisitor)this).getCursor(), (String)(declaringClassName + "::" + staticMI.getSimpleName())).build(), staticMI.getCoordinates().replace(), new Object[0]) : ((methodType = staticMI.getMethodType()) != null ? (Expression)staticMI.withTemplate((SourceTemplate)JavaTemplate.builder(() -> ((PowerMockitoToMockitoVisitor)this).getCursor(), (String)("() -> #{any(" + methodType.getReturnType() + ")}")).build(), staticMI.getCoordinates().replace(), new Object[]{staticMI}) : staticMI);
                if (Collections.replaceAll(methodArguments, staticMI, lambdaInvocation)) {
                    J.Identifier mockedField = this.getFieldIdentifier("mocked" + declaringClassName);
                    whenMethod = whenMethod.withSelect((Expression)mockedField);
                    whenMethod = whenMethod.withArguments(methodArguments);
                }
            }
            return whenMethod;
        }

        @Nullable
        private String getDeclaringClassName(J.MethodInvocation mi) {
            JavaType.Method methodType = mi.getMethodType();
            if (methodType != null) {
                JavaType.FullyQualified declaringType = methodType.getDeclaringType();
                return declaringType.getClassName();
            }
            return null;
        }

        @Nullable
        private J.Identifier getFieldIdentifier(String name) {
            Optional<J.Identifier> optionalFieldIdentifier = this.getMockedTypesFields().stream().filter(identifier -> identifier.getSimpleName().equals(name)).findFirst();
            if (optionalFieldIdentifier.isPresent()) {
                return optionalFieldIdentifier.get();
            }
            J.ClassDeclaration cd = (J.ClassDeclaration)this.getCursor().dropParentUntil(it -> it instanceof J.ClassDeclaration).getValue();
            List collect = cd.getBody().getStatements().stream().filter(statement -> statement instanceof J.VariableDeclarations).map(variableDeclarations -> ((J.VariableDeclarations)variableDeclarations).getVariables()).collect(Collectors.toList());
            for (List namedVariables : collect) {
                for (J.VariableDeclarations.NamedVariable namedVariable : namedVariables) {
                    if (!namedVariable.getSimpleName().equals(name)) continue;
                    return namedVariable.getName();
                }
            }
            return null;
        }

        private static class TestFrameworkInfo {
            private final String setUpMethodAnnotationSignature;
            private final String setUpMethodAnnotation;
            private final String tearDownMethodAnnotationSignature;
            private final String tearDownMethodAnnotation;
            private final String additionalClasspathResource;
            private final String setUpImportToAdd;
            private final String tearDownImportToAdd;

            public TestFrameworkInfo(String setUpMethodAnnotationName, String tearDownMethodAnnotationName, String annotationPackage, String additionalClasspathResource) {
                this.setUpMethodAnnotation = "@" + setUpMethodAnnotationName;
                this.tearDownMethodAnnotation = "@" + tearDownMethodAnnotationName;
                this.setUpMethodAnnotationSignature = "@" + annotationPackage + "." + setUpMethodAnnotationName;
                this.tearDownMethodAnnotationSignature = "@" + annotationPackage + "." + tearDownMethodAnnotationName;
                this.setUpImportToAdd = annotationPackage + "." + setUpMethodAnnotationName;
                this.tearDownImportToAdd = annotationPackage + "." + tearDownMethodAnnotationName;
                this.additionalClasspathResource = additionalClasspathResource;
            }
        }
    }
}

