import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
import os
import argparse
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib import cm
import pandas as pd

# Define constants for colors
colors = cm.get_cmap('Blues', 50)
colors_yolo = cm.get_cmap('Greens', 50)


def arg_directory(path):
    if os.path.isdir(path):
        return path
    else:
        raise argparse.ArgumentTypeError(f'`{path}` is not a valid path')

def overlay_annotations(image_path, annotation_path, detection_path, output_directory):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    H, W = image.shape[0], image.shape[1]
    base_name = os.path.splitext(os.path.basename(image_path))[0]

    try:
        shape = pd.read_csv(annotation_path, sep=' ', names=['label', 'x', 'y', 'w', 'h'])
        shape_yolo = pd.read_csv(detection_path, sep=' ', names=['label', 'x', 'y', 'w', 'h'])
    except Exception:
        return

    for shape_df, colors_palette in zip([shape, shape_yolo], [colors, colors_yolo]):
        for z in range(len(shape_df)):
            x, y, w, h = shape_df.iloc[z][['x', 'y', 'w', 'h']] * [W, H, W, H]

            # Calculate rectangle coordinates
            shape1 = (int(x - 0.5 * w), int(y + 0.5 * h))
            shape4 = (int(x + 0.5 * w), int(y - 0.5 * h))

            # Calculate text coordinates
            shp1 = (shape1[0], shape1[1] + 20)
            shp4 = (shape4[0], shape4[1])
            text_shape = (shp1[0], shp1[1] - 5)

            label = str(shape_df.label.iloc[z])

            # Draw rectangle and text
            cv2.rectangle(image, pt1=shape1, pt2=shape4, color=colors_palette[shape_df.label.iloc[z]], thickness=1)
            cv2.rectangle(image, pt1=shp1, pt2=shp4, color=colors_palette[shape_df.label.iloc[z]], thickness=-1)
            cv2.putText(image, label, text_shape, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

    output_path = os.path.join(output_directory, f'{base_name}.jpg')
    if not os.path.exists(output_directory):
        os.mkdir(output_directory)
        
    plt.imshow(image, cmap='jet')
    plt.title('Blues : ANNOTATION; Greens : YOLO DETECTION', loc = 'center')

    plt.savefig(output_path)
    plt.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='TODO')
    parser.add_argument('-p', '--path_to_data', type=str, help='Path of the folder that contains the .jpg files (*/set/images/)', required=True)
    parser.add_argument('-s', '--detection', type=str, help='Path the folder containing the .txt detection (*/exp/labels)', required=True)
    parser.add_argument('-d', '--directory', type=arg_directory, help='Directory to which the overlayed images will be stored', required=True)
    parser.add_argument('-a', '--annotation', type=str, help='Path the .txt containing the annotation (*/train/labels/)', required=True)
    args = parser.parse_args()

    path = args.path_to_data
    directory = args.directory
    detection = args.detection
    annotation = args.annotation

    image_files = glob.glob(os.path.join(path, '*', '*.jpg'))

    for image_path in image_files:
        base_name = os.path.splitext(os.path.basename(image_path))[0]
        annotation_path = os.path.join(annotation, f'{base_name}.txt')
        detection_path = os.path.join(detection, f'{base_name}.txt')

        overlay_annotations(image_path, annotation_path, detection_path, directory)
