/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.hibernate;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Generated;
import org.jspecify.annotations.Nullable;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Preconditions;
import org.openrewrite.Recipe;
import org.openrewrite.Tree;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.FindImplementations;
import org.openrewrite.java.search.FindMethodDeclaration;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JRightPadded;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.MethodCall;
import org.openrewrite.java.tree.Space;
import org.openrewrite.java.tree.TypeTree;
import org.openrewrite.java.tree.TypeUtils;
import org.openrewrite.marker.Markers;

public class MigrateUserType
extends Recipe {
    private static final String USER_TYPE = "org.hibernate.usertype.UserType";
    private static final MethodMatcher ASSEMBLE = new MethodMatcher("* assemble(java.io.Serializable, java.lang.Object)");
    private static final MethodMatcher DEEP_COPY = new MethodMatcher("* deepCopy(java.lang.Object)");
    private static final MethodMatcher DISASSEMBLE = new MethodMatcher("* disassemble(java.lang.Object)");
    private static final MethodMatcher EQUALS = new MethodMatcher("* equals(java.lang.Object, java.lang.Object)");
    private static final MethodMatcher HASHCODE = new MethodMatcher("* hashCode(java.lang.Object)");
    private static final MethodMatcher NULL_SAFE_GET_STRING_ARRAY = new MethodMatcher("* nullSafeGet(java.sql.ResultSet, java.lang.String[], org.hibernate.engine.spi.SharedSessionContractImplementor, java.lang.Object)");
    private static final MethodMatcher NULL_SAFE_SET = new MethodMatcher("* nullSafeSet(java.sql.PreparedStatement, java.lang.Object, int, org.hibernate.engine.spi.SharedSessionContractImplementor)");
    private static final MethodMatcher NULL_SAFE_GET_INT = new MethodMatcher("* nullSafeGet(java.sql.ResultSet, int, org.hibernate.engine.spi.SharedSessionContractImplementor, java.lang.Object)");
    private static final MethodMatcher REPLACE = new MethodMatcher("* replace(java.lang.Object, java.lang.Object, java.lang.Object)");
    private static final MethodMatcher RESULT_SET_STRING_PARAM = new MethodMatcher("java.sql.ResultSet *(java.lang.String)");
    private static final MethodMatcher RETURNED_CLASS = new MethodMatcher("* returnedClass()");
    private static final MethodMatcher SQL_TYPES = new MethodMatcher("* sqlTypes()");
    final String displayName = "Migrate `UserType` to Hibernate 6";
    final String description = "With Hibernate 6 the `UserType` interface received a type parameter making it more strictly typed. This recipe applies the changes required to adhere to this change.";

    public TreeVisitor<?, ExecutionContext> getVisitor() {
        return Preconditions.check((TreeVisitor)Preconditions.and((TreeVisitor[])new TreeVisitor[]{new FindImplementations(USER_TYPE).getVisitor(), Preconditions.not((TreeVisitor)new FindMethodDeclaration("* getSqlType()", Boolean.valueOf(true)).getVisitor())}), (TreeVisitor)new JavaVisitor<ExecutionContext>(){

            public J visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) {
                J.ClassDeclaration cd = classDecl;
                J.FieldAccess parameterizedType = this.getReturnedClass(cd);
                cd = cd.withImplements(ListUtils.map((List)cd.getImplements(), impl -> {
                    if (TypeUtils.isAssignableTo((String)MigrateUserType.USER_TYPE, (JavaType)impl.getType()) && parameterizedType != null) {
                        return (TypeTree)TypeTree.build((String)("UserType<" + parameterizedType.getTarget() + ">")).withType(JavaType.buildType((String)MigrateUserType.USER_TYPE)).withPrefix(Space.SINGLE_SPACE);
                    }
                    return impl;
                }));
                if (parameterizedType != null) {
                    this.getCursor().putMessage("parameterizedType", (Object)parameterizedType);
                }
                return super.visitClassDeclaration(cd, (Object)ctx);
            }

            private // Could not load outer class - annotation placement on inner may be incorrect
            @Nullable J.FieldAccess getReturnedClass(final J.ClassDeclaration cd) {
                AtomicReference reference = new AtomicReference();
                new JavaIsoVisitor<AtomicReference<J.FieldAccess>>(){

                    public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, AtomicReference<J.FieldAccess> ref) {
                        return RETURNED_CLASS.matches(method, cd) ? super.visitMethodDeclaration(method, ref) : method;
                    }

                    public J.Return visitReturn(J.Return _return, AtomicReference<J.FieldAccess> ref) {
                        Expression expression = _return.getExpression();
                        if (expression instanceof J.FieldAccess && "class".equals(((J.FieldAccess)expression).getSimpleName())) {
                            ref.set((J.FieldAccess)expression);
                        }
                        return _return;
                    }
                }.visitNonNull((Tree)cd, reference);
                return (J.FieldAccess)reference.get();
            }

            public J visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
                J.MethodDeclaration md = method;
                J.ClassDeclaration cd = (J.ClassDeclaration)this.getCursor().firstEnclosing(J.ClassDeclaration.class);
                J.FieldAccess parameterizedType = (J.FieldAccess)this.getCursor().getNearestMessage("parameterizedType");
                if (cd == null || parameterizedType == null) {
                    return md;
                }
                if (SQL_TYPES.matches(md, cd)) {
                    if (md.getBody() != null) {
                        J.NewArray newArray;
                        Optional<J.Return> ret = md.getBody().getStatements().stream().filter(J.Return.class::isInstance).map(J.Return.class::cast).findFirst();
                        if (ret.isPresent() && ret.get().getExpression() instanceof J.NewArray && (newArray = (J.NewArray)ret.get().getExpression()).getInitializer() != null) {
                            String template = "@Override\npublic int getSqlType() {\n    return #{any()};\n}";
                            md = (J.MethodDeclaration)JavaTemplate.builder((String)template).javaParser(JavaParser.fromJavaVersion()).build().apply(this.getCursor(), md.getCoordinates().replace(), new Object[]{newArray.getInitializer().get(0)}).withId(md.getId());
                        }
                    }
                } else if (RETURNED_CLASS.matches(md, cd)) {
                    if ((md = md.withReturnTypeExpression(TypeTree.build((String)("Class<" + parameterizedType.getTarget() + ">")))).getReturnTypeExpression() != null) {
                        md = md.withPrefix(md.getReturnTypeExpression().getPrefix());
                    }
                } else if (EQUALS.matches(md, cd)) {
                    md = this.changeParameterTypes(md, Arrays.asList(0, 1), parameterizedType);
                } else if (HASHCODE.matches(md, cd)) {
                    md = this.changeParameterTypes(md, Collections.singletonList(0), parameterizedType);
                } else if (NULL_SAFE_GET_STRING_ARRAY.matches(md, cd)) {
                    String template = "@Override\npublic BigDecimal nullSafeGet(ResultSet rs, int position, SharedSessionContractImplementor session, Object owner) throws SQLException {\n}";
                    J.MethodDeclaration updatedParam = (J.MethodDeclaration)JavaTemplate.builder((String)template).javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, new String[]{"hibernate-core"})).imports(new String[]{"java.math.BigDecimal", "java.sql.ResultSet", "java.sql.SQLException", "org.hibernate.engine.spi.SharedSessionContractImplementor"}).build().apply(this.getCursor(), md.getCoordinates().replace(), new Object[0]);
                    md = updatedParam.withId(md.getId()).withBody(md.getBody());
                } else if (NULL_SAFE_SET.matches(md, cd)) {
                    md = this.changeParameterTypes(md, Collections.singletonList(1), parameterizedType);
                } else if (DEEP_COPY.matches(md, cd)) {
                    if ((md = md.withReturnTypeExpression((TypeTree)parameterizedType.getTarget().withPrefix(Space.SINGLE_SPACE))).getReturnTypeExpression() != null) {
                        md = md.withPrefix(md.getReturnTypeExpression().getPrefix());
                    }
                    md = this.changeParameterTypes(md, Collections.singletonList(0), parameterizedType);
                } else if (DISASSEMBLE.matches(md, cd)) {
                    if ((md = this.changeParameterTypes(md, Collections.singletonList(0), parameterizedType)).getBody() != null) {
                        md = md.withBody(md.getBody().withStatements(ListUtils.map((List)md.getBody().getStatements(), stmt -> {
                            J.Return r;
                            if (stmt instanceof J.Return && (r = (J.Return)stmt).getExpression() instanceof J.TypeCast) {
                                J.TypeCast tc = (J.TypeCast)r.getExpression();
                                if (TypeUtils.isOfType((JavaType)parameterizedType.getTarget().getType(), (JavaType)tc.getClazz().getType())) {
                                    return r.withExpression(tc.getExpression());
                                }
                            }
                            return stmt;
                        })));
                    }
                } else if (ASSEMBLE.matches(md, cd)) {
                    if ((md = md.withReturnTypeExpression((TypeTree)parameterizedType.getTarget().withPrefix(Space.SINGLE_SPACE))).getReturnTypeExpression() != null) {
                        md = md.withPrefix(md.getReturnTypeExpression().getPrefix());
                    }
                    if (md.getBody() != null) {
                        md = md.withBody(md.getBody().withStatements(ListUtils.map((List)md.getBody().getStatements(), stmt -> {
                            J.Return r;
                            if (stmt instanceof J.Return && (r = (J.Return)stmt).getExpression() != null && !TypeUtils.isOfType((JavaType)parameterizedType.getTarget().getType(), (JavaType)r.getExpression().getType())) {
                                return r.withExpression((Expression)new J.TypeCast(Tree.randomId(), Space.EMPTY, Markers.EMPTY, new J.ControlParentheses(Tree.randomId(), Space.EMPTY, Markers.EMPTY, new JRightPadded((Object)((TypeTree)TypeTree.build((String)"BigDecimal").withType(parameterizedType.getTarget().getType())), Space.EMPTY, Markers.EMPTY)), r.getExpression()));
                            }
                            return stmt;
                        })));
                    }
                } else if (REPLACE.matches(md, cd)) {
                    if ((md = md.withReturnTypeExpression((TypeTree)parameterizedType.getTarget().withPrefix(Space.SINGLE_SPACE))).getReturnTypeExpression() != null) {
                        md = md.withPrefix(md.getReturnTypeExpression().getPrefix());
                    }
                    md = this.changeParameterTypes(md, Arrays.asList(0, 1), parameterizedType);
                }
                this.updateCursor((Tree)md);
                md = (J.MethodDeclaration)super.visitMethodDeclaration(md, (Object)ctx);
                return this.maybeAutoFormat((J)method, (J)md, ctx);
            }

            private J.MethodDeclaration changeParameterTypes(J.MethodDeclaration md, List<Integer> paramIndexes, J.FieldAccess parameterizedType) {
                if (md.getMethodType() != null) {
                    JavaType.Method met = md.getMethodType().withParameterTypes(ListUtils.map((List)md.getMethodType().getParameterTypes(), (index, type) -> {
                        if (paramIndexes.contains(index)) {
                            type = TypeUtils.isOfType((JavaType)JavaType.buildType((String)"java.lang.Object"), (JavaType)type) ? parameterizedType.getTarget().getType() : type;
                        }
                        return type;
                    }));
                    return md.withParameters(ListUtils.map((List)md.getParameters(), (index, param) -> {
                        if (param instanceof J.VariableDeclarations && paramIndexes.contains(index)) {
                            param = ((J.VariableDeclarations)param).withType(parameterizedType.getTarget().getType()).withTypeExpression((TypeTree)parameterizedType.getTarget()).withVariables(ListUtils.map((List)((J.VariableDeclarations)param).getVariables(), var -> {
                                if ((var = var.withType(parameterizedType.getTarget().getType())).getVariableType() != null && parameterizedType.getTarget().getType() != null) {
                                    var = var.withVariableType(var.getVariableType().withType(parameterizedType.getTarget().getType()).withOwner((JavaType)met));
                                }
                                return var;
                            }));
                        }
                        return param;
                    }));
                }
                return md;
            }

            public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
                J.MethodInvocation mi = (J.MethodInvocation)super.visitMethodInvocation(method, (Object)ctx);
                if (RESULT_SET_STRING_PARAM.matches((MethodCall)mi)) {
                    J.MethodDeclaration md = (J.MethodDeclaration)this.getCursor().firstEnclosing(J.MethodDeclaration.class);
                    J.ClassDeclaration cd = (J.ClassDeclaration)this.getCursor().firstEnclosing(J.ClassDeclaration.class);
                    if (md != null && cd != null && NULL_SAFE_GET_INT.matches(md, cd) && (mi = mi.withArguments(Collections.singletonList(((J.VariableDeclarations.NamedVariable)((J.VariableDeclarations)md.getParameters().get(1)).getVariables().get(0)).getName()))).getMethodType() != null) {
                        mi = mi.withMethodType(mi.getMethodType().withParameterTypes(Collections.singletonList(JavaType.buildType((String)"int"))));
                    }
                }
                return mi;
            }
        });
    }

    @Generated
    public String getDisplayName() {
        return this.displayName;
    }

    @Generated
    public String getDescription() {
        return this.description;
    }
}

