/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package app.pivo.android.plussdk;

import android.content.Context;
import android.graphics.Color;
import android.graphics.ImageFormat;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Paint.Cap;
import android.graphics.Paint.Join;
import android.graphics.Paint.Style;
import android.graphics.PointF;
import android.graphics.Rect;
import android.graphics.RectF;
import android.os.SystemClock;
import android.util.Log;
import android.util.Pair;

import java.util.LinkedList;
import java.util.List;

import app.pivo.android.plussdk.Classifier.*;


/**
 * A tracker that handles non-max suppression and matches existing objects to new detections.
 */
class MultiBoxTracker {

    private final String TAG = this.getClass().getSimpleName();

    private Context mContext;

    private boolean trackingProcessing = false;

    private static final float MAX_OVERLAP = 0.2f;
    private static final float MIN_SIZE = 50.0f;

    private boolean frontCamera;

    private static final int[] COLORS = {
            Color.BLUE,
            Color.RED
    };

    private final List<Pair<Float, RectF>> screenRects = new LinkedList<>();
    private final List<TrackedRecognition> trackedObjects = new LinkedList<>();
    private final Paint boxPaint = new Paint();

    private Matrix frameToCanvasMatrix;
    private int frameWidth;
    private int frameHeight;
    private int sensorOrientation;


    private ActionTracker mTracker;

    private float multiplier = 1.0f;

    private PointF center;

    private Rect targetToTrack;

    private long mLastTime = 0;

    private boolean isInitialized = false;
    private long time = 0;

    public MultiBoxTracker(Context context, ActionTracker objectTracker) {
        mContext = context;
        mTracker = objectTracker;
        boxPaint.setStyle(Style.STROKE);
        boxPaint.setStrokeWidth(5.0f);
        boxPaint.setStrokeCap(Cap.ROUND);
        boxPaint.setStrokeJoin(Join.ROUND);
        boxPaint.setStrokeMiter(100);
    }

    public boolean isTrackingProcessing(){
        return trackingProcessing;
    }

    public void setTargetToTrack(Rect target){
        this.targetToTrack = target;
    }

    public void setFrontCamera(boolean frontCamera){
        this.frontCamera = frontCamera;
    }

    public synchronized void setFrameConfiguration(
            final int width, final int height, final int sensorOrientation) {
        frameWidth = width;
        frameHeight = height;
        this.sensorOrientation = sensorOrientation;
    }

    public synchronized void trackResults(final List<Classifier.Recognition> results) {
        processResults(results);
    }

    private Matrix getFrameToCanvasMatrix() {
        return frameToCanvasMatrix;
    }

    public void setLayout(int width, int height){
        this.cWidth = height;
        this.cHeight = width;
    }

    private int cHeight ;
    private  int cWidth;

    public float getMultiplier() {
        return multiplier;
    }

    public Rect draw(){
        if (!trackedObjects.isEmpty()){

            boolean rotated = sensorOrientation % 180 == 90;

            multiplier = Math.min(
                            cHeight / (float) (rotated ? frameWidth : frameHeight),
                            cWidth / (float) (rotated ? frameHeight : frameWidth));

            frameToCanvasMatrix = ImageUtils.getTransformationMatrix(
                            frameWidth,
                            frameHeight,
                            (int) (multiplier * (rotated ? frameHeight : frameWidth)),
                            (int) (multiplier * (rotated ? frameWidth : frameHeight)),
                            sensorOrientation,
                            false);

            if (targetToTrack != null ) {
                Rect roi = new Rect((int) (targetToTrack.left * multiplier),
                        (int) (targetToTrack.top * multiplier),
                        (int) (targetToTrack.right * multiplier),
                        (int) (targetToTrack.bottom * multiplier));
                return roi;
            }

//            for (final TrackedRecognition recognition : trackedObjects) {
//                final RectF trackedPos = new RectF(recognition.location);
//
//                getFrameToCanvasMatrix().mapRect(trackedPos);
//                boxPaint.setColor(COLORS[0]);
//                canvas.drawRect(trackedPos, boxPaint);
//            }
        }
        return null;
    }

    public void track(final byte[] data, int w, int h, Classifier.Recognition result, int rotation, boolean orientationLocked, String label) {

        long currentTime = SystemClock.uptimeMillis();
        long elapsedTime = currentTime - mLastTime;

        if (mLastTime != 0) time = time + elapsedTime;
        mLastTime = currentTime;

        Rect roi = new Rect();

        final RectF trackedPos = result.getLocation();
        if (getFrameToCanvasMatrix()==null)return;

        getFrameToCanvasMatrix().mapRect(trackedPos);

        if (frontCamera){
            if (!orientationLocked){
                float temp = trackedPos.right;
                trackedPos.right = cWidth - trackedPos.left;
                trackedPos.left = cWidth - temp;
            }else{
                float temp = trackedPos.bottom;
                trackedPos.bottom = cHeight - trackedPos.top;
                trackedPos.top = cHeight - temp;
            }
        }

        trackedPos.round(roi);

        roi = new Rect((int) (roi.left / multiplier),
                (int) (roi.top / multiplier),
                (int) (roi.right / multiplier),
                (int) (roi.bottom / multiplier));

        int locked = orientationLocked?0:1;

        if (!isInitialized) {
            if (mTracker != null) {
                trackingProcessing = true;
                mTracker.initCamera(data, ImageFormat.YUV_420_888, w, h, roi, frontCamera ? 1 : 0, rotation, locked);
                isInitialized = true;
            }
        }

        if (isInitialized) {
            mTracker.updateCameraFrame(data, ImageFormat.YUV_420_888, w, h, frontCamera ? 1 : 0, rotation, locked);
            if (time > getDetectionTimeInterval(result)) {
                trackingProcessing = false;
                isInitialized = false;
                time = 0;
            }
        }
    }

    private int getDetectionTimeInterval(Classifier.Recognition result){
        switch (RamClassifier.classify(mContext)){
            case A1:
            case B1:
                if (result.getTitle().compareTo("person") == 0)
                    return 1000;
                return 5000;
            default:
                if (result.getTitle().compareTo("person") == 0)
                    return 3000;
                return 5000;
        }
    }

    private void processResults(final List<Classifier.Recognition> results) {
        final List<Pair<Float, Classifier.Recognition>> rectsToTrack = new LinkedList<>();
        screenRects.clear();
        final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix());
        for (final Classifier.Recognition result : results) {
            if (result.getLocation() == null) {
                continue;
            }
            final RectF detectionFrameRect = new RectF(result.getLocation());

            final RectF detectionScreenRect = new RectF();
            rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect);
            screenRects.add(new Pair<>(result.getConfidence(), detectionScreenRect));
            if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) {
                continue;
            }
            rectsToTrack.add(new Pair<>(result.getConfidence(), result));
        }

        trackedObjects.clear();

        for (final Pair<Float, Classifier.Recognition> potential : rectsToTrack) {
            final TrackedRecognition trackedRecognition = new TrackedRecognition();
            trackedRecognition.detectionConfidence = potential.first;
            trackedRecognition.location = new RectF(potential.second.getLocation());
            trackedRecognition.title = potential.second.getTitle();
            trackedRecognition.color = COLORS[0];
            trackedObjects.add(trackedRecognition);
        }
    }

    private static class TrackedRecognition {
        RectF location;
        float detectionConfidence;
        int color;
        String title;
    }

    public Classifier.Recognition handleDetection(final List<Recognition> proposedTargets, Recognition potential) {

        if (potential !=null && !proposedTargets.isEmpty()){

            float maxIOU = 0.0f;
            Recognition newTarget = null;

            for (final Recognition target : proposedTargets) {
                final RectF a = target.getLocation();
                final RectF b = potential.getLocation();
                final RectF intersection = new RectF();

                final boolean intersect = intersection.setIntersect(a, b);

                final float intersectArea = intersection.width() * intersection.height();
                final float totalArea = a.width() * a.height() + b.width() * b.height() - intersectArea;
                final float intersectionOverUnion = intersectArea / totalArea;

                if (intersect && intersectionOverUnion > MAX_OVERLAP) {
                    if (intersectionOverUnion > maxIOU) {
                        maxIOU = intersectionOverUnion;
                        newTarget = target;
                    }
                }
            }

            return newTarget;
        }
        return null;
    }
}

