/*
 * Decompiled with CFR 0.152.
 */
package org.talend.sdk.component.runtime.beam.transformer;

import java.io.ByteArrayOutputStream;
import java.io.NotSerializableException;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.io.ObjectStreamException;
import java.io.OutputStream;
import java.io.Serializable;
import java.lang.instrument.ClassFileTransformer;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.net.URLClassLoader;
import java.security.ProtectionDomain;
import java.util.Collection;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.xbean.asm7.ClassReader;
import org.apache.xbean.asm7.ClassVisitor;
import org.apache.xbean.asm7.ClassWriter;
import org.apache.xbean.asm7.Label;
import org.apache.xbean.asm7.MethodVisitor;
import org.apache.xbean.asm7.Type;
import org.apache.xbean.asm7.commons.AdviceAdapter;
import org.apache.xbean.asm7.commons.Method;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.talend.sdk.component.classloader.ConfigurableClassLoader;
import org.talend.sdk.component.runtime.serialization.ContainerFinder;

public class BeamIOTransformer
implements ClassFileTransformer {
    private static final Logger log = LoggerFactory.getLogger(BeamIOTransformer.class);
    private static final boolean DEBUG = Boolean.getBoolean("talend.component.beam.transformers.debug");
    private static final BiConsumer<OutputStream, Object> BYPASS_REPLACE_SERIALIZER = BeamIOTransformer.createSerializer();
    private final Collection<String> typesToEnhance;

    public BeamIOTransformer() {
        this(Stream.of("org.apache.beam.sdk.coders.Coder", "org.apache.beam.sdk.io.Source", "org.apache.beam.sdk.io.Source$Reader", "org.apache.beam.sdk.io.UnboundedSource$CheckpointMark", "org.apache.beam.sdk.transforms.DoFn", "org.apache.beam.sdk.transforms.PTransform", "org.apache.beam.sdk.transforms.Combine$CombineFn", "org.apache.beam.sdk.transforms.SerializableFunction", "org.apache.beam.sdk.values.TupleTag").collect(Collectors.toSet()));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) {
        if (className == null || !ConfigurableClassLoader.class.isInstance(loader)) {
            return classfileBuffer;
        }
        ConfigurableClassLoader classLoader = (ConfigurableClassLoader)ConfigurableClassLoader.class.cast(loader);
        URLClassLoader tmpLoader = classLoader.createTemporaryCopy();
        Thread thread = Thread.currentThread();
        ClassLoader old = thread.getContextClassLoader();
        thread.setContextClassLoader(tmpLoader);
        try {
            Class<?> tmpClass = this.loadTempClass(tmpLoader, className);
            if (tmpClass.getClassLoader() != tmpLoader.getParent() && this.doesHierarchyContain(tmpClass, this.typesToEnhance)) {
                byte[] byArray = this.rewrite(classLoader, className, classfileBuffer, tmpLoader, tmpClass);
                return byArray;
            }
        }
        catch (ClassNotFoundException | NoClassDefFoundError e) {
            if (DEBUG) {
                log.error("Can't load: " + className, e);
            }
        }
        finally {
            thread.setContextClassLoader(old);
        }
        return classfileBuffer;
    }

    private byte[] rewrite(ConfigurableClassLoader loader, String className, byte[] classfileBuffer, ClassLoader tmpLoader, Class<?> tmpClass) {
        String plugin = loader.getId();
        ClassReader reader = new ClassReader(classfileBuffer);
        ComponentClassWriter writer = new ComponentClassWriter(className.replace('/', '.'), tmpLoader, reader, 2);
        SerializableCoderReplacement serializableCoderReplacement = new SerializableCoderReplacement((ClassVisitor)writer, plugin, tmpClass);
        ComponentClassVisitor visitor = new ComponentClassVisitor(serializableCoderReplacement, plugin);
        reader.accept((ClassVisitor)visitor, 4);
        this.unsupportedLog(className);
        return writer.toByteArray();
    }

    private Class<?> loadTempClass(ClassLoader tmpLoader, String className) throws ClassNotFoundException {
        return tmpLoader.loadClass(className.replace('/', '.'));
    }

    private boolean doesHierarchyContain(Class<?> clazz, Collection<String> types) {
        Class<?> superclass = clazz.getSuperclass();
        if (Stream.of(clazz.getInterfaces()).anyMatch(itf -> types.contains(itf.getName()))) {
            return true;
        }
        if (superclass == null || Object.class == superclass) {
            return false;
        }
        if (types.contains(superclass.getName())) {
            return true;
        }
        return this.doesHierarchyContain(superclass, types);
    }

    private void unsupportedLog(String className) {
        log.debug("Rewrote {} bytecode, note it is not an officially supported component type and feature, this support can be dropped anytime", (Object)className);
    }

    public static ClassLoader setPluginTccl(String key) {
        Thread thread = Thread.currentThread();
        ClassLoader old = thread.getContextClassLoader();
        thread.setContextClassLoader(ContainerFinder.Instance.get().find(key).classloader());
        return old;
    }

    public static void resetTccl(ClassLoader loader) {
        Thread.currentThread().setContextClassLoader(loader);
    }

    private static BiConsumer<OutputStream, Object> createSerializer() {
        java.lang.reflect.Method writeEnum;
        java.lang.reflect.Method writeArray;
        java.lang.reflect.Method writeString;
        java.lang.reflect.Method writeHandle;
        java.lang.reflect.Method writeClassDesc;
        java.lang.reflect.Method writeClass;
        java.lang.reflect.Method writeNull;
        java.lang.reflect.Method setBlockDataMode;
        java.lang.reflect.Method handlesLookup;
        java.lang.reflect.Method subsLookup;
        Field handles;
        Field subs;
        Field depth;
        Field bout;
        java.lang.reflect.Method writeOrdinaryObject;
        try {
            writeOrdinaryObject = ObjectOutputStream.class.getDeclaredMethod("writeOrdinaryObject", Object.class, ObjectStreamClass.class, Boolean.TYPE);
            bout = ObjectOutputStream.class.getDeclaredField("bout");
            depth = ObjectOutputStream.class.getDeclaredField("depth");
            subs = ObjectOutputStream.class.getDeclaredField("subs");
            handles = ObjectOutputStream.class.getDeclaredField("handles");
            subsLookup = subs.getType().getDeclaredMethod("lookup", Object.class);
            handlesLookup = handles.getType().getDeclaredMethod("lookup", Object.class);
            setBlockDataMode = bout.getType().getDeclaredMethod("setBlockDataMode", Boolean.TYPE);
            writeNull = ObjectOutputStream.class.getDeclaredMethod("writeNull", new Class[0]);
            writeClass = ObjectOutputStream.class.getDeclaredMethod("writeClass", Class.class, Boolean.TYPE);
            writeClassDesc = ObjectOutputStream.class.getDeclaredMethod("writeClassDesc", ObjectStreamClass.class, Boolean.TYPE);
            writeHandle = ObjectOutputStream.class.getDeclaredMethod("writeHandle", Integer.TYPE);
            writeString = ObjectOutputStream.class.getDeclaredMethod("writeString", String.class, Boolean.TYPE);
            writeArray = ObjectOutputStream.class.getDeclaredMethod("writeArray", Object.class, ObjectStreamClass.class, Boolean.TYPE);
            writeEnum = ObjectOutputStream.class.getDeclaredMethod("writeEnum", Enum.class, ObjectStreamClass.class, Boolean.TYPE);
            Stream.of(writeOrdinaryObject, bout, depth, setBlockDataMode, subs, subsLookup, handles, handlesLookup, writeNull, writeClass, writeClassDesc, writeHandle, writeString, writeArray, writeEnum).forEach(accessible -> {
                if (!accessible.isAccessible()) {
                    accessible.setAccessible(true);
                }
            });
        }
        catch (Exception e) {
            throw new IllegalStateException(e);
        }
        return (out, obj) -> {
            try (ObjectOutputStream oos = new ObjectOutputStream((OutputStream)out);){
                boolean oldMode;
                block25: {
                    oldMode = (Boolean)Boolean.class.cast(setBlockDataMode.invoke(bout.get(oos), false));
                    depth.set(oos, (Integer)Integer.class.cast(depth.get(oos)) + 1);
                    try {
                        obj = subsLookup.invoke(subs.get(oos), obj);
                        if (obj == null) {
                            writeNull.invoke((Object)oos, new Object[0]);
                            break block25;
                        }
                        int h = (Integer)Integer.class.cast(handlesLookup.invoke(handles.get(oos), obj));
                        if (h != -1) {
                            writeHandle.invoke((Object)oos, h);
                            break block25;
                        }
                        if (Class.class.isInstance(obj)) {
                            writeClass.invoke((Object)oos, obj, false);
                            break block25;
                        }
                        if (ObjectStreamClass.class.isInstance(obj)) {
                            writeClassDesc.invoke((Object)oos, obj, false);
                            break block25;
                        }
                        if (String.class.isInstance(obj)) {
                            writeClassDesc.invoke((Object)oos, obj, false);
                            break block25;
                        }
                        if (obj instanceof String) {
                            writeString.invoke((Object)oos, obj, false);
                            break block25;
                        }
                        if (obj.getClass().isArray()) {
                            writeArray.invoke((Object)oos, obj, ObjectStreamClass.lookup(obj.getClass()), false);
                            break block25;
                        }
                        if (obj instanceof Enum) {
                            writeEnum.invoke((Object)oos, obj, ObjectStreamClass.lookup(obj.getClass()), false);
                            break block25;
                        }
                        if (obj instanceof Serializable) {
                            writeOrdinaryObject.invoke((Object)oos, obj, ObjectStreamClass.lookup(obj.getClass()), false);
                            break block25;
                        }
                        throw new NotSerializableException(String.valueOf(obj));
                    }
                    catch (Throwable throwable) {
                        depth.set(oos, (Integer)Integer.class.cast(depth.get(oos)) - 1);
                        setBlockDataMode.invoke(bout.get(oos), oldMode);
                        throw throwable;
                    }
                }
                depth.set(oos, (Integer)Integer.class.cast(depth.get(oos)) - 1);
                setBlockDataMode.invoke(bout.get(oos), oldMode);
            }
            catch (Exception e) {
                throw new IllegalStateException(e);
            }
        };
    }

    public BeamIOTransformer(Collection<String> typesToEnhance) {
        this.typesToEnhance = typesToEnhance;
    }

    private static class ComponentClassWriter
    extends ClassWriter {
        private final String currentClass;
        private final ClassLoader tmpLoader;

        private ComponentClassWriter(String name, ClassLoader loader, ClassReader reader, int flags) {
            super(reader, flags);
            this.tmpLoader = loader;
            this.currentClass = name;
        }

        protected String getCommonSuperClass(String type1, String type2) {
            Class<?> d;
            Class<?> c;
            try {
                c = this.findClass(type1.replace('/', '.'));
                d = this.findClass(type2.replace('/', '.'));
            }
            catch (Exception e) {
                throw new RuntimeException(e.toString());
            }
            catch (ClassCircularityError e) {
                return "java/lang/Object";
            }
            if (c.isAssignableFrom(d)) {
                return type1;
            }
            if (d.isAssignableFrom(c)) {
                return type2;
            }
            if (c.isInterface() || d.isInterface()) {
                return "java/lang/Object";
            }
            while (!(c = c.getSuperclass()).isAssignableFrom(d)) {
            }
            return c.getName().replace('.', '/');
        }

        private Class<?> findClass(String className) throws ClassNotFoundException {
            try {
                return this.currentClass.equals(className) ? Object.class : Class.forName(className, false, this.tmpLoader);
            }
            catch (ClassNotFoundException e) {
                return Class.forName(className, false, ((Object)((Object)this)).getClass().getClassLoader());
            }
        }
    }

    private static class ComponentClassVisitor
    extends ClassVisitor {
        private static final String[] OBJECT_STREAM_EXCEPTION = new String[]{Type.getType(ObjectStreamException.class).getInternalName()};
        private final ClassVisitor writer;
        private final String plugin;
        private boolean hasWriteReplace;

        private ComponentClassVisitor(ClassVisitor cv, String plugin) {
            super(458752, cv);
            this.plugin = plugin;
            this.writer = cv;
        }

        public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) {
            MethodVisitor mv = super.visitMethod(access, name, desc, signature, exceptions);
            if ("writeReplace".equals(name)) {
                this.hasWriteReplace = true;
            }
            if (Modifier.isPublic(access) && !Modifier.isStatic(access)) {
                return new TCCLAdviceAdapter(mv, access, name, desc, this.plugin);
            }
            return mv;
        }

        public void visitEnd() {
            this.createSerialisation(this.writer, this.plugin);
            super.visitEnd();
        }

        private void createSerialisation(ClassVisitor cw, String pluginId) {
            if (this.hasWriteReplace) {
                return;
            }
            MethodVisitor mv = cw.visitMethod(1, "writeReplace", "()Ljava/lang/Object;", null, OBJECT_STREAM_EXCEPTION);
            mv.visitCode();
            String wrapperType = SerializationWrapper.class.getName().replace('.', '/');
            mv.visitTypeInsn(187, wrapperType);
            mv.visitInsn(89);
            mv.visitVarInsn(25, 0);
            mv.visitLdcInsn((Object)pluginId);
            mv.visitMethodInsn(184, wrapperType, "replace", "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false);
            mv.visitInsn(176);
            mv.visitMaxs(-1, -1);
            mv.visitEnd();
        }
    }

    private static class TCCLAdviceAdapter
    extends AdviceAdapter {
        private static final Type THROWABLE_TYPE = Type.getType(Throwable.class);
        private static final Type TCCL_HELPER = Type.getType(BeamIOTransformer.class);
        private static final Type STRING_TYPE = Type.getType(String.class);
        private static final Type CLASSLOADER_TYPE = Type.getType(ClassLoader.class);
        private static final Type[] SET_TCCL_ARGS = new Type[]{STRING_TYPE};
        private static final Type[] RESET_TCCL_ARGS = new Type[]{CLASSLOADER_TYPE};
        private static final Method SET_METHOD = new Method("setPluginTccl", CLASSLOADER_TYPE, SET_TCCL_ARGS);
        private static final Method RESET_METHOD = new Method("resetTccl", Type.VOID_TYPE, RESET_TCCL_ARGS);
        private final String plugin;
        private final String desc;
        private final Label tryStart = new Label();
        private final Label endLabel = new Label();
        private int ctxLocal;

        private TCCLAdviceAdapter(MethodVisitor mv, int access, String name, String desc, String plugin) {
            super(458752, mv, access, name, desc);
            this.plugin = plugin;
            this.desc = desc;
        }

        public void onMethodEnter() {
            this.push(this.plugin);
            this.ctxLocal = this.newLocal(CLASSLOADER_TYPE);
            this.invokeStatic(TCCL_HELPER, SET_METHOD);
            this.storeLocal(this.ctxLocal);
            this.visitLabel(this.tryStart);
        }

        public void onMethodExit(int opCode) {
            if (opCode == 191) {
                return;
            }
            int stateLocal = -1;
            if (opCode != Integer.MIN_VALUE) {
                Type returnType = Type.getReturnType((String)this.desc);
                boolean isVoid = Type.VOID_TYPE.equals((Object)returnType);
                if (!isVoid) {
                    stateLocal = this.newLocal(returnType);
                    this.storeLocal(stateLocal);
                }
            } else {
                stateLocal = this.newLocal(THROWABLE_TYPE);
                this.storeLocal(stateLocal);
            }
            this.loadLocal(this.ctxLocal);
            this.invokeStatic(TCCL_HELPER, RESET_METHOD);
            if (stateLocal != -1) {
                this.loadLocal(stateLocal);
            }
        }

        public void visitMaxs(int maxStack, int maxLocals) {
            this.visitLabel(this.endLabel);
            this.catchException(this.tryStart, this.endLabel, THROWABLE_TYPE);
            this.onMethodExit(Integer.MIN_VALUE);
            this.throwException();
            super.visitMaxs(0, 0);
        }
    }

    public static class SerializationWrapper
    implements Serializable {
        private final String plugin;
        private final byte[] delegateBytes;

        public SerializationWrapper(Object delegate, String plugin) {
            this.plugin = plugin;
            this.delegateBytes = this.serialize(delegate);
            if (DEBUG) {
                try {
                    this.readResolve();
                }
                catch (ObjectStreamException e) {
                    log.debug("Serialization BUG: " + e.getMessage(), (Throwable)e);
                }
            }
        }

        private byte[] serialize(Object delegate) {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            if (DEBUG) {
                log.debug("serializing {}", delegate);
            }
            BYPASS_REPLACE_SERIALIZER.accept(baos, delegate);
            return baos.toByteArray();
        }

        /*
         * Exception decompiling
         */
        Object readResolve() throws ObjectStreamException {
            /*
             * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
             * 
             * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
             *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
             *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseInnerClassesPass1(ClassFile.java:923)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1035)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
             *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
             *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
             *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
             *     at org.benf.cfr.reader.Main.main(Main.java:54)
             */
            throw new IllegalStateException("Decompilation failed");
        }

        public static Object replace(Object delegate, String plugin) {
            return delegate == null ? null : new SerializationWrapper(delegate, plugin);
        }
    }

    private static class SerializableCoderReplacement
    extends ClassVisitor {
        private final String plugin;
        private final Type accumulatorType;

        private SerializableCoderReplacement(ClassVisitor delegate, String plugin, Class<?> clazz) {
            super(458752, delegate);
            this.plugin = plugin;
            Type accumulatorType = null;
            if (Combine.CombineFn.class.isAssignableFrom(clazz)) {
                try {
                    if (clazz.getMethod("getAccumulatorCoder", CoderRegistry.class, Coder.class).getDeclaringClass() != clazz) {
                        accumulatorType = Type.getType(clazz.getMethod("createAccumulator", new Class[0]).getReturnType());
                    }
                }
                catch (NoSuchMethodException noSuchMethodException) {
                    // empty catch block
                }
            }
            this.accumulatorType = accumulatorType;
        }

        public void visitEnd() {
            if (this.accumulatorType != null) {
                MethodVisitor getAccumulatorCoder = super.visitMethod(1, "getAccumulatorCoder", "(Lorg/apache/beam/sdk/coders/CoderRegistry;Lorg/apache/beam/sdk/coders/Coder;)Lorg/apache/beam/sdk/coders/Coder;", null, null);
                getAccumulatorCoder.visitLdcInsn((Object)this.accumulatorType);
                getAccumulatorCoder.visitLdcInsn((Object)this.plugin);
                getAccumulatorCoder.visitMethodInsn(184, "org/talend/sdk/component/runtime/beam/coder/ContextualSerializableCoder", "of", "(Ljava/lang/Class;Ljava/lang/String;)Lorg/apache/beam/sdk/coders/SerializableCoder;", false);
                getAccumulatorCoder.visitInsn(176);
                getAccumulatorCoder.visitMaxs(-1, -1);
                getAccumulatorCoder.visitEnd();
            }
            super.visitEnd();
        }

        public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
            MethodVisitor delegate = super.visitMethod(access, name, descriptor, signature, exceptions);
            return new MethodVisitor(458752, delegate){

                public void visitMethodInsn(int opcode, String owner, String name, String descriptor, boolean isInterface) {
                    if ("org/apache/beam/sdk/coders/SerializableCoder".equals(owner) && "of".equals(name) && "(Ljava/lang/Class;)Lorg/apache/beam/sdk/coders/SerializableCoder;".equals(descriptor)) {
                        super.visitLdcInsn((Object)plugin);
                        super.visitMethodInsn(opcode, "org/talend/sdk/component/runtime/beam/coder/ContextualSerializableCoder", "of", "(Ljava/lang/Class;Ljava/lang/String;)Lorg/apache/beam/sdk/coders/SerializableCoder;", false);
                    } else {
                        super.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
                    }
                }

                public void visitMaxs(int maxStack, int maxLocals) {
                    super.visitMaxs(-1, -1);
                }
            };
        }
    }
}

