/*
 * Decompiled with CFR 0.152.
 */
package ai.databand.agent;

import ai.databand.config.DbndAgentConfig;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.lang.instrument.ClassFileTransformer;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.Optional;
import javassist.CannotCompileException;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtField;
import javassist.CtMethod;
import javassist.LoaderClassPath;
import javassist.NotFoundException;
import javassist.bytecode.AnnotationsAttribute;
import javassist.bytecode.DuplicateMemberException;
import javassist.bytecode.MethodInfo;
import javassist.expr.ExprEditor;
import javassist.expr.MethodCall;
import javassist.expr.NewExpr;

public class DbndTrackingTransformer
implements ClassFileTransformer {
    private static final String TASK_ANNOTATION = "ai.databand.annotations.Task";
    private final DbndAgentConfig config;
    private static final String SPARK_LISTENER_INJECT_CODE = "{ $_ = $proceed($$);$_.sparkContext().addSparkListener(new ai.databand.spark.DbndSparkListener($dbnd));}";
    private static final String SPARK_QUERY_LISTENER_INJECT_CODE = "{ $_ = $proceed($$);$_.sparkContext().addSparkListener(new ai.databand.spark.DbndSparkListener($dbnd));$_.listenerManager().register(new ai.databand.spark.DbndSparkQueryExecutionListener($dbnd));}";

    public DbndTrackingTransformer(DbndAgentConfig config) {
        this.config = config;
    }

    public byte[] transform(ClassLoader loader, String className, Class classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) {
        ClassPool cp = ClassPool.getDefault();
        cp.appendClassPath(new LoaderClassPath(loader));
        Optional<CtClass> ctOpt = this.classInScope(cp, className, classfileBuffer);
        if (!ctOpt.isPresent()) {
            return null;
        }
        try {
            CtClass ct = ctOpt.get();
            System.out.printf("Instrumenting class %s with dbnd wrapper%n", className);
            CtMethod[] declaredMethods = ct.getDeclaredMethods();
            try {
                ct.addField(CtField.make("static ai.databand.DbndWrapper $dbnd = ai.databand.DbndWrapper.instance();", ct));
            }
            catch (DuplicateMemberException duplicateMemberException) {
                // empty catch block
            }
            for (CtMethod method : declaredMethods) {
                MethodInfo methodInfo = method.getMethodInfo();
                AnnotationsAttribute attInfo = (AnnotationsAttribute)methodInfo.getAttribute("RuntimeVisibleAnnotations");
                if (attInfo == null || attInfo.getAnnotation(TASK_ANNOTATION) == null) continue;
                CtClass tr = cp.get("java.lang.Throwable");
                if (this.config.isVerbose()) {
                    System.out.printf("Instrumenting method %s%n", methodInfo.getName());
                }
                this.injectSparkListener(method);
                method.insertBefore("{ $dbnd.beforeTask(\"" + ct.getName() + "\", \"" + method.getLongName() + "\", $args); }");
                method.insertAfter("{ $dbnd.afterTask(\"" + method.getLongName() + "\", (Object) ($w) $_); }");
                method.addCatch("{ $dbnd.errorTask(\"" + method.getLongName() + "\", $e); throw $e; }", tr);
            }
            return ct.toBytecode();
        }
        catch (RuntimeException e) {
            if (e.getMessage() != null && e.getMessage().contains("frozen")) {
                return null;
            }
        }
        catch (Throwable e) {
            System.err.println("Class instrumentation failed");
            e.printStackTrace();
            return null;
        }
        return classfileBuffer;
    }

    protected void injectSparkListener(CtMethod method) throws CannotCompileException {
        if (!this.config.sparkListenerInjectEnabled()) {
            return;
        }
        final ArrayList listenerAdded = new ArrayList(1);
        method.instrument(new ExprEditor(){

            @Override
            public void edit(NewExpr c) {
                if ("ai.databand.spark.DbndSparkListener".equalsIgnoreCase(c.getClassName())) {
                    listenerAdded.add(new Object());
                }
            }
        });
        if (listenerAdded.isEmpty()) {
            method.instrument(new ExprEditor(){

                @Override
                public void edit(MethodCall m) throws CannotCompileException {
                    if (m.getMethodName().contains("getOrCreate") && m.getClassName().equalsIgnoreCase("org.apache.spark.sql.SparkSession$Builder")) {
                        if (DbndTrackingTransformer.this.config.sparkQueryListenerInjectEnabled()) {
                            m.replace(DbndTrackingTransformer.SPARK_QUERY_LISTENER_INJECT_CODE);
                            if (DbndTrackingTransformer.this.config.isVerbose()) {
                                System.out.println("Spark listener and query listener are injected");
                            }
                        } else {
                            m.replace(DbndTrackingTransformer.SPARK_LISTENER_INJECT_CODE);
                            if (DbndTrackingTransformer.this.config.isVerbose()) {
                                System.out.println("Spark listener is injected");
                            }
                        }
                    }
                }
            });
        } else if (this.config.isVerbose()) {
            System.out.println("Spark listener was already added by user, skipped injection");
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    protected Optional<CtClass> classInScope(ClassPool cp, String className, byte[] classfileBuffer) {
        try (ByteArrayInputStream is = new ByteArrayInputStream(classfileBuffer);){
            CtMethod[] declaredMethods;
            CtClass ct = cp.makeClass(is);
            CtMethod[] ctMethodArray = declaredMethods = ct.getDeclaredMethods();
            int n = ctMethodArray.length;
            int n2 = 0;
            while (n2 < n) {
                CtMethod method = ctMethodArray[n2];
                MethodInfo methodInfo = method.getMethodInfo();
                AnnotationsAttribute attInfo = (AnnotationsAttribute)methodInfo.getAttribute("RuntimeVisibleAnnotations");
                if (attInfo != null && attInfo.getAnnotation(TASK_ANNOTATION) != null) {
                    if (!this.isScalaObject(cp, className)) {
                        Optional<CtClass> optional = Optional.empty();
                        return optional;
                    }
                    Optional<CtClass> optional = Optional.of(ct);
                    return optional;
                }
                ++n2;
            }
            return Optional.empty();
        }
        catch (IOException e) {
            return Optional.empty();
        }
    }

    protected boolean isScalaObject(ClassPool cp, String className) {
        if (className.contains("$")) {
            return true;
        }
        try {
            cp.get(className + '$');
            return false;
        }
        catch (NotFoundException e) {
            return true;
        }
    }
}

