/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.staticanalysis;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import lombok.Generated;
import org.jspecify.annotations.Nullable;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Option;
import org.openrewrite.Recipe;
import org.openrewrite.Tree;
import org.openrewrite.TreeVisitor;
import org.openrewrite.Validated;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.java.AnnotationMatcher;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.RemoveAnnotationVisitor;
import org.openrewrite.java.ShortenFullyQualifiedTypeReferences;
import org.openrewrite.java.search.FindAnnotations;
import org.openrewrite.java.search.SemanticallyEqual;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.MethodCall;
import org.openrewrite.java.tree.Space;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.staticanalysis.SimplifyBooleanExpression;
import org.openrewrite.staticanalysis.SimplifyConstantIfBranchExecution;
import org.openrewrite.staticanalysis.java.MoveFieldAnnotationToType;

public final class AnnotateRequiredParameters
extends Recipe {
    private static final String DEFAULT_NONNULL_ANN_CLASS = "org.jspecify.annotations.NonNull";
    @Option(displayName="`@NonNull` annotation class", description="The fully qualified name of the @NonNull annotation. The annotation should be meta annotated with `@Target(TYPE_USE)`. Defaults to `org.jspecify.annotations.NonNull`", example="org.jspecify.annotations.NonNull", required=false)
    private final @Nullable String nonNullAnnotationClass;
    private final String displayName = "Annotate required method parameters with `@NonNull`";
    private final String description = "Add `@NonNull` to parameters of public methods that are explicitly checked for `null` and throw an exception if null. By default `org.jspecify.annotations.NonNull` is used, but through the `nonNullAnnotationClass` option a custom annotation can be provided. When providing a custom `nonNullAnnotationClass` that annotation should be meta annotated with `@Target(TYPE_USE)`. This recipe scans for methods that do not already have parameters annotated with `@NonNull` annotation and checks for null validation patterns that throw exceptions, such as `if (param == null) throw new IllegalArgumentException()`.";

    public Validated<Object> validate() {
        return super.validate().and(Validated.test((String)"nonNullAnnotationClass", (String)"Property `nonNullAnnotationClass` must be a fully qualified classname.", (Object)this.nonNullAnnotationClass, it -> it == null || it.contains(".")));
    }

    public TreeVisitor<?, ExecutionContext> getVisitor() {
        final String fullyQualifiedName = this.nonNullAnnotationClass != null ? this.nonNullAnnotationClass : DEFAULT_NONNULL_ANN_CLASS;
        final String fullyQualifiedPackage = fullyQualifiedName.substring(0, fullyQualifiedName.lastIndexOf(46));
        final String simpleName = fullyQualifiedName.substring(fullyQualifiedName.lastIndexOf(46) + 1);
        return new JavaIsoVisitor<ExecutionContext>(){

            public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDeclaration, ExecutionContext ctx) {
                J.MethodDeclaration md = super.visitMethodDeclaration(methodDeclaration, (Object)ctx);
                if (!md.hasModifier(J.Modifier.Type.Public) || md.getBody() == null || md.getParameters().isEmpty() || md.getParameters().get(0) instanceof J.Empty || md.getMethodType() == null || md.getMethodType().isOverride()) {
                    return md;
                }
                RequiredParameterAnalysis analysis = (RequiredParameterAnalysis)new RequiredParameterVisitor(AnnotateRequiredParameters.this.getAllParameters(md)).reduce((Tree)md.getBody(), new RequiredParameterAnalysis());
                if (analysis.requiredIdentifiers.isEmpty()) {
                    return md;
                }
                md = this.addAnnotationsToParameters(md, analysis, ctx);
                md = md.withBody((J.Block)new ReplaceNullChecksWithFalse(analysis.requiredIdentifiers).visit((Tree)md.getBody(), ctx));
                this.doAfterVisit(new SimplifyBooleanExpression().getVisitor());
                this.doAfterVisit(new SimplifyConstantIfBranchExecution().getVisitor());
                return md;
            }

            private J.MethodDeclaration addAnnotationsToParameters(J.MethodDeclaration md, RequiredParameterAnalysis analysis, ExecutionContext ctx) {
                this.maybeAddImport(fullyQualifiedName);
                String nullableFqn = fullyQualifiedPackage + ".Nullable";
                return md.withParameters(ListUtils.map((List)md.getParameters(), stm -> {
                    J.VariableDeclarations vd;
                    J.Identifier identifier;
                    if (stm instanceof J.VariableDeclarations && AnnotateRequiredParameters.containsIdentifierByName(analysis.requiredIdentifiers, identifier = ((J.VariableDeclarations.NamedVariable)(vd = (J.VariableDeclarations)stm).getVariables().get(0)).getName()) && FindAnnotations.find((J)vd, (String)("@" + fullyQualifiedName)).isEmpty()) {
                        this.maybeRemoveImport(nullableFqn);
                        vd = (J.VariableDeclarations)new RemoveAnnotationVisitor(new AnnotationMatcher(nullableFqn)).visit((Tree)vd, (Object)ctx, this.getCursor());
                        J.VariableDeclarations annotated = (J.VariableDeclarations)JavaTemplate.builder((String)("@" + fullyQualifiedName)).javaParser(JavaParser.fromJavaVersion().dependsOn(new String[]{String.format("package %s;public @interface %s {}", fullyQualifiedPackage, simpleName)})).build().apply(new Cursor(this.getCursor(), (Object)vd), vd.getCoordinates().addAnnotation(Comparator.comparing(J.Annotation::getSimpleName)), new Object[0]);
                        this.doAfterVisit((TreeVisitor)ShortenFullyQualifiedTypeReferences.modifyOnly((J)annotated));
                        this.doAfterVisit(new MoveFieldAnnotationToType(fullyQualifiedName).getVisitor());
                        return annotated.withModifiers(ListUtils.mapFirst((List)annotated.getModifiers(), first -> first.withPrefix(Space.SINGLE_SPACE)));
                    }
                    return stm;
                }));
            }
        };
    }

    private static boolean containsIdentifierByName(Collection<J.Identifier> identifiers, // Could not load outer class - annotation placement on inner may be incorrect
    @Nullable J.Identifier target) {
        if (target == null) {
            return false;
        }
        for (J.Identifier identifier : identifiers) {
            if (!SemanticallyEqual.areEqual((J)identifier, (J)target)) continue;
            return true;
        }
        return false;
    }

    private Set<J.Identifier> getAllParameters(J.MethodDeclaration md) {
        LinkedHashSet<J.Identifier> allParams = new LinkedHashSet<J.Identifier>();
        for (Statement parameter : md.getParameters()) {
            if (!(parameter instanceof J.VariableDeclarations)) continue;
            allParams.add(((J.VariableDeclarations.NamedVariable)((J.VariableDeclarations)parameter).getVariables().get(0)).getName());
        }
        return allParams;
    }

    @Generated
    public AnnotateRequiredParameters(@Nullable String nonNullAnnotationClass) {
        this.nonNullAnnotationClass = nonNullAnnotationClass;
    }

    @Generated
    public @Nullable String getNonNullAnnotationClass() {
        return this.nonNullAnnotationClass;
    }

    @Generated
    public String getDisplayName() {
        return this.displayName;
    }

    @Generated
    public String getDescription() {
        return this.description;
    }

    @Generated
    public String toString() {
        return "AnnotateRequiredParameters(nonNullAnnotationClass=" + this.getNonNullAnnotationClass() + ", displayName=" + this.getDisplayName() + ", description=" + this.getDescription() + ")";
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AnnotateRequiredParameters)) {
            return false;
        }
        AnnotateRequiredParameters other = (AnnotateRequiredParameters)((Object)o);
        if (!other.canEqual((Object)this)) {
            return false;
        }
        String this$nonNullAnnotationClass = this.getNonNullAnnotationClass();
        String other$nonNullAnnotationClass = other.getNonNullAnnotationClass();
        if (this$nonNullAnnotationClass == null ? other$nonNullAnnotationClass != null : !this$nonNullAnnotationClass.equals(other$nonNullAnnotationClass)) {
            return false;
        }
        String this$displayName = this.getDisplayName();
        String other$displayName = other.getDisplayName();
        if (this$displayName == null ? other$displayName != null : !this$displayName.equals(other$displayName)) {
            return false;
        }
        String this$description = this.getDescription();
        String other$description = other.getDescription();
        return !(this$description == null ? other$description != null : !this$description.equals(other$description));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof AnnotateRequiredParameters;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        String $nonNullAnnotationClass = this.getNonNullAnnotationClass();
        result = result * 59 + ($nonNullAnnotationClass == null ? 43 : $nonNullAnnotationClass.hashCode());
        String $displayName = this.getDisplayName();
        result = result * 59 + ($displayName == null ? 43 : $displayName.hashCode());
        String $description = this.getDescription();
        result = result * 59 + ($description == null ? 43 : $description.hashCode());
        return result;
    }

    private static class ReplaceNullChecksWithFalse
    extends JavaVisitor<ExecutionContext> {
        private static final MethodMatcher REQUIRE_NON_NULL = new MethodMatcher("java.util.Objects requireNonNull(..)");
        private final Set<J.Identifier> requiredIdentifiers;

        public J visitBinary(J.Binary binary, ExecutionContext ctx) {
            J.Binary b = (J.Binary)super.visitBinary(binary, (Object)ctx);
            if (b.getOperator() == J.Binary.Type.Equal) {
                J.Identifier paramIdentifier = null;
                if (J.Literal.isLiteralValue((Expression)b.getLeft(), null) && b.getRight() instanceof J.Identifier) {
                    paramIdentifier = (J.Identifier)b.getRight();
                } else if (J.Literal.isLiteralValue((Expression)b.getRight(), null) && b.getLeft() instanceof J.Identifier) {
                    paramIdentifier = (J.Identifier)b.getLeft();
                }
                if (AnnotateRequiredParameters.containsIdentifierByName(this.requiredIdentifiers, paramIdentifier)) {
                    return new J.Literal(Tree.randomId(), b.getPrefix(), b.getMarkers(), (Object)false, "false", null, JavaType.Primitive.Boolean);
                }
            }
            return b;
        }

        public @Nullable J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
            J.MethodInvocation m = (J.MethodInvocation)super.visitMethodInvocation(method, (Object)ctx);
            if (REQUIRE_NON_NULL.matches((MethodCall)m) && m.getArguments().get(0) instanceof J.Identifier && AnnotateRequiredParameters.containsIdentifierByName(this.requiredIdentifiers, (J.Identifier)m.getArguments().get(0))) {
                Cursor parent = this.getCursor().getParentTreeCursor();
                if (parent.getValue() instanceof J.Block) {
                    return null;
                }
                return ((Expression)m.getArguments().get(0)).withPrefix(m.getPrefix());
            }
            return m;
        }

        @Generated
        public ReplaceNullChecksWithFalse(Set<J.Identifier> requiredIdentifiers) {
            this.requiredIdentifiers = requiredIdentifiers;
        }
    }

    private static class RequiredParameterVisitor
    extends JavaIsoVisitor<RequiredParameterAnalysis> {
        private static final MethodMatcher REQUIRE_NON_NULL = new MethodMatcher("java.util.Objects requireNonNull(..)");
        private final Collection<J.Identifier> parameterIdentifiers;

        public J.If visitIf(J.If iff, RequiredParameterAnalysis analysis) {
            Expression condition = (Expression)iff.getIfCondition().getTree();
            List<J.Identifier> nullCheckedParams = this.extractNullCheckedParameters(condition);
            if (!nullCheckedParams.isEmpty() && this.bodyThrowsException(iff.getThenPart())) {
                for (J.Identifier param : nullCheckedParams) {
                    if (!AnnotateRequiredParameters.containsIdentifierByName(this.parameterIdentifiers, param)) continue;
                    analysis.requiredIdentifiers.add(param);
                }
            }
            return iff;
        }

        public Statement visitStatement(Statement statement, RequiredParameterAnalysis analysis) {
            J.Identifier firstArgument;
            J.MethodInvocation method;
            if (statement instanceof J.MethodInvocation && REQUIRE_NON_NULL.matches((MethodCall)(method = (J.MethodInvocation)statement)) && !method.getArguments().isEmpty() && method.getArguments().get(0) instanceof J.Identifier && AnnotateRequiredParameters.containsIdentifierByName(this.parameterIdentifiers, firstArgument = (J.Identifier)method.getArguments().get(0))) {
                analysis.requiredIdentifiers.add(firstArgument);
            }
            return super.visitStatement(statement, (Object)analysis);
        }

        private List<J.Identifier> extractNullCheckedParameters(Expression condition) {
            ArrayList<J.Identifier> params = new ArrayList<J.Identifier>();
            this.extractNullCheckedParametersRecursive(condition, params);
            return params;
        }

        private void extractNullCheckedParametersRecursive(Expression condition, List<J.Identifier> params) {
            if (condition instanceof J.Binary) {
                J.Binary binary = (J.Binary)condition;
                J.Binary.Type operator = binary.getOperator();
                if (operator == J.Binary.Type.Or) {
                    this.extractNullCheckedParametersRecursive(binary.getLeft(), params);
                    this.extractNullCheckedParametersRecursive(binary.getRight(), params);
                } else if (operator == J.Binary.Type.Equal) {
                    if (J.Literal.isLiteralValue((Expression)binary.getLeft(), null) && binary.getRight() instanceof J.Identifier) {
                        params.add((J.Identifier)binary.getRight());
                    } else if (J.Literal.isLiteralValue((Expression)binary.getRight(), null) && binary.getLeft() instanceof J.Identifier) {
                        params.add((J.Identifier)binary.getLeft());
                    }
                }
            }
        }

        private boolean bodyThrowsException(Statement body) {
            if (body instanceof J.Throw) {
                return true;
            }
            if (body instanceof J.Block) {
                J.Block block = (J.Block)body;
                for (Statement statement : block.getStatements()) {
                    if (!(statement instanceof J.Throw)) continue;
                    return true;
                }
            }
            return false;
        }

        @Generated
        public RequiredParameterVisitor(Collection<J.Identifier> parameterIdentifiers) {
            this.parameterIdentifiers = parameterIdentifiers;
        }
    }

    private static class RequiredParameterAnalysis {
        final Set<J.Identifier> requiredIdentifiers = new HashSet<J.Identifier>();

        private RequiredParameterAnalysis() {
        }
    }
}

