package cool.mtc.swagger;

import cool.mtc.swagger.model.*;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import springfox.documentation.builders.*;
import springfox.documentation.service.*;
import springfox.documentation.spi.DocumentationType;
import springfox.documentation.spring.web.plugins.Docket;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.function.Predicate;

/**
 * @author 明河
 */
@Configuration
@ConditionalOnProperty(value = "mtc.swagger.enabled", havingValue = "true", matchIfMissing = true)
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class DocketConfiguration {
    private final BeanFactory beanFactory;
    private final SwaggerProperties swaggerProperties;

    private static final String DOCKET_BEAN_PREFIX = "swagger-spring-boot-starter-docket-";

    @Bean
    public void customDocket() {
        List<Docket> docketList = new ArrayList<>();
        if (swaggerProperties.getGroups().isEmpty()) {
            docketList.add(this.loadSwaggerRootConfig(swaggerProperties));
        } else {
            docketList.addAll(this.loadSwaggerGroupConfig(swaggerProperties));
        }
        DefaultListableBeanFactory defaultListableBeanFactory = (DefaultListableBeanFactory) beanFactory;
        for (Docket docket : docketList) {
            String beanName = DOCKET_BEAN_PREFIX + ObjectUtils.nullSafeHashCode(docket.getGroupName().toCharArray());
            defaultListableBeanFactory.registerSingleton(beanName, docket);
        }
    }

    /**
     * 加载根节点上的配置
     */
    private Docket loadSwaggerRootConfig(SwaggerProperties swaggerProperties) {
        return this.loadSwaggerDocketConfig(swaggerProperties);
    }

    /**
     * 加载分组配置
     */
    private List<Docket> loadSwaggerGroupConfig(SwaggerProperties swaggerProperties) {
        List<Docket> docketList = new ArrayList<>();
        for (SwaggerGroup group : swaggerProperties.getGroups()) {
            this.handleGroupData(swaggerProperties, group);
            docketList.add(this.loadSwaggerDocketConfig(group));
        }
        return docketList;
    }

    /**
     * 加载Docket配置
     */
    private Docket loadSwaggerDocketConfig(SwaggerDocket swaggerDocket) {
        Docket docket = new Docket(DocumentationType.OAS_30)
                .enable(swaggerDocket.isEnabled())
                .groupName(swaggerDocket.getName())
                .apiInfo(this.handleApiInfo(swaggerDocket.getApi()))
                .globalRequestParameters(this.handleGlobalParameter(swaggerDocket.getGlobalParameters()))
                .apiDescriptionOrdering(this.test())
                .select()
                .apis(RequestHandlerSelectors.basePackage(swaggerDocket.getBasePackage()))
                .paths(this.handlePaths(swaggerDocket.getPathPatterns(), swaggerDocket.getExcludePathPatterns()))
                .build();
        Tag[] tags = this.handleTag(swaggerDocket.getTags());
        if (tags.length > 0) {
            docket.tags(tags[0], Arrays.copyOfRange(tags, 1, tags.length));
        }
        Server[] servers = this.handleServer(swaggerDocket.getServers());
        if (servers.length > 0) {
            docket.servers(servers[0], Arrays.copyOfRange(servers, 1, servers.length));
        }
        return docket;
    }

    public Comparator<ApiDescription> test() {
        return Comparator.comparing(ApiDescription::getDescription).reversed();
    }

    /**
     * 处理分组数据，从上级继承部分属性
     */
    private void handleGroupData(SwaggerProperties swaggerProperties, SwaggerGroup group) {
        // api属性
        SwaggerApiInfo swaggerApiInfo = swaggerProperties.getApi();
        SwaggerContact swaggerContact = swaggerApiInfo.getContact();

        SwaggerApiInfo groupApiInfo = group.getApi();
        if (null == groupApiInfo) {
            group.setApi(swaggerApiInfo);
        } else {
            Field[] apiFields = groupApiInfo.getClass().getDeclaredFields();
            for (Field field : apiFields) {
                if (field.getType().equals(SwaggerContact.class)) {
                    continue;
                }
                this.copyData(field, groupApiInfo, swaggerApiInfo);
            }
            SwaggerContact groupContact = groupApiInfo.getContact();
            if (null == groupContact) {
                groupApiInfo.setContact(swaggerContact);
            } else {
                Field[] contactFields = groupContact.getClass().getDeclaredFields();
                for (Field contactField : contactFields) {
                    this.copyData(contactField, groupContact, swaggerContact);
                }
            }
        }
        // basePackage
        if (!StringUtils.hasText(group.getBasePackage())) {
            group.setBasePackage(swaggerProperties.getBasePackage());
        }
        // pathPatterns
        if (group.getPathPatterns().isEmpty()) {
            group.setPathPatterns(swaggerProperties.getPathPatterns());
        }
        // excludePathPatterns
        if (group.getExcludePathPatterns().isEmpty()) {
            group.setExcludePathPatterns(swaggerProperties.getExcludePathPatterns());
        }
        // globalParameters
        if (group.getGlobalParameters().isEmpty()) {
            group.setGlobalParameters(swaggerProperties.getGlobalParameters());
        }
        // tags
        if (group.getTags().isEmpty()) {
            group.setTags(swaggerProperties.getTags());
        }
        // servers
        if (group.getServers().isEmpty()) {
            group.setServers(swaggerProperties.getServers());
        }
    }

    /**
     * 从一个对象中复制属性的值到另一个对象
     */
    private void copyData(Field field, Object source, Object target) {
        field.setAccessible(true);
        try {
            if (null == field.get(source)) {
                field.set(source, field.get(target));
            }
        } catch (IllegalAccessException ignore) {
            // 此异常不会发生，无需处理
        }
    }

    /**
     * 处理接口文档信息
     */
    private ApiInfo handleApiInfo(SwaggerApiInfo swaggerApiInfo) {
        if (null == swaggerApiInfo) {
            return ApiInfo.DEFAULT;
        }
        return new ApiInfoBuilder()
                .title(swaggerApiInfo.getTitle())
                .description(swaggerApiInfo.getDescription())
                .version(swaggerApiInfo.getVersion())
                .license(swaggerApiInfo.getLicense())
                .licenseUrl(swaggerApiInfo.getLicenseUrl())
                .termsOfServiceUrl(swaggerApiInfo.getTermsOfServiceUrl())
                .contact(this.handleContact(swaggerApiInfo.getContact()))
                .build();
    }

    /**
     * 处理联系人信息
     */
    private Contact handleContact(SwaggerContact swaggerContact) {
        return null == swaggerContact ? ApiInfo.DEFAULT_CONTACT : new Contact(swaggerContact.getName(), swaggerContact.getUrl(), swaggerContact.getEmail());
    }

    /**
     * 处理全局的请求参数
     */
    private List<RequestParameter> handleGlobalParameter(List<SwaggerParameter> globalParameters) {
        if (globalParameters.isEmpty()) {
            return new ArrayList<>();
        }
        List<RequestParameter> parameterList = new ArrayList<>();
        for (SwaggerParameter parameter : globalParameters) {
            RequestParameter requestParameter = new RequestParameterBuilder()
                    .name(parameter.getName())
                    .in(parameter.getType())
                    .description(parameter.getDescription())
                    .required(parameter.isRequired())
                    .build();
            parameterList.add(requestParameter);
        }
        return parameterList;
    }

    /**
     * 处理标签信息
     */
    private Tag[] handleTag(List<SwaggerTag> swaggerTags) {
        if (swaggerTags.isEmpty()) {
            return new Tag[]{};
        }
        Tag[] tags = new Tag[swaggerTags.size()];
        for (int i = 0; i < swaggerTags.size(); i++) {
            SwaggerTag swaggerTag = swaggerTags.get(i);
            tags[i] = new Tag(swaggerTag.getName(), swaggerTag.getDescription(), swaggerTag.getOrder());
        }
        return tags;
    }

    /**
     * 处理server信息
     */
    private Server[] handleServer(List<SwaggerServer> swaggerServers) {
        if (swaggerServers.isEmpty()) {
            return new Server[]{};
        }
        Server[] servers = new Server[swaggerServers.size()];
        for (int i = 0; i < swaggerServers.size(); i++) {
            SwaggerServer swaggerServer = swaggerServers.get(i);
            Server server = new ServerBuilder()
                    .url(swaggerServer.getUrl())
                    .description(swaggerServer.getDescription())
                    .build();
            servers[i] = server;
        }
        return servers;
    }

    /**
     * 处理扫描路径
     *
     * @param pathPatterns        允许的路径规则
     * @param excludePathPatterns 需要排除的路径规则
     */
    private Predicate<String> handlePaths(List<String> pathPatterns, List<String> excludePathPatterns) {
        // 需要扫描的路径
        Predicate<String> includePathPredicate = pathPatterns.isEmpty() ? PathSelectors.any() : PathSelectors.none();
        for (String pathPattern : pathPatterns) {
            includePathPredicate = includePathPredicate.or(PathSelectors.ant(pathPattern));
        }
        // 需要排除扫描的路径
        Predicate<String> excludePathPredicate = PathSelectors.none();
        for (String excludePathPattern : excludePathPatterns) {
            excludePathPredicate = excludePathPredicate.or(PathSelectors.ant(excludePathPattern));
        }
        return includePathPredicate.and(excludePathPredicate.negate());
    }
}
