/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.java.testing.mockito;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Generated;
import org.jspecify.annotations.Nullable;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Preconditions;
import org.openrewrite.Recipe;
import org.openrewrite.Tree;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.java.AnnotationMatcher;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.TypeMatcher;
import org.openrewrite.java.VariableNameUtils;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.Flag;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaSourceFile;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.java.tree.TypeUtils;

public class MockitoWhenOnStaticToMockStatic
extends Recipe {
    private static final AnnotationMatcher JUNIT_4_ANNOTATION = new AnnotationMatcher("org.junit.*");
    private static final AnnotationMatcher JUNIT_5_ANNOTATION = new AnnotationMatcher("org.junit.jupiter.api.*");
    private static final AnnotationMatcher TESTNG_ANNOTATION = new AnnotationMatcher("org.testng.annotations.*");
    private static final AnnotationMatcher BEFORE = new AnnotationMatcher("org..Before*");
    private static final AnnotationMatcher BEFORE_CLASS = new AnnotationMatcher("org..BeforeClass");
    private static final AnnotationMatcher BEFORE_ALL = new AnnotationMatcher("org..BeforeAll");
    private static final AnnotationMatcher BEFORE_PARAM_CLASS_INV = new AnnotationMatcher("org..BeforeParameterizedClassInvocation");
    private static final MethodMatcher MOCKITO_WHEN = new MethodMatcher("org.mockito.Mockito when(..)");
    private static final TypeMatcher MOCKED_STATIC = new TypeMatcher("org.mockito.MockedStatic");
    private static final String DEFAULT_AFTER_METHOD = "tearDown";
    private int varCounter = 0;
    final String displayName = "Replace `Mockito.when` on static (non mock) with try-with-resource with MockedStatic";
    final String description = "Replace `Mockito.when` on static (non mock) with try-with-resource with MockedStatic as Mockito4 no longer allows this. For JUnit 4/5 & TestNG: When `@Before*` is used, a `close` call is added to the corresponding `@After*` method. This change moves away from implicit bytecode manipulation for static method stubbing, making mocking behavior more explicit and scoped to avoid unintended side effects.";

    public TreeVisitor<?, ExecutionContext> getVisitor() {
        return Preconditions.check((TreeVisitor)new UsesMethod(MOCKITO_WHEN), (TreeVisitor)new JavaIsoVisitor<ExecutionContext>(){

            public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
                J.MethodDeclaration containingMethod = (J.MethodDeclaration)this.getCursor().firstEnclosing(J.MethodDeclaration.class);
                List<Statement> newStatements = MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAnnotation((Statement)containingMethod, new AnnotationMatcher[]{BEFORE}) ? this.maybeStatementsToMockedStatic(block, block.getStatements(), ctx) : this.maybeWrapStatementsInTryWithResourcesMockedStatic(block, block.getStatements(), ctx);
                J.Block b = super.visitBlock(block.withStatements(newStatements), (Object)ctx);
                return (J.Block)this.maybeAutoFormat((J)block, (J)b, ctx);
            }

            private List<Statement> maybeStatementsToMockedStatic(J.Block m, List<Statement> statements, ExecutionContext ctx) {
                ArrayList<Statement> list = new ArrayList<Statement>();
                for (Statement statement : statements) {
                    J.MethodInvocation whenArg = this.getWhenArg(statement);
                    if (whenArg != null) {
                        // Could not load outer class - annotation placement on inner may be incorrect
                        @Nullable JavaType.Class invokedType = this.getTypeFromInvocation(whenArg);
                        if (invokedType == null) continue;
                        list.addAll(this.mockedStatic(m, (J.MethodInvocation)statement, invokedType.getClassName(), whenArg, ctx));
                        continue;
                    }
                    list.add(statement);
                }
                return list;
            }

            private List<Statement> maybeWrapStatementsInTryWithResourcesMockedStatic(J.Block block, List<Statement> statements, ExecutionContext ctx) {
                AtomicBoolean restInTry = new AtomicBoolean(false);
                return ListUtils.map(statements, (index, statement) -> {
                    JavaType.Class invokedType;
                    if (restInTry.get()) {
                        return null;
                    }
                    J.MethodInvocation whenArg = this.getWhenArg((Statement)statement);
                    if (whenArg != null && (invokedType = this.getTypeFromInvocation(whenArg)) != null) {
                        Optional nameOfWrappingMockedStatic = MockitoWhenOnStaticToMockStatic.tryGetMatchedWrappingResourceName(this.getCursor(), (JavaType)invokedType);
                        if (nameOfWrappingMockedStatic.isPresent()) {
                            return this.reuseMockedStatic(block, (J.MethodInvocation)statement, nameOfWrappingMockedStatic.get(), whenArg, ctx);
                        }
                        J.Identifier staticMockedVariable = MockitoWhenOnStaticToMockStatic.findMockedStaticVariable(this.getCursor(), (JavaType)invokedType);
                        if (staticMockedVariable != null) {
                            return this.reuseMockedStatic(block, (J.MethodInvocation)statement, staticMockedVariable, whenArg, ctx);
                        }
                        restInTry.set(true);
                        return this.tryWithMockedStatic(block, statements, (Integer)index, (J.MethodInvocation)statement, invokedType.getClassName(), whenArg, ctx);
                    }
                    return statement;
                });
            }

            private // Could not load outer class - annotation placement on inner may be incorrect
            @Nullable J.MethodInvocation getWhenArg(Statement statement) {
                J.MethodInvocation whenArg;
                J.MethodInvocation when;
                if (statement instanceof J.MethodInvocation && MOCKITO_WHEN.matches(((J.MethodInvocation)statement).getSelect()) && (when = (J.MethodInvocation)((J.MethodInvocation)statement).getSelect()) != null && when.getArguments().get(0) instanceof J.MethodInvocation && (whenArg = (J.MethodInvocation)when.getArguments().get(0)).getMethodType() != null && whenArg.getMethodType().hasFlags(new Flag[]{Flag.Static})) {
                    return whenArg;
                }
                return null;
            }

            private // Could not load outer class - annotation placement on inner may be incorrect
            @Nullable JavaType.Class getTypeFromInvocation(J.MethodInvocation whenArg) {
                J.Identifier clazz = null;
                if (whenArg.getSelect() instanceof J.Identifier && ((J.Identifier)whenArg.getSelect()).getFieldType() == null) {
                    clazz = (J.Identifier)whenArg.getSelect();
                } else if (whenArg.getSelect() instanceof J.FieldAccess && ((J.FieldAccess)whenArg.getSelect()).getTarget() instanceof J.Identifier) {
                    clazz = (J.Identifier)((J.FieldAccess)whenArg.getSelect()).getTarget();
                }
                return clazz != null && clazz.getType() != null ? (JavaType.Class)clazz.getType() : null;
            }

            private J.Try tryWithMockedStatic(J.Block block, List<Statement> statements, Integer index, J.MethodInvocation statement, String className, J.MethodInvocation whenArg, ExecutionContext ctx) {
                String variableName = VariableNameUtils.generateVariableName((String)("mock" + className + ++MockitoWhenOnStaticToMockStatic.this.varCounter), (Cursor)this.updateCursor((Tree)block), (VariableNameUtils.GenerationStrategy)VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER);
                Expression thenReturnArg = (Expression)statement.getArguments().get(0);
                J.Try try_ = (J.Try)((J.Block)this.javaTemplateMockStatic(String.format("try(MockedStatic<%1$s> %2$s = mockStatic(%1$s.class)) {\n    %2$s.when(() -> #{any()}).thenReturn(#{any()});\n}", className, variableName), ctx).apply(this.getCursor(), block.getCoordinates().firstStatement(), new Object[]{whenArg, thenReturnArg})).getStatements().get(0);
                List<Statement> precedingStatements = statements.subList(0, index);
                List handledStatements = ListUtils.concat(precedingStatements, (Object)try_);
                List<Statement> remainingStatements = statements.subList(index + 1, statements.size());
                List newStatements = ListUtils.concatAll((List)try_.getBody().getStatements(), this.maybeWrapStatementsInTryWithResourcesMockedStatic(block.withStatements(handledStatements), remainingStatements, ctx));
                return try_.withBody(try_.getBody().withStatements(newStatements)).withPrefix(statement.getPrefix());
            }

            private Statement reuseMockedStatic(J.Block block, J.MethodInvocation statement, Object variable, J.MethodInvocation whenArg, ExecutionContext ctx) {
                String mockedStaticVariableTemplate = variable instanceof J ? "#{any()}" : "#{}";
                return (Statement)((J.Block)this.javaTemplateMockStatic(mockedStaticVariableTemplate + ".when(() -> #{any()}).thenReturn(#{any()});", ctx).apply(this.getCursor(), block.getCoordinates().firstStatement(), new Object[]{variable, whenArg, statement.getArguments().get(0)})).getStatements().get(0);
            }

            private List<Statement> mockedStatic(J.Block block, J.MethodInvocation statement, final String className, J.MethodInvocation whenArg, ExecutionContext ctx) {
                J.MethodDeclaration containingMethod = (J.MethodDeclaration)this.getCursor().firstEnclosing(J.MethodDeclaration.class);
                final boolean staticSetup = MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAnnotation((Statement)containingMethod, new AnnotationMatcher[]{BEFORE_CLASS, BEFORE_ALL, BEFORE_PARAM_CLASS_INV});
                final String variableName = VariableNameUtils.generateVariableName((String)("mock" + className + ++MockitoWhenOnStaticToMockStatic.this.varCounter), (Cursor)this.updateCursor((Tree)block), (VariableNameUtils.GenerationStrategy)VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER);
                final String matchedAnnotation = Objects.requireNonNull(MockitoWhenOnStaticToMockStatic.tryGetMatchedAnnotationOnMethodDeclaration(containingMethod, new AnnotationMatcher[]{BEFORE}));
                final String correspondingAfterFqn = matchedAnnotation.replace(".Before", ".After");
                Expression thenReturnArg = (Expression)statement.getArguments().get(0);
                List<Statement> statements = ((J.Block)this.javaTemplateMockStatic(String.format("%2$s = mockStatic(%1$s.class);\n%2$s.when(() -> #{any()}).thenReturn(#{any()});", className, variableName), ctx).apply(this.getCursor(), block.getCoordinates().firstStatement(), new Object[]{whenArg, thenReturnArg})).getStatements().subList(0, 2);
                this.doAfterVisit((TreeVisitor)new JavaIsoVisitor<ExecutionContext>(){

                    public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) {
                        J.ClassDeclaration after = (J.ClassDeclaration)JavaTemplate.builder((String)String.format("private%s MockedStatic<%s> %s;", staticSetup ? " static" : "", className, variableName)).contextSensitive().build().apply(this.updateCursor((Tree)classDecl), classDecl.getBody().getCoordinates().firstStatement(), new Object[0]);
                        List afterStatements = after.getBody().getStatements();
                        AnnotationMatcher specificBeforeMatcher = new AnnotationMatcher(matchedAnnotation);
                        if (classDecl.getBody().getStatements().stream().noneMatch(it -> MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAnnotation(it, new AnnotationMatcher[]{new AnnotationMatcher(correspondingAfterFqn)}))) {
                            String safeAfterMethodName = MockitoWhenOnStaticToMockStatic.getSafeAfterMethodName(MockitoWhenOnStaticToMockStatic.DEFAULT_AFTER_METHOD, afterStatements);
                            Optional<Statement> beforeMethodJunit4 = afterStatements.stream().filter(it -> MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAllAnnotations(it, new AnnotationMatcher[]{JUNIT_4_ANNOTATION, specificBeforeMatcher})).findFirst();
                            Optional<Statement> beforeMethodJunit5 = afterStatements.stream().filter(it -> MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAllAnnotations(it, new AnnotationMatcher[]{JUNIT_5_ANNOTATION, specificBeforeMatcher})).findFirst();
                            Optional<Statement> beforeMethodTestng = afterStatements.stream().filter(it -> MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAllAnnotations(it, new AnnotationMatcher[]{TESTNG_ANNOTATION, specificBeforeMatcher})).findFirst();
                            String afterAnnotationName = correspondingAfterFqn.substring(correspondingAfterFqn.lastIndexOf(46) + 1);
                            String template = String.format("@%1$s public%2$s void %3$s() {}", afterAnnotationName, staticSetup ? " static" : "", safeAfterMethodName);
                            if (beforeMethodJunit4.isPresent()) {
                                after = this.writeAfterMethod(after, beforeMethodJunit4.get(), ctx, template, correspondingAfterFqn, "junit-4");
                            } else if (beforeMethodJunit5.isPresent()) {
                                after = this.writeAfterMethod(after, beforeMethodJunit5.get(), ctx, template, correspondingAfterFqn, "junit-jupiter-api-5");
                            } else if (beforeMethodTestng.isPresent()) {
                                after = this.writeAfterMethod(after, beforeMethodTestng.get(), ctx, template, correspondingAfterFqn, "testng");
                            }
                        }
                        J.ClassDeclaration cd = super.visitClassDeclaration(after, (Object)ctx);
                        return (J.ClassDeclaration)this.maybeAutoFormat((J)classDecl, (J)cd, ctx);
                    }

                    private J.ClassDeclaration writeAfterMethod(J.ClassDeclaration after, Statement beforeMethod, ExecutionContext ctx, String template, String importClass, String ... classpaths) {
                        this.maybeAddImport(importClass);
                        return (J.ClassDeclaration)JavaTemplate.builder((String)template).imports(new String[]{importClass}).javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, classpaths)).build().apply(this.updateCursor((Tree)after), beforeMethod.getCoordinates().after(), new Object[0]);
                    }

                    public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDecl, ExecutionContext ctx) {
                        J.MethodDeclaration md = super.visitMethodDeclaration(methodDecl, (Object)ctx);
                        if (MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAnnotation((Statement)md, new AnnotationMatcher[]{new AnnotationMatcher(correspondingAfterFqn)})) {
                            return (J.MethodDeclaration)JavaTemplate.builder((String)(variableName + ".close();")).contextSensitive().build().apply(this.getCursor(), md.getBody().getCoordinates().lastStatement(), new Object[0]);
                        }
                        return md;
                    }
                });
                return statements;
            }

            private JavaTemplate javaTemplateMockStatic(String code, ExecutionContext ctx) {
                this.maybeAddImport("org.mockito.MockedStatic", false);
                this.maybeAddImport("org.mockito.Mockito", "mockStatic");
                return JavaTemplate.builder((String)code).contextSensitive().imports(new String[]{"org.mockito.MockedStatic"}).staticImports(new String[]{"org.mockito.Mockito.mockStatic"}).javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, new String[]{"mockito-core-5"})).build();
            }
        });
    }

    private static List<J.Try.Resource> getMatchingFilteredResources(@Nullable List<// Could not load outer class - annotation placement on inner may be incorrect
    J.Try.Resource> resources, JavaType className) {
        if (resources == null) {
            return Collections.emptyList();
        }
        return ListUtils.filter(resources, res -> MockitoWhenOnStaticToMockStatic.isMockedStaticOfType(className, (JavaType)((J.VariableDeclarations)res.getVariableDeclarations()).getTypeAsFullyQualified()));
    }

    private static boolean isMockedStaticOfType(JavaType mockedType, @Nullable JavaType comparisonType) {
        if (comparisonType != null && MOCKED_STATIC.matches(comparisonType) && comparisonType instanceof JavaType.Parameterized) {
            JavaType.Parameterized parameterizedType = Objects.requireNonNull(TypeUtils.asParameterized((JavaType)comparisonType));
            return parameterizedType.getTypeParameters().size() == 1 && TypeUtils.isAssignableTo((JavaType)mockedType, (JavaType)((JavaType)parameterizedType.getTypeParameters().get(0)));
        }
        return false;
    }

    private static Optional<String> tryGetMatchedWrappingResourceName(Cursor cursor, JavaType className) {
        try {
            Cursor foundParentCursor = cursor.dropParentUntil(val -> {
                if (val instanceof J.Try) {
                    List<J.Try.Resource> filteredResources = MockitoWhenOnStaticToMockStatic.getMatchingFilteredResources(((J.Try)val).getResources(), className);
                    return !filteredResources.isEmpty();
                }
                return false;
            });
            return MockitoWhenOnStaticToMockStatic.getMatchingFilteredResources(((J.Try)foundParentCursor.getValue()).getResources(), className).stream().findFirst().map(res -> ((J.VariableDeclarations.NamedVariable)((J.VariableDeclarations)res.getVariableDeclarations()).getVariables().get(0)).getSimpleName());
        }
        catch (IllegalStateException e) {
            return Optional.empty();
        }
    }

    private static boolean isMethodDeclarationWithAnnotation(@Nullable Statement statement, AnnotationMatcher ... matchers) {
        if (statement instanceof J.MethodDeclaration) {
            return ((J.MethodDeclaration)statement).getLeadingAnnotations().stream().anyMatch(it -> Arrays.stream(matchers).anyMatch(m -> m.matches(it)));
        }
        return false;
    }

    private static boolean isMethodDeclarationWithAllAnnotations(@Nullable Statement statement, AnnotationMatcher ... matchers) {
        if (statement instanceof J.MethodDeclaration) {
            return ((J.MethodDeclaration)statement).getLeadingAnnotations().stream().anyMatch(it -> Arrays.stream(matchers).allMatch(m -> m.matches(it)));
        }
        return false;
    }

    private static @Nullable String tryGetMatchedAnnotationOnMethodDeclaration(// Could not load outer class - annotation placement on inner may be incorrect
    @Nullable J.MethodDeclaration methodDecl, AnnotationMatcher ... matchers) {
        if (methodDecl != null) {
            return methodDecl.getLeadingAnnotations().stream().filter(it -> Arrays.stream(matchers).anyMatch(m -> m.matches(it))).findFirst().map(J.Annotation::getType).map(Object::toString).orElse(null);
        }
        return null;
    }

    private static String getSafeAfterMethodName(String baseName, List<Statement> existingStatements) {
        return existingStatements.stream().filter(it -> it instanceof J.MethodDeclaration).map(it -> ((J.MethodDeclaration)it).getSimpleName()).filter(s -> s.matches("^" + baseName + "(\\d+)?$")).max(Comparator.comparingInt(s -> s.equals(baseName) ? 0 : Integer.parseInt(s.substring(baseName.length())))).map(last -> {
            int suffix = last.equals(baseName) ? 0 : Integer.parseInt(last.substring(baseName.length()));
            return baseName + (suffix + 1);
        }).orElse(baseName);
    }

    private static // Could not load outer class - annotation placement on inner may be incorrect
    @Nullable J.Identifier findMockedStaticVariable(final Cursor scope, final JavaType className) {
        JavaSourceFile compilationUnit = (JavaSourceFile)scope.firstEnclosing(JavaSourceFile.class);
        if (compilationUnit == null) {
            return null;
        }
        return (J.Identifier)((AtomicReference)new JavaIsoVisitor<AtomicReference<J.Identifier>>(){

            public J.Block visitBlock(J.Block block, AtomicReference<J.Identifier> mockedStaticVar) {
                if (scope.isScopeInPath((Tree)block)) {
                    return super.visitBlock(block, mockedStaticVar);
                }
                return block;
            }

            public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, AtomicReference<J.Identifier> mockedStaticVar) {
                J.Identifier identifier = variable.getName();
                if (MockitoWhenOnStaticToMockStatic.isMockedStaticOfType(className, identifier.getType())) {
                    mockedStaticVar.set(identifier);
                }
                return super.visitVariable(variable, mockedStaticVar);
            }
        }.reduce((Tree)compilationUnit, new AtomicReference())).get();
    }

    @Generated
    public String getDisplayName() {
        return this.displayName;
    }

    @Generated
    public String getDescription() {
        return this.description;
    }
}

