package cn.bestwu.security.oauth2.social.provider;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import org.springframework.security.oauth2.provider.*;
import org.springframework.security.oauth2.provider.token.AuthorizationServerTokenServices;
import org.springframework.util.StringUtils;

import javax.validation.ConstraintViolation;
import javax.validation.ConstraintViolationException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

/**
 * @author Peter Wu
 */
public class SocialTokenGranter<T> implements TokenGranter {

	protected final Log logger = LogFactory.getLog(getClass());

	private final AuthorizationServerTokenServices tokenServices;

	private final ClientDetailsService clientDetailsService;

	private final SocialAdapter<T> socialAdapter;

	public SocialTokenGranter(AuthorizationServerTokenServices tokenServices, ClientDetailsService clientDetailsService, SocialAdapter<T> socialAdapter) {
		this.clientDetailsService = clientDetailsService;
		this.tokenServices = tokenServices;
		this.socialAdapter = socialAdapter;
	}

	@Override
	public OAuth2AccessToken grant(String grantType, TokenRequest tokenRequest) {
		grantType = grantType.toLowerCase();

		if (!socialAdapter.support(grantType)) {
			return null;
		}

		String clientId = tokenRequest.getClientId();
		ClientDetails client = clientDetailsService.loadClientByClientId(clientId);
		validateGrantType(grantType, client);

		logger.debug("Getting access token for: " + clientId);

		return tokenServices.createAccessToken(getOAuth2Authentication(grantType, tokenRequest));

	}

	protected void validateGrantType(String grantType, ClientDetails clientDetails) {
		Collection<String> authorizedGrantTypes = clientDetails.getAuthorizedGrantTypes();
		if (authorizedGrantTypes != null && !authorizedGrantTypes.isEmpty()
				&& !authorizedGrantTypes.contains(grantType)) {
			throw new InvalidClientException("Unauthorized grant type: " + grantType);
		}
	}

	protected OAuth2Authentication getOAuth2Authentication(String pid, TokenRequest tokenRequest) {
		Map<String, String> requestParameters = tokenRequest.getRequestParameters();
		String puid = requestParameters.get("puid");
		String social_token = requestParameters.get("social_token");

		if (!StringUtils.hasText(puid)) {
			throw new SocialException(HttpStatus.UNPROCESSABLE_ENTITY.value(), "puid must be supplied.");
		}
		if (!StringUtils.hasText(social_token)) {
			throw new SocialException(HttpStatus.UNPROCESSABLE_ENTITY.value(), "social_token must be supplied.");
		}

		Map originalProfile = socialAdapter.validateSocialTokenAndGetUserOriginalProfile(pid, puid, social_token);
		SocialId<T> socialId = socialAdapter.findByPidAndPuid(pid, puid);
		if (socialId == null) {
			UserProfile userProfile = socialAdapter.originalProfile2UserProfile(pid, puid, originalProfile);
			if (socialAdapter.isAutoLogon()) {
				try {
					socialId = socialAdapter.signup(userProfile);
					T user = socialId.getUser();
					OAuth2Authentication authentication = socialAdapter.createOAuth2Authentication(user, tokenRequest.getClientId(), tokenRequest.getScope());
					//自动登录
					SecurityContextHolder.getContext().setAuthentication(authentication);
					socialAdapter.afterSignup(user);
					return authentication;
				} catch (ConstraintViolationException e) {
					Set<ConstraintViolation<?>> constraintViolations = e.getConstraintViolations();
					Map<String, String> errors = new HashMap<>();
					constraintViolations.forEach(constraintViolation -> errors.put(constraintViolation.getPropertyPath().toString(), constraintViolation.getMessage()));
					SocialException socialException = new SocialException(HttpStatus.UNPROCESSABLE_ENTITY.value(), errors.values().iterator().next(), userProfile);
					socialException.setErrors(errors);
					throw socialException;
				} catch (DataIntegrityViolationException e) {
					String specificCauseMessage = e.getMostSpecificCause().getMessage();
					String duplicateRegex = "^Duplicate entry '(.*?)'.*";
					String constraintSubfix = "Cannot delete or update a parent row";
					String message;
					int httpStatusCode = 500;
					if (specificCauseMessage.matches(duplicateRegex)) {
						httpStatusCode = HttpStatus.UNPROCESSABLE_ENTITY.value();
						message = specificCauseMessage.replaceAll(duplicateRegex, "$1") + "已经存在";
						if (!StringUtils.hasText(message)) {
							message = "数据验证失败";
						}
					} else if (specificCauseMessage.startsWith(constraintSubfix)) {
						httpStatusCode = HttpStatus.UNPROCESSABLE_ENTITY.value();
						message = "不能删除或更新关联实体，其他的资源引用了此实体";
						if (!StringUtils.hasText(message)) {
							message = "数据验证失败";
						}
					} else {
						message = e.getRootCause().getMessage();
					}

					throw new SocialException(httpStatusCode, message, userProfile);
				}
			} else {
				throw new SocialException(HttpStatus.NOT_FOUND.value(), "账号未初始化", userProfile);
			}
		} else if (socialAdapter.isRefreshUserProfile()) {
			try {
				UserProfile userProfile = socialAdapter.originalProfile2UserProfile(pid, puid, originalProfile);
				socialAdapter.saveUserProfile(socialId.getUser(), userProfile);
			} catch (Exception ignored) {
			}
		}

		return socialAdapter.createOAuth2Authentication(socialId.getUser(), tokenRequest.getClientId(), tokenRequest.getScope());
	}

}
