package org.thryft.waf.server.controllers.oauth;

import static com.google.common.base.Preconditions.checkNotNull;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;

import javax.annotation.Nullable;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.DisabledAccountException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Marker;
import org.thryft.native_.GenericUri;
import org.thryft.native_.Url;
import org.thryft.waf.api.models.ModelEntry;
import org.thryft.waf.lib.logging.LoggingUtils;

import com.github.scribejava.core.builder.ServiceBuilder;
import com.github.scribejava.core.exceptions.OAuthException;
import com.github.scribejava.core.model.OAuth2AccessToken;
import com.github.scribejava.core.oauth.OAuth20Service;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableMap;

@SuppressWarnings("serial")
public abstract class AbstractOauthLoginController<UserEntryT extends ModelEntry<?, ?>> extends HttpServlet {
    protected AbstractOauthLoginController(final ImmutableMap<String, Oauth2ServiceProvider> oauthServiceProviders) {
        this.oauthServiceProviders = checkNotNull(oauthServiceProviders);
    }

    protected abstract String _getFailedLoginUrl(final String errorUrlDecoded, final String stateParameterUrlDecoded);

    protected abstract String _getNewLoginUrl(final String stateParameterUrlDecoded);

    protected abstract String _getOauthCallbackUrlPathPrefix();

    protected Optional<String> _getOauthCallbackUrlScheme() {
        return Optional.absent();
    }

    protected abstract String _getSuccessfulLoginUrl(final String stateParameterUrlDecoded);

    protected abstract Optional<UserEntryT> _getUser(String providerId, final OauthUserProfile userProfile)
            throws IOException;

    protected abstract void _login(final UserEntryT userEntry);

    protected abstract UserEntryT _postUser(String providerId, final OauthUserProfile userProfile) throws IOException;

    @Override
    protected final void doGet(final HttpServletRequest httpServletRequest,
            final HttpServletResponse httpServletResponse) throws IOException {
        if (httpServletRequest.getPathInfo() == null || httpServletRequest.getPathInfo().length() <= 1) {
            logger.debug(LOG_MARKER, "ignoring request with invalid path '{}'",
                    httpServletRequest.getPathInfo() != null ? httpServletRequest.getPathInfo() : "");
            httpServletResponse.sendError(HttpServletResponse.SC_NOT_FOUND);
            return;
        }
        final String providerId = StringUtils.stripStart(httpServletRequest.getPathInfo(), "/").toLowerCase();
        final String hostHeader = httpServletRequest.getHeader("Host");
        if (hostHeader == null) {
            logger.debug(LOG_MARKER, "ignoring request with no Host header");
            httpServletResponse.sendError(HttpServletResponse.SC_BAD_REQUEST);
            return;
        }

        @Nullable
        final String stateParameterUrlDecoded = httpServletRequest.getParameter("state");

        OauthUserProfile userProfile;
        {
            {
                final Oauth2ServiceProvider serviceConfiguration = oauthServiceProviders.get(providerId);
                if (serviceConfiguration == null) {
                    logger.warn(LOG_MARKER, "unsupported OAuth provider '{}'", providerId);
                    httpServletResponse.sendError(HttpServletResponse.SC_NOT_FOUND);
                    return;
                }

                final Url httpServletRequestUrl = Url.parse(httpServletRequest.getRequestURL().toString());
                final String callbackUrl = Url.parse(_getOauthCallbackUrlScheme().or(httpServletRequestUrl.getScheme())
                        + "://" + hostHeader + _getOauthCallbackUrlPathPrefix() + providerId).toString();
                logger.debug(LOG_MARKER, "HTTP request URL: {}, OAuth callback URL: {}", httpServletRequestUrl,
                        callbackUrl);

                // .offline(true)
                final ServiceBuilder serviceBuilder = new ServiceBuilder().apiKey(serviceConfiguration.getApiKey())
                        .apiSecret(serviceConfiguration.getApiSecret()).callback(callbackUrl);
                if (serviceConfiguration.getScope().isPresent()) {
                    serviceBuilder.scope(serviceConfiguration.getScope().get());
                }
                final OAuth20Service service = serviceBuilder.build(serviceConfiguration);

                final String code = httpServletRequest.getParameter("code");
                String error = httpServletRequest.getParameter("error");
                if (code != null) {
                    // The request is a normal OAuth 2 callback.
                    final OAuth2AccessToken accessToken;
                    // final Verifier verifier = new Verifier(code);
                    try {
                        accessToken = checkNotNull(service.getAccessToken(code));
                    } catch (final OAuthException e) {
                        logger.error(LOG_MARKER, "error getting OAuth access token: ", e);
                        httpServletResponse.sendError(HttpServletResponse.SC_BAD_REQUEST);
                        return;
                    }

                    logger.debug(LOG_MARKER, "getting user information from OAuth2 service provider '{}'", providerId);
                    try {
                        userProfile = serviceConfiguration.getUserProfile(service, accessToken);
                    } catch (final IOException e) {
                        logger.warn(LOG_MARKER, "error getting OAuth user profile: ", e);
                        httpServletResponse.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
                        return;
                    }
                } else if (error != null) {
                    // This is an OAuth 2 callback with an error.
                    if (httpServletRequest.getContentLength() > 0) {
                        try {
                            final BufferedReader requestReader = httpServletRequest.getReader();
                            final StringBuilder requestBodyStringBuilder = new StringBuilder();
                            final char[] requestBodyBuffer = new char[128];
                            int requestBodyBytesRead = -1;
                            while ((requestBodyBytesRead = requestReader.read(requestBodyBuffer)) > 0) {
                                requestBodyStringBuilder.append(requestBodyBuffer, 0, requestBodyBytesRead);
                            }
                            error += ":\n" + requestBodyStringBuilder.toString();
                        } catch (final IOException e) {
                        }
                    }
                    logger.error(LOG_MARKER, "OAuth2 error: ", error);
                    // httpServletResponse.sendError(HttpServletResponse.SC_NOT_FOUND);
                    final String failedLoginUrl = _getFailedLoginUrl(error, stateParameterUrlDecoded);
                    logger.debug(LOG_MARKER, "redirecting user to {} after failed login", failedLoginUrl);
                    httpServletResponse.sendRedirect(failedLoginUrl);
                    return;
                } else {
                    // The request is to log in. Redirect the user to the
                    // appropriate
                    // OAuth authorization URL.

                    final String authorizationUrlString = service.getAuthorizationUrl(null);
                    Url authorizationUrl;
                    try {
                        authorizationUrl = Url.parse(authorizationUrlString);
                    } catch (final IllegalArgumentException e) {
                        logger.error(LOG_MARKER, "error parsing authorization URL '{}': ", authorizationUrlString, e);
                        throw new IllegalStateException();
                    }
                    logger.debug(LOG_MARKER, "redirecting user to authorization URL");

                    if (stateParameterUrlDecoded != null && !stateParameterUrlDecoded.isEmpty()) {
                        String stateParameterUrlEncoded;
                        try {
                            stateParameterUrlEncoded = URLEncoder.encode(stateParameterUrlDecoded, "ASCII");
                        } catch (final UnsupportedEncodingException e) {
                            throw new IllegalStateException(e);
                        }
                        if (authorizationUrl.getQuery().isPresent()) {
                            authorizationUrl = (Url) GenericUri.builder(authorizationUrl)
                                    .setQuery(authorizationUrl.getQuery().get() + "&state=" + stateParameterUrlEncoded)
                                    .build();
                        } else {
                            authorizationUrl = (Url) GenericUri.builder(authorizationUrl)
                                    .setQuery("?state=" + stateParameterUrlEncoded).build();
                        }
                    }

                    httpServletResponse.sendRedirect(authorizationUrl.toString());
                    return;
                }
            }
        }

        Optional<UserEntryT> newUserEntry = Optional.absent();
        final Optional<UserEntryT> existingUserEntry;
        try {
            existingUserEntry = _getUser(providerId, userProfile);
            if (existingUserEntry.isPresent()) {
                logger.debug(LOG_MARKER, "logging in existing user {}", existingUserEntry.get().getModel());
            } else {
                newUserEntry = Optional.of(_postUser(providerId, userProfile));
                logger.debug(LOG_MARKER, "logging in new user {}", newUserEntry.get().getModel());
            }
        } catch (final IOException e) {
            logger.error(LOG_MARKER, "I/O exception logging in user {}: ", userProfile, e);
            httpServletResponse.sendError(500, ExceptionUtils.getRootCauseMessage(e));
            return;
        }
        final UserEntryT userEntry = existingUserEntry.or(newUserEntry).get();

        try {
            _login(userEntry);
        } catch (final DisabledAccountException e) {
            final String failedLoginUrl = _getFailedLoginUrl("inactive", stateParameterUrlDecoded);
            logger.warn(LOG_MARKER, "redirecting inactive user to {} ", failedLoginUrl);
            httpServletResponse.sendRedirect(failedLoginUrl);
            return;
        } catch (final AuthenticationException e) {
            logger.error(LOG_MARKER, "error logging in {}: ", userEntry.getModel(), e);
            httpServletResponse.sendError(500, ExceptionUtils.getRootCauseMessage(e));
            return;
        }

        final String successfulLoginUrl = existingUserEntry.isPresent()
                ? _getSuccessfulLoginUrl(stateParameterUrlDecoded) : _getNewLoginUrl(stateParameterUrlDecoded);

        logger.debug(LOG_MARKER, "redirecting {} to {} after successful login", userEntry.getModel(),
                successfulLoginUrl);

        httpServletResponse.sendRedirect(successfulLoginUrl);
    }

    private final ImmutableMap<String, Oauth2ServiceProvider> oauthServiceProviders;
    private final static Logger logger = LoggerFactory.getLogger(AbstractOauthLoginController.class);
    final static Marker LOG_MARKER = LoggingUtils.getMarker(AbstractOauthLoginController.class);
}
