package org.tkit.rhpam.quarkus.tracing;

import io.jaegertracing.internal.JaegerSpanContext;
import io.opentracing.Scope;
import io.opentracing.Span;
import io.opentracing.SpanContext;
import io.opentracing.Tracer;
import io.opentracing.tag.Tags;
import io.smallrye.reactive.messaging.amqp.AmqpMessage;
import lombok.extern.slf4j.Slf4j;

import javax.annotation.Priority;
import javax.inject.Inject;
import javax.interceptor.AroundInvoke;
import javax.interceptor.Interceptor;
import javax.interceptor.InvocationContext;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletionStage;


@TraceFromMessage
@Interceptor
@Priority(Interceptor.Priority.LIBRARY_BEFORE + 1)
@Slf4j
public class TraceFromMessageInterceptor {

    private static final String TRACE_ID = "traceId";
    private static final String SPAN_ID = "spanId";
    private static final String PARENT_ID = "parentId";
    private static final String SAMPLED = "sampled";

    @Inject
    Tracer tracer;

    @AroundInvoke
    public Object aroundTraceIdFromMessage(InvocationContext ctx) throws Exception {
        Object[] parameters = ctx.getParameters();
        if (parameters.length > 0 && parameters[0] instanceof AmqpMessage && ctx.getMethod().getReturnType().equals(CompletionStage.class)) {
            Span span = TracingUtils.buildChildSpan(ctx.getMethod().getName(), ((AmqpMessage<String>) parameters[0]).getApplicationProperties(), tracer);
            span.setTag("amqp.address", ((AmqpMessage<String>) parameters[0]).getAddress());
            span.setTag("class", ctx.getMethod().getDeclaringClass().getName());
            span.setTag("component", "tkit-rhpam-client");
            putContext((JaegerSpanContext) span.context());
//            log.info("## Interceptor: Before scope creation? {}", tracer.scopeManager().active());
            try (Scope scope = tracer.scopeManager().activate(span, false)) {
//                log.info("## Interceptor: after scope creation? {}", tracer.scopeManager().active());
                Object returnValue = ctx.proceed();
//                log.info("## Interceptor: after method execution? {}", tracer.scopeManager().active());
                clearMDCScope();
                scope.close();
//                log.info("## Interceptor: finished tracing? {}", tracer.scopeManager().active());
                ((CompletionStage) returnValue).thenRun(() -> {
//                    log.info("$$ Interceptpr Then run about to close span {} while active scope is {}", span, tracer.scopeManager().active());
                    tracer.scopeManager().active().close();
                    span.finish();
                });
                return returnValue;
            } catch (Exception e) {
                throw e;
            }

        } else {
            log.warn("### AMQP trace interceptor called on method without amq message param {}", ctx.getParameters());
            return ctx.proceed();
        }
    }

//    @AroundInvoke
    public Object setTraceIdFromMessage(InvocationContext ctx) throws Exception {
        Object[] parameters = ctx.getParameters();
        if (parameters.length > 0 && parameters[0] instanceof AmqpMessage && ctx.getMethod().getReturnType().equals(CompletionStage.class)) {
            AmqpMessage<String> messageParam = (AmqpMessage<String>) parameters[0];
            Tracer.SpanBuilder spanBuilder = tracer.buildSpan(ctx.getMethod().getName())
                    .withTag(Tags.SPAN_KIND.getKey(), Tags.SPAN_KIND_CONSUMER);

            SpanContext parentContext = TracingUtils.extract(messageParam.getApplicationProperties(), tracer);


            if (parentContext != null) {
                spanBuilder.asChildOf(parentContext);
//                ctx.getContextData().put(SPAN_CONTEXT, parentContext);
            }
            spanBuilder.withTag("class", ctx.getMethod().getDeclaringClass().getName());
            spanBuilder.withTag("component", "tkit-rhpam-client");
            spanBuilder.withTag("amqp.subject", messageParam.getSubject());
            spanBuilder.withTag("amqp.address", messageParam.getAddress());
            putContext((JaegerSpanContext) spanBuilder.start().context());
            //  just start span, dont care about context
//            Span span = spanBuilder.start();
//            Scope newScope = spanBuilder.startActive(false);
//            ctx.getContextData().put(SPAN_CONTEXT, span.context());
//            try {
//                ctx.getContextData().put(SPAN_CONTEXT, scope.span().context());
//
//                return ctx.proceed();
//            } catch (Exception e) {
//                logException(scope.span(), e);
//                throw e;
//            } finally {
//                scope.close();
//            }
            log.info("## INTERCEPTOR: parent ctx: {}  method: {}", parentContext, ctx.getMethod().getReturnType());
            Object returnValue = ctx.proceed();
//            newScope.close();
            clearMDCScope();
            log.info("## Interceptor: finished");
            ((CompletionStage) returnValue).thenRun(() -> {
                log.info("## Interceptor: future finished");
//                log.info("### THEN RUN active span: {} MDC now {} parent: {}", tracer.scopeManager().active(), MDC.get("spanId"), MDC.get("parentId"));
//                span.finish();
//                log.info("### AFTER Closing , active span:  {} MDC now {} parent: {}", tracer.scopeManager().active(), MDC.get("spanId"), MDC.get("parentId"));
            });
            return returnValue;

        } else {
            log.warn("### AMQP trace interceptor called on method without amq message param {}", ctx.getParameters());
            return ctx.proceed();
        }
    }

    private void logException(Span span, Exception e) {
        Map<String, Object> errorLogs = new HashMap<String, Object>(3);
        errorLogs.put("event", Tags.ERROR.getKey());
        errorLogs.put("error.message", e.getMessage());
        errorLogs.put("error.object", e);
        span.log(errorLogs);
        Tags.ERROR.set(span, true);
    }
    protected void putContext(JaegerSpanContext spanContext) {
        org.jboss.logging.MDC.put(TRACE_ID, spanContext.getTraceId());
        org.jboss.logging.MDC.put(SPAN_ID, Long.toHexString(spanContext.getSpanId()));
        org.jboss.logging.MDC.put(PARENT_ID, Long.toHexString(spanContext.getParentId()));
        org.jboss.logging.MDC.put(SAMPLED, Boolean.toString(spanContext.isSampled()));
    }

    protected void clearMDCScope() {
        org.jboss.logging.MDC.remove(TRACE_ID);
        org.jboss.logging.MDC.remove(SPAN_ID);
        org.jboss.logging.MDC.remove(PARENT_ID);
        org.jboss.logging.MDC.remove(SAMPLED);
    }
}
