package com.walker.support.milvus.engine;

import com.walker.support.milvus.DataSet;
import com.walker.support.milvus.FieldType;
import com.walker.support.milvus.OperateService;
import com.walker.support.milvus.OutData;
import com.walker.support.milvus.Query;
import com.walker.support.milvus.Table;
import com.walker.support.milvus.util.FieldTypeUtils;
import io.milvus.client.MilvusServiceClient;
import io.milvus.grpc.MutationResult;
import io.milvus.grpc.SearchResults;
import io.milvus.param.ConnectParam;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import io.milvus.param.R;
import io.milvus.param.RpcStatus;
import io.milvus.param.collection.CreateCollectionParam;
import io.milvus.param.collection.DropCollectionParam;
import io.milvus.param.collection.LoadCollectionParam;
import io.milvus.param.collection.ReleaseCollectionParam;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.param.index.DropIndexParam;
import io.milvus.response.SearchResultsWrapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class DefaultOperateService implements OperateService {

    protected final transient Logger logger = LoggerFactory.getLogger(this.getClass());

    private MilvusServiceClient client = null;

    @Override
    public boolean connect(String ip, int port) {
        if(client != null){
            this.client.close();
            logger.warn("MilvusServiceClient在运行，正在停止，准备创建新对象: " + ip + ", " + port);
        }
        try {
            this.client = new MilvusServiceClient(ConnectParam.newBuilder()
                    .withHost(ip)
                    .withPort(port)
                    .build());
            return true;
        } catch (Exception ex){
            logger.error("创建 MilvusServiceClient 错误：" + ip, ex);
            return false;
        }
    }

    @Override
    public void close() {
        if(this.client != null){
            this.client.close();
        }
    }

    @Override
    public boolean createTable(Table table) {
        this.checkConnection();
        if(table == null){
            logger.error("table 必须提供");
            return false;
        }
        List<FieldType> fieldList = table.getFieldTypes();
        if(fieldList == null || fieldList.size() == 0){
            logger.error("未找到任何字段信息");
            return false;
        }
        String tableName = table.getCollectionName();
        if(tableName == null || tableName.equals("")){
            logger.error("表名必须提供：tableName");
            return false;
        }

        try{
            List<io.milvus.param.collection.FieldType> milvusFieldList = new ArrayList<>();
            for(FieldType ft : fieldList){
                milvusFieldList.add(FieldTypeUtils.toMilvusFieldType(ft, table.getDimension()));
            }

            CreateCollectionParam.Builder builder = CreateCollectionParam.newBuilder();
            CreateCollectionParam param = builder.withCollectionName(tableName)
                    .withDescription(table.getDescription())
                    .withShardsNum(table.getShardsNum())
                    .withFieldTypes(milvusFieldList)
                    .build();
            R<RpcStatus> statusR = this.client.createCollection(param);
            return this.checkStatusR(statusR);
        } catch (Exception ex){
            logger.error("创建向量表失败：" + tableName, ex);
          return false;
        }
    }

    @Override
    public void dropTable(String tableName){
        this.checkConnection();
        if(tableName == null || tableName.equals("")){
            logger.error("表名必须提供：tableName");
            return;
        }
        this.client.dropCollection(DropCollectionParam.newBuilder()
                .withCollectionName(tableName)
                .build());
    }

    @Override
    public boolean insertDataSet(DataSet dataSet){
        this.checkConnection();
        if(dataSet == null){
            return false;
        }
        String tableName = dataSet.getTableName();
        if(tableName == null || tableName.equals("")){
            logger.error("表名必须提供：tableName");
            return false;
        }
        Map<String, List<?>> fields = dataSet.getFields();
        if(fields == null || fields.size() == 0){
            logger.error("数据集合必须提供：fields");
            return false;
        }

        List<InsertParam.Field> fieldList = new ArrayList<>();
        for(Map.Entry<String, List<?>> entry : fields.entrySet()){
            fieldList.add(new InsertParam.Field(entry.getKey(), entry.getValue()));
        }

        InsertParam.Builder builder = InsertParam.newBuilder();
        builder.withCollectionName(dataSet.getTableName());
        if(dataSet.getPartitionName() != null && !dataSet.getPartitionName().equals("")){
            builder.withPartitionName(dataSet.getPartitionName());
        }
        builder.withFields(fieldList);

        InsertParam insertParam = builder.build();
        R<MutationResult> statusR = this.client.insert(insertParam);
        if(statusR == null){
            return false;
        }
        if(statusR.getStatus().intValue() == R.Status.Success.getCode()){
            logger.error("insert 返回值：" + statusR.getStatus().intValue());
            return true;
        }
        return false;
    }

    @Override
    public boolean createIndex(String tableName, String fieldName, String indexType, String indexParam){
        this.checkConnection();
        IndexType INDEX_TYPE = null;
        if(indexType.equals("IVF_FLAT")){
            INDEX_TYPE = IndexType.IVF_FLAT;
        } else if(indexType.equals("IVF_SQ8")){
            INDEX_TYPE = IndexType.IVF_SQ8;
        } else if(indexType.equals("IVF_PQ")){
            INDEX_TYPE = IndexType.IVF_PQ;
        } else if(indexType.equals("HNSW")){
            INDEX_TYPE = IndexType.HNSW;
        } else if(indexType.equals("ANNOY")){
            INDEX_TYPE = IndexType.ANNOY;
        } else if(indexType.equals("FLAT")){
            INDEX_TYPE = IndexType.FLAT;
        } else {
            throw new IllegalArgumentException("暂不支持其他索引类型：" + indexType);
        }

        CreateIndexParam.Builder builder = CreateIndexParam.newBuilder();
        builder.withCollectionName(tableName)
                .withFieldName(fieldName)
                .withIndexName(fieldName + "_index")
                .withIndexType(INDEX_TYPE)
                .withMetricType(MetricType.L2)
                .withExtraParam(indexParam)
                .withSyncMode(false);

        R<RpcStatus> statusR = this.client.createIndex(builder.build());
        return checkStatusR(statusR);
    }

    @Override
    public boolean dropIndex(String tableName, String fieldName){
        this.checkConnection();
        R<RpcStatus> statusR = this.client.dropIndex(DropIndexParam.newBuilder()
                .withCollectionName(tableName)
                .withIndexName(fieldName + "_index")
                .build());
        return checkStatusR(statusR);
    }

    @Override
    public boolean prepareSearch(String tableName){
        this.checkConnection();
        R<RpcStatus> statusR = this.client.loadCollection(LoadCollectionParam.newBuilder().withCollectionName(tableName).build());
//        if(statusR == null){
//            return false;
//        }
//        if(statusR.getStatus().intValue() == R.Status.Success.getCode()){
//            return true;
//        }
//        return false;
        return checkStatusR(statusR);
    }

    @Override
    public OutData searchVector(Query query){
        this.checkConnection();
        List<List<Float>> search_vectors = query.getSearchVectors();
        if(search_vectors == null){
            logger.error("未设置搜索向量条件：search_vectors");
            return null;
        }
        String vectorField = query.getVectorName();
        if(vectorField == null || vectorField.equals("")){
            logger.error("未设置搜索字段名称：vectorField");
            return null;
        }

        List<String> outputFieldList = query.getOutFieldList();
        if(outputFieldList == null || outputFieldList.size() == 0){
            logger.error("未设置输出字段名称：OutFieldList");
            return null;
        }
        MetricType metricType = null;
        if(query.getMetricType() == null || query.getMetricType().equals("")){
            metricType = MetricType.L2;
        }

        SearchParam searchParam = SearchParam.newBuilder()
                .withCollectionName(query.getTableName())
                .withMetricType(metricType)
                .withOutFields(outputFieldList)
                .withTopK(query.getTopK())
                .withVectors(query.getSearchVectors())
                .withVectorFieldName(query.getVectorName())
                .withParams(query.getSearchParam())
                .build();

        R<SearchResults> respSearch = this.client.search(searchParam);
        if(respSearch == null){
            logger.warn("未搜索到相似结果对象：" + query);
            return null;
        }
        SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(respSearch.getData().getResults());
        System.out.println(wrapperSearch.getIDScore(0));

        // 设置一个分值，评分过低的结果直接过滤。2022-08-26
        OutData outData = new OutData();

        List<SearchResultsWrapper.IDScore> scoreList = wrapperSearch.getIDScore(0);
        if(scoreList != null && scoreList.size() > 0){
            for(SearchResultsWrapper.IDScore idScore : scoreList){
                outData.addIdScore(idScore.getLongID(), idScore.getScore());
            }
        }
        for(String outField : outputFieldList){
            if(outField.equals("id")){
                outData.setKeyList((List<Long>)wrapperSearch.getFieldData("id", 0));
            } else {
                outData.setBusinessIdList((List<Long>)wrapperSearch.getFieldData(outField, 0));
            }
        }
//        System.out.println(wrapperSearch.getFieldData("book_id", 0));
//        return wrapperSearch.getFieldData(query.getFieldPrimaryKey(), 0);
        return outData;
    }

    @Override
    public void releaseSearch(String tableName){
        this.checkConnection();
        this.client.releaseCollection(ReleaseCollectionParam.newBuilder()
                .withCollectionName(tableName)
                .build());
    }

    private void checkConnection(){
        if(this.client == null){
            throw new RuntimeException("服务未连接，请先连接 milvus 服务");
        }
    }

    private boolean checkStatusR(R<RpcStatus> statusR){
        if(statusR == null){
            return false;
        }
        if(statusR.getStatus().intValue() == R.Status.Success.getCode()){
            return true;
        }
        return false;
    }
}
