package org.tkit.rhpam.quarkus.tracing;

import io.jaegertracing.internal.JaegerSpanContext;
import io.opentracing.Scope;
import io.opentracing.Span;
import io.opentracing.SpanContext;
import io.opentracing.Tracer;
import io.opentracing.tag.Tags;
import io.smallrye.reactive.messaging.amqp.AmqpMessage;
import lombok.extern.slf4j.Slf4j;
import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.tkit.quarkus.log.cdi.context.CorrelationScope;
import org.tkit.quarkus.log.cdi.context.TkitLogContext;
import org.tkit.rhpam.quarkus.messaging.common.MessageUtil;

import javax.annotation.Priority;
import javax.inject.Inject;
import javax.interceptor.AroundInvoke;
import javax.interceptor.Interceptor;
import javax.interceptor.InvocationContext;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletionStage;


@TraceFromMessage
@Interceptor
@Priority(Interceptor.Priority.LIBRARY_BEFORE + 1)
@Slf4j
public class TraceFromMessageInterceptor {

    private static final String TRACE_ID = "traceId";
    private static final String SPAN_ID = "spanId";
    private static final String PARENT_ID = "parentId";
    private static final String SAMPLED = "sampled";

    @ConfigProperty(name = "tkit.rhpam.propagate.correlationId", defaultValue = "true")
    boolean propagateCorrelationId;

    @Inject
    Tracer tracer;

    @AroundInvoke
    public Object aroundTraceIdFromMessage(InvocationContext ctx) throws Exception {
        Object[] parameters = ctx.getParameters();
        if (parameters.length > 0 && parameters[0] instanceof AmqpMessage && ctx.getMethod().getReturnType().equals(CompletionStage.class)) {
            Span span = TracingUtils.buildChildSpan(ctx.getMethod().getName(), ((AmqpMessage<String>) parameters[0]).getApplicationProperties(), tracer);
            span.setTag("amqp.address", ((AmqpMessage<String>) parameters[0]).getAddress());
            span.setTag("class", ctx.getMethod().getDeclaringClass().getName());
            span.setTag("component", "tkit-rhpam-client");
            if (propagateCorrelationId) {
                String correlationId = ((AmqpMessage<String>) parameters[0]).getApplicationProperties().getString(MessageUtil.PROP_X_CORRELATION_ID);
                if (correlationId != null) {
                    CorrelationScope correlationCtx = new CorrelationScope(correlationId);
                    TkitLogContext.set(correlationCtx);
                }
            }

            fillMDCWithTraceData((JaegerSpanContext) span.context());
            Scope scope = tracer.scopeManager().activate(span, true);
            try {
                Object returnValue = ctx.proceed();
                return returnValue;
            } catch (Exception e) {
                Tags.ERROR.set(span, true);
                throw e;
            } finally {
                scope.close();
                clearMDCTraceData();
            }

        } else {
            log.warn("### AMQP trace interceptor called on method without amq message param {}", ctx.getParameters());
            return ctx.proceed();
        }
    }

    protected void fillMDCWithTraceData(JaegerSpanContext spanContext) {
        org.jboss.logging.MDC.put(TRACE_ID, spanContext.getTraceId());
        org.jboss.logging.MDC.put(SPAN_ID, Long.toHexString(spanContext.getSpanId()));
        org.jboss.logging.MDC.put(PARENT_ID, Long.toHexString(spanContext.getParentId()));
        org.jboss.logging.MDC.put(SAMPLED, Boolean.toString(spanContext.isSampled()));
    }

    protected void clearMDCTraceData() {
        org.jboss.logging.MDC.remove(TRACE_ID);
        org.jboss.logging.MDC.remove(SPAN_ID);
        org.jboss.logging.MDC.remove(PARENT_ID);
        org.jboss.logging.MDC.remove(SAMPLED);
    }
}
