package sila_java.library.server_base.metadata;

import io.grpc.ClientCall;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import lombok.NonNull;
import sila_java.library.core.sila.errors.SiLAErrorException;
import sila_java.library.core.sila.errors.SiLAErrors;
import sila_java.library.core.sila.utils.FullyQualifiedIdentifierUtils;

import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public abstract class MetadataInterceptor implements ServerInterceptor {
    private final List<Pattern> affectedCallTargetPatterns;
    /**
     * Listener which does nothing. To be used as target for already aborted calls.
     */
    @SuppressWarnings("rawtypes")
    private static final ServerCall.Listener emptyListener = new ServerCall.Listener() {};

    /**
     * Handle received SiLA Client Metadata
     *
     * @param call incoming call
     * @param metadata received SiLA Client Metadata
     * @param <ReqT> request type
     * @param <RespT> response type
     *
     * @return context for the next call (usually {@link Context#current})
     */
    public abstract <ReqT, RespT> Context intercept(final ServerCall<ReqT, RespT> call, ServerMetadataContainer metadata);

    public MetadataInterceptor(@NonNull final List<String> affectedCalls) {
        this.affectedCallTargetPatterns = getAffectedCallTargetPatterns(affectedCalls);
    }

    public MetadataInterceptor(@NonNull final String affectedCall) {
        this.affectedCallTargetPatterns = getAffectedCallTargetPatterns(Collections.singletonList(affectedCall));
    }

    /**
     * The actual interception logic: Calls {@link MetadataInterceptor#intercept(ServerCall, ServerMetadataContainer)}
     * and forwards the returned context to the next call handler. On error, the appropriate SiLA Error is issued.
     *
     * @param call object to receive response messages
     * @param headers which can contain extra call metadata from {@link ClientCall#start}, e.g. authentication
     *         credentials.
     * @param next next processor in the interceptor chain
     * @param <ReqT> request type
     * @param <RespT> response type
     *
     * @return next call listener
     */
    @Override
    public final <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
            final ServerCall<ReqT, RespT> call,
            final Metadata headers,
            final ServerCallHandler<ReqT, RespT> next
    ) {
        // if call is not affected: forward to next call handler
        final String callTarget = call.getMethodDescriptor().getFullMethodName();

        if (this.affectedCallTargetPatterns.stream().noneMatch(pattern -> pattern.matcher(callTarget).matches())) {
            return Contexts.interceptCall(Context.current(), call, headers, next);
        }

        final ServerMetadataContainer metadata;

        // get metadata either from context, or from headers
        if (ServerMetadataContainer.current() == null) {
            metadata = ServerMetadataContainer.fromHeaders(headers);
        } else {
            metadata = ServerMetadataContainer.current();
        }

        // apply interceptor logic
        try {
            final Context newContext = intercept(call, metadata);
            return Contexts.interceptCall(newContext, call, headers, next);
        } catch (final Throwable e) {
            call.close(new SiLAErrorException(SiLAErrors.throwableToSiLAError(e)).getStatus(), headers);
            //noinspection unchecked
            return emptyListener;
        }
    }

    /**
     * Convert fully qualified feature, command, and property identifiers to regex patterns for matching incoming gRPC
     * call method names
     *
     * @return Regex patterns for matching gRPC call method names
     */
    private static List<Pattern> getAffectedCallTargetPatterns(@NonNull final List<String> affectedCalls) {
        return affectedCalls
                .stream()
                .map(MetadataInterceptor::getAffectedCallTargetPattern)
                .collect(Collectors.toList());
    }

    /**
     * Convert fully qualified feature, command, and property identifiers to regex pattern for matching incoming gRPC
     * call method names
     *
     * @return Regex patterns for matching gRPC call method names
     */
    private static Pattern getAffectedCallTargetPattern(String affectedCall) {
        final Matcher featureIdentifierMatcher = FullyQualifiedIdentifierUtils.FullyQualifiedFeatureIdentifierPattern
                .matcher(affectedCall);
        final Matcher commandIdentifierMatcher = FullyQualifiedIdentifierUtils.FullyQualifiedCommandIdentifierPattern
                .matcher(affectedCall);
        final Matcher propertyIdentifierMatcher = FullyQualifiedIdentifierUtils.FullyQualifiedPropertyIdentifierPattern
                .matcher(affectedCall);

        final boolean isValidFQI = !(
                featureIdentifierMatcher.matches()
                        || commandIdentifierMatcher.matches()
                        || propertyIdentifierMatcher.matches()
        );

        if (isValidFQI) {
            throw new IllegalArgumentException(
                    String.format(
                            "Given argument is no fully qualified feature, command, or property identifier: '%s'",
                            affectedCall
                    )
            );
        }

        final String featureIdentifier = affectedCall.split("/")[2];
        String callPrefix = String.join(
                "\\.", "sila2", affectedCall.toLowerCase().replace('/', '.'), featureIdentifier
        );

        if (commandIdentifierMatcher.matches()) {
            // command identifier, e.g.
            // org.silastandard/core/SiLAService/v1/Command/GetFeatureDefinition
            callPrefix += "/" + affectedCall.split("/")[5];
        } else if (propertyIdentifierMatcher.matches()) {
            // property identifier, e.g.
            // org.silastandard/core/SiLAService/v1/Property/ImplementedFeatures
            callPrefix += "/(Get_|Subscribe_)" + affectedCall.split("/")[5];
        } else {
            // feature identifier, e.g. org.silastandard/core/SiLAService/v1
            callPrefix += "/(Get_|Subscribe_)?" + "[A-Z][a-zA-Z0-9]*";
        }
        return Pattern.compile(callPrefix);
    }
}
