package cn.myeasyai.face;

import cn.myeasyai.temple.AllTemple;
import cn.myeasyai.temple.MyTemples;
import cn.myeasyai.temple.Temple;
import cn.myeasyai.temple.TempleMessage;
import cn.myeasyai.tools.Tools;
import org.dromara.easyai.config.RZ;
import org.dromara.easyai.entity.Box;
import org.dromara.easyai.entity.ThreeChannelMatrix;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.nerveCenter.NerveManager;
import org.dromara.easyai.nerveEntity.SensoryNerve;
import org.dromara.easyai.pso.PSO;
import org.dromara.easyai.tools.NMS;
import org.dromara.easyai.tools.Picture;
import org.dromara.easyai.yolo.*;

import java.util.*;

/**
 * 人脸信息
 *
 * @author lidapeng
 */
public class Face {
    public YoloConfig yoloConfig = new YoloConfig();
    private final MatrixOperation matrixOperation = new MatrixOperation();
    /**
     * 五官定位
     */
    private TypeBody trTypeBody;
    private final int faceWidth;
    private final int faceHeight;
    private final int lastXSize;
    private final int lastYSize;
    private final int pictureHeight;
    private final int pictureWidth;
    private final ThreeChannelMatrix avg;
    /**
     * 平均脸的lbp特征
     */
    private final Matrix avgLBP;
    private final double trustTh;
    private final int moveSize;

    public Face(FaceTestConfig faceTestConfig, ThreeChannelMatrix avg) throws Exception {//初始化
        pictureWidth = faceTestConfig.getPictureWidth();
        trustTh = faceTestConfig.getTrustTh();
        moveSize = faceTestConfig.getMoveSize();
        pictureHeight = faceTestConfig.getPictureHeight();
        faceWidth = faceTestConfig.getPictureWidth() / 3;
        faceHeight = faceTestConfig.getPictureHeight() / 3;
        lastXSize = faceTestConfig.getLastXSize();
        lastYSize = faceTestConfig.getLastYSize();
        yoloConfig.setContainIouTh(0.2);
        yoloConfig.setWindowWidth(faceTestConfig.getPictureWidth() / 3);
        yoloConfig.setWindowHeight(faceTestConfig.getPictureHeight() / 3);
        yoloConfig.setRegular(0.01);
        yoloConfig.setRegularModel(RZ.L1);
        this.avg = avg;
        if (avg != null) {
            int XSize = (int) (avg.getX() * 0.6);
            int YSize = avg.getY();
            ThreeChannelMatrix avm = avg.cutChannel((int) (avg.getX() * 0.1), 0, XSize, YSize);
            avgLBP = avm.getLBPMatrix();
        } else {
            avgLBP = null;
        }
    }

    public void insertModel(FaceModel faceModel) throws Exception {//插入模型
        YoloModel yoloModel = faceModel.getYoloModel();
        List<TypeModel> typeModels = yoloModel.getTypeModels();
        trTypeBody = Tools.getTypeBody(typeModels.get(0), yoloConfig);
    }


    public FaceModel getModel() throws Exception {//返回模型
        FaceModel faceModel = new FaceModel();
        YoloModel yoloModel = new YoloModel();
        faceModel.setYoloModel(yoloModel);
        List<TypeModel> typeModels = new ArrayList<>();
        typeModels.add(getTypeModel(trTypeBody));
        yoloModel.setTypeModels(typeModels);
        return faceModel;
    }

    private TypeModel getTypeModel(TypeBody typeBody) throws Exception {
        TypeModel typeModel = new TypeModel();
        typeModel.setTypeID(typeBody.getTypeID());
        typeModel.setMappingID(typeBody.getMappingID());
        typeModel.setMinHeight(typeBody.getMinHeight());
        typeModel.setMinWidth(typeBody.getMinWidth());
        typeModel.setMaxWidth(typeBody.getMaxWidth());
        typeModel.setMaxHeight(typeBody.getMaxHeight());
        typeModel.setPositionModel(typeBody.getPositonNerveManager().getConvModel());
        return typeModel;
    }

    private Box getBox(int i, int j, int maxX, int maxY, TypeBody typeBody, double distX, double distY
            , double pWidth, double pHeight, double trust) {
        boolean des = false;
        double centerX = i - distX * pictureHeight;
        double centerY = j - distY * pictureHeight;
        int width = (int) typeBody.getRealWidth(pWidth);
        int height = (int) typeBody.getRealHeight(pHeight);
        int realX = (int) (centerX - height / 2);
        int realY = (int) (centerY - width / 2);
        if (realX < 0) {
            des = true;
        }
        if (realY < 0) {
            des = true;
        }
        if (realX + height > maxX) {
            des = true;
        }
        if (realY + width > maxY) {
            des = true;
        }
        Box box = null;
        if (!des && trust > 0.49) {
            box = new Box();
            box.setX(realX);
            box.setY(realY);
            box.setxSize(height);
            box.setySize(width);
            box.setConfidence(trust);
            box.setTypeID(typeBody.getTypeID());
        }
        return box;
    }

    public ThreeChannelMatrix uniform(ThreeChannelMatrix threeChannelMatrix) throws Exception {
        if (threeChannelMatrix.getX() == pictureHeight && threeChannelMatrix.getY() == pictureWidth) {
            return threeChannelMatrix;
        }
        ThreeChannelMatrix sTrMatrix = threeChannelMatrix.scale(true, pictureWidth);
        ThreeChannelMatrix myThreeMatrix = Tools.fillColor(sTrMatrix, pictureHeight, pictureWidth);
        if (myThreeMatrix == null) {
            myThreeMatrix = sTrMatrix;
        }
        return myThreeMatrix;
    }

    public ErrorMessage look(ThreeChannelMatrix face, long eventID, int secondExplore) throws Exception {
        FaceMessage faceMessage = null;
        if (avg != null) {
            ThreeChannelMatrix th = uniform(face);//缩放人脸 将人脸缩放到标准尺寸
            int x = th.getX();
            int y = th.getY();
            List<Box> trBoxes = new ArrayList<>();
            NMS nms = new NMS(yoloConfig.getIouTh());
            int width = yoloConfig.getWindowWidth();
            int height = yoloConfig.getWindowHeight();
            for (int i = 0; i <= x - height; i += height) {
                for (int j = 0; j <= y - width; j += width) {
                    PositionBack trPositionBack = new PositionBack();
                    ThreeChannelMatrix myTh = th.cutChannel(i, j, height, width);
                    List<SensoryNerve> trSensoryNerves = trTypeBody.getPositonNerveManager().getSensoryNerves();
                    Tools.study(eventID, trSensoryNerves, myTh, false, null, trPositionBack);
                    Box trBox = getBox(i, j, x, y, trTypeBody, trPositionBack.getDistX(), trPositionBack.getDistY()
                            , trPositionBack.getWidth(), trPositionBack.getHeight(), trPositionBack.getTrust());
                    if (trBox != null) {
                        trBoxes.add(trBox);
                    }
                }
            }
            ErrorMessage errorMessage = new ErrorMessage();
            if (!trBoxes.isEmpty()) {
                List<Box> myTrBoxes = nms.start(trBoxes);
                Box box = getTrustBox(myTrBoxes);
                FaceMessage myFaceMessage = secondCorrect(box, th, secondExplore);//二次定位
                double dist = myFaceMessage.getDist();
                if (dist < trustTh) {
                    faceMessage = myFaceMessage;
                    errorMessage.setErrorCode(0);
                    errorMessage.setErrorMessage("正常");
                } else {
                    errorMessage.setErrorCode(3);
                    errorMessage.setErrorMessage("找不到人脸,照片曝光过高，或者过暗，或者阴阳脸");
                }
                errorMessage.setFaceMessage(faceMessage);
            } else {
                errorMessage.setErrorCode(1);
                errorMessage.setErrorMessage("找不到人脸，没有合法的人脸");
            }
            return errorMessage;
        } else {
            throw new Exception("没有配置平均脸");
        }
    }

    private Box getTrustBox(List<Box> myTrBoxes) {
        double maxTrust = 0;
        Box myBox = null;
        for (Box box : myTrBoxes) {
            double trust = box.getConfidence();
            if (trust > maxTrust) {
                maxTrust = trust;
                myBox = box;
            }
        }
        return myBox;
    }

    private FaceMessage secondCorrect(Box box, ThreeChannelMatrix pic, int secondExplore) throws Exception {//做二次修正
        int bx = pic.getX();
        int by = pic.getY();
        int x = box.getX();
        int y = box.getY();
        int minX = x - moveSize;
        int maxX = x + moveSize;
        int minY = y - moveSize;
        int maxY = y + moveSize;
        int xSize = box.getxSize();
        int ySize = box.getySize();
        if (minX < 0) {
            minX = 0;
        }
        if (maxX >= bx - xSize) {
            maxX = bx - xSize - 1;
        }
        if (minY < 0) {
            minY = 0;
        }
        if (maxY >= by - ySize) {
            maxY = by - ySize - 1;
        }
        double[] minBorder = new double[]{minX, minY};
        double[] maxBorder = new double[]{maxX, maxY};
        Matrix avg = getMatrixE(this.avg);
        SecondPosition secondPosition = new SecondPosition(avg, pic, this, xSize, ySize);
        PSO pso = new PSO(2, minBorder, maxBorder, secondExplore, secondExplore, secondPosition, 0.5, 2, 2,
                false, moveSize / 2D, 2);
        pso.start();
        double[] bestPosition = pso.getAllBest();
        int realX = (int) bestPosition[0];
        int realY = (int) bestPosition[1];
        FaceMessage faceMessage = new FaceMessage();
        ThreeChannelMatrix tm = pic.cutChannel(realX, realY, xSize, ySize);
        ThreeChannelMatrix sTrMatrix = tm.scale(true, faceWidth);
        ThreeChannelMatrix myThreeMatrix = Tools.fillColor(sTrMatrix, faceHeight, faceWidth);
        if (myThreeMatrix == null) {
            myThreeMatrix = sTrMatrix;
        }
        int XSize = (int) (myThreeMatrix.getX() * 0.6);
        int YSize = myThreeMatrix.getY();
        myThreeMatrix = myThreeMatrix.cutChannel((int) (myThreeMatrix.getX() * 0.1), 0, XSize, YSize);
        faceMessage.setChannel(myThreeMatrix);
        Matrix lbpMatrix = myThreeMatrix.getLBPMatrix();
        double normDist = matrixOperation.getEDistByMatrix(lbpMatrix, avgLBP);
        faceMessage.setDist(normDist);
        faceMessage.setFeature(lbpMatrix);
        return faceMessage;
    }

    private List<TempleMessage> anySort(List<TempleMessage> sentences) {//做乱序
        Random random = new Random();
        List<TempleMessage> sent = new ArrayList<>();
        int time = sentences.size();
        for (int i = 0; i < time; i++) {
            int size = sentences.size();
            int index = random.nextInt(size);
            sent.add(sentences.get(index));
            sentences.remove(index);
        }
        return sent;
    }

    public Matrix lookFace(ThreeChannelMatrix picture) throws Exception {
        ThreeChannelMatrix sTrMatrix = picture.scale(true, faceWidth);
        ThreeChannelMatrix myThreeMatrix = Tools.fillColor(sTrMatrix, faceHeight, faceWidth);
        if (myThreeMatrix == null) {
            myThreeMatrix = sTrMatrix;
        }
        return getMatrixE(myThreeMatrix);
    }

    public Matrix getMatrixE(ThreeChannelMatrix subMatrix) throws Exception {
        int x = subMatrix.getX();
        int y = subMatrix.getY();
        int XStep = x / lastXSize + 1;
        int YStep = y / lastYSize + 1;
        Matrix matrixR = subMatrix.getMatrixR();
        Matrix matrixG = subMatrix.getMatrixG();
        Matrix matrixB = subMatrix.getMatrixB();
        Matrix myMatrix = new Matrix(lastXSize, lastYSize);
        for (int i = 0; i <= x - XStep; i += XStep) {
            for (int j = 0; j <= y - YStep; j += YStep) {
                double tr = matrixR.getSonOfMatrix(i, j, XStep, YStep).getAVG();
                double tg = matrixG.getSonOfMatrix(i, j, XStep, YStep).getAVG();
                double tb = matrixB.getSonOfMatrix(i, j, XStep, YStep).getAVG();
                double value = tr + tg + tb;
                int tx = i / XStep;
                int ty = j / YStep;
                myMatrix.setNub(tx, ty, value);
            }
        }
        return matrixOperation.softMaxByMatrix(myMatrix);
    }

    public void facePositionStudy(List<AllTemple> allTemples) throws Exception {
        Temple temple = new Temple();
        MyTemples templeMessages = temple.readALlTemple(allTemples, yoloConfig);//
        List<TempleMessage> temples = anySort(templeMessages.getTempleMessages());//标注信息
        TypeBody trTypeBody = templeMessages.getTrTypeBody();//五官位置网络
        this.trTypeBody = trTypeBody;
        NMS nms = new NMS(yoloConfig.getContainIouTh());
        double size = temples.size();
        int index = 0;
        for (TempleMessage templeMessage : temples) {
            index++;
            double jin = index / size * 100;
            System.out.println("训练进度=====================" + jin + "%");
            String fileName = templeMessage.getFileName();//文件名称
            YoloBody trYoloBody = templeMessage.getTrYoloBody();//五官信息
            ThreeChannelMatrix threeChannelMatrix = Picture.getThreeMatrix(fileName, true);//读取当前图片
            studyPicture(threeChannelMatrix, trYoloBody, trTypeBody, nms);
        }
        System.out.println("训练完毕=================");
    }


    private void studyPicture(ThreeChannelMatrix picture, YoloBody yoloBody, TypeBody typeBody, NMS nms) throws Exception {
        int x = picture.getX();
        int y = picture.getY();
        int width = yoloConfig.getWindowWidth();
        int height = yoloConfig.getWindowHeight();
        int widthStep = (int) (width * 0.5);
        int heightStep = (int) (height * 0.5);
        NerveManager trPositionManager = typeBody.getPositonNerveManager();
        for (int i = 0; i <= x - height; i += heightStep) {
            for (int j = 0; j <= y - width; j += widthStep) {
                Box testBox = new Box();
                testBox.setX(i);
                testBox.setY(j);
                testBox.setxSize(height);
                testBox.setySize(width);
                YoloMessage yoloMessage = containSample(nms, yoloBody, testBox, typeBody);
                ThreeChannelMatrix small = picture.cutChannel(i, j, height, width);
                if (yoloMessage != null) {//满足拟合条件进行拟合
                    Map<Integer, Double> positionE = new HashMap<>();
                    positionE.put(1, yoloMessage.getDistX());
                    positionE.put(2, yoloMessage.getDistY());
                    positionE.put(3, yoloMessage.getWidth());
                    positionE.put(4, yoloMessage.getHeight());
                    positionE.put(5, yoloMessage.getTrust());
                    Tools.study(1, trPositionManager.getSensoryNerves(), small, true, positionE, null);
                }
            }
        }
    }

    private YoloMessage containSample(NMS nms, YoloBody yoloBody, Box testBox, TypeBody typeBody) {
        YoloMessage yoloMessage = null;
        Box box = new Box();
        box.setX(yoloBody.getY());
        box.setY(yoloBody.getX());
        box.setxSize(yoloBody.getHeight());
        box.setySize(yoloBody.getWidth());
        double iou = nms.getSRatio(testBox, box, false);
        if (iou > yoloConfig.getContainIouTh()) {//进行位置拟合
            yoloMessage = new YoloMessage();
            int centerX = box.getX() + box.getxSize() / 2;
            int centerY = box.getY() + box.getySize() / 2;
            double distX = (double) (testBox.getX() - centerX) / pictureHeight;
            double distY = (double) (testBox.getY() - centerY) / pictureHeight;
            double height = typeBody.getOneHeight(box.getxSize());
            double width = typeBody.getOneWidth(box.getySize());
            double trust = 0;
            if (centerX >= testBox.getX() && centerX <= (testBox.getX() + testBox.getxSize()) &&
                    centerY >= testBox.getY() && centerY <= (testBox.getY() + testBox.getySize())) {
                trust = 1;
            }
            yoloMessage.setWidth(width);
            yoloMessage.setHeight(height);
            yoloMessage.setDistX(distX);
            yoloMessage.setDistY(distY);
            yoloMessage.setTrust(trust);
            yoloMessage.setMappingID(typeBody.getMappingID());
        }
        return yoloMessage;
    }
}
