from datetime import datetime
import os
import argparse

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import scipy.optimize
from scipy import interpolate
import scipy.signal


def get_angles(data, degree=True):
    pitch = np.arctan2(data.a_x, np.sqrt(data.a_y**2 + data.a_z**2))
    roll = np.arctan2(-data.a_y, np.sqrt(data.a_x**2 + data.a_z**2))
    x_roll, y_roll = np.cos(roll), np.sin(roll)
    y_pitch = np.sin(pitch)
    yaw = np.arctan2(-data.m_y*x_roll - data.m_z*y_roll,
                      data.m_x*np.cos(pitch) + data.m_y*y_roll*y_pitch + data.m_z*x_roll*y_pitch)
    out = np.array([yaw, pitch, roll]).T
    if degree:
        out = out * 180/np.pi
    return out


def ellipsoid_fit(X) :
    # need nine or more data points
    assert np.shape(X)[1] == 3
    x = X[:, 0].reshape(-1, 1)
    y = X[:, 1].reshape(-1, 1)
    z = X[:, 2].reshape(-1, 1)
    if np.shape(x)[0] < 9:
        print('Must have at least 9 points to fit a unique ellipsoid')

    D = np.concatenate((X**2, 2*x*y, 2*x*z, 2*y*z, 2*X), axis=1)

    A = D.T @ D
    B = D.T @ np.ones(x.shape)

    # solve the normal system of equations
    v = scipy.optimize.lsq_linear(A, B.flatten())
    v = v.x.flatten()
    # form the algebraic form of the ellipsoid
    A = np.array([[v[0], v[3], v[4], v[6]],
                  [v[3], v[1], v[5], v[7]],
                  [v[4], v[5], v[2], v[8]],
                  [v[6], v[7], v[8], -1 ]])
    # find the center of the ellipsoid
    center = np.linalg.lstsq(-A[0:3, 0:3], np.array([v[6], v[7], v[8]]).reshape(-1, 1), rcond=None)
    center = center[0].flatten()
    # form the corresponding translation matrix
    T = np.eye(4)
    T[3, :3] = center
    # translate to the center
    R = T @ A @ T.T
    # solve the eigenproblem
    evals, evecs = np.linalg.eig(R[:3, :3] / -R[3,3])
    evals = evals.reshape(-1,1)
    radii = np.sqrt( 1 / evals)
    return center, radii, evecs, v


def sphereFit(magxyz):
    #   Assemble the A matrix
    spX = magxyz[:,0]
    spY = magxyz[:,1]
    spZ = magxyz[:,2]
    A = np.zeros((len(spX),4))
    A[:,0] = spX*2
    A[:,1] = spY*2
    A[:,2] = spZ*2
    A[:,3] = 1

    #   Assemble the f matrix
    f = np.zeros((len(spX),1))
    f[:,0] = (spX*spX) + (spY*spY) + (spZ*spZ)
    C, residules, rank, singval = np.linalg.lstsq(A,f)
    print(C.shape)
    t = (C[0]*C[0])+(C[1]*C[1])+(C[2]*C[2])+C[3]
    print(t.shape)
    radius = np.sqrt(t)[0]
    center = C[0:3].reshape(3)
    return center, radius


def mag_calibration(calib_dataframe):

    X = calib_dataframe[['m_x','m_y','m_z']].values
    print(calib_dataframe)

    ## step 1 :
    ## estimation of the center of the ellipsoid and the magnetic field strength :
    precal_center, magfield = get_EllipsoidCenter_MagnFieldStr(X)

    ## step 2 :
    ## recenter m_ data :
    X_centered = X - precal_center

    ## step 3 :
    ## do ellipsoid fitting
    e_center, e_radii, e_eigenvecs, e_algebraic = ellipsoid_fit(X_centered)

    ## step 4 :
    # compensate distorted magnetometer data
    # e_eigenvecs is an orthogonal matrix, so we can transpose instead of inversing it
    S = X_centered - e_center

    scale = np.linalg.inv(np.array([[e_radii[0,0], 0, 0],
                                    [0, e_radii[1,0], 0],
                                    [0, 0, e_radii[2,0]]])) # scaling matrix

    comp = e_eigenvecs @ scale @ e_eigenvecs.T
    offset = precal_center + e_center
    return offset, comp


def compensate(data, comp_matrix, offset):
    assert data.shape[1] == 3
    data = data - offset
    return data @ comp_matrix.T

def proj(data, radius, offset) :
    data = data - offset
    #data = (data / np.linalg.norm(data, axis=1, keepdims = True)) * radius
    data = data / np.linalg.norm(data, axis=1, keepdims=True)
    return data

def get_EllipsoidCenter_MagnFieldStr(MagValues):
    # MagValues is a Nx3 array containing N x (MagX, MagY and MagZ) data triplets

    # returns a tuple (ellipsoid_center, Magnetic_Field_Strength)
    # ellipsoid_center : 1x3 array , the center of the fitted ellipsoid of the
    # uncalibrated magnetometer data points

    def residual(p, x, y):
        return y - x @ p

    print(MagValues[0,:])
    X = np.concatenate((MagValues, np.ones(MagValues.shape[0]).reshape(-1, 1)), axis=1)
    print(X[0, 0], X[0, 1], X[0, 2])
    print(type(X[0, 0]), type(X[0, 1]), type(X[0, 2]))
    print(X[0, 3])
    print(type(X[0, 3]))

    Y = (X[:, 0]**2+X[:, 1]**2+X[:, 2]**2).reshape(-1, 1)
    p0 = np.array([1.0, 1.0, 1.0, 1.0])

    popt, pcov = scipy.optimize.leastsq(residual, p0,  args=(X, Y))

    # center of ellipsoid
    V = 1/2 * popt[:3]
    # magnetic field strength :
    B = np.sqrt(popt[3] + np.dot(V, V))
    return V, B


## remove outliers from Magnetometer values (interpolates)
def filter_outliers(df):
    for mag_type in 'm_x', 'm_y', 'm_z':
        magdf = df[mag_type].copy()
        magdf_backup = magdf.copy()
        peaks1 = scipy.signal.find_peaks( np.abs(magdf.values), threshold=10)[0]
        peaks2 = scipy.signal.find_peaks(-np.abs(magdf.values), threshold=10)[0]
        peaks = np.concatenate((peaks1,peaks2), axis=0)

        magdf.drop(peaks, axis=0, inplace=True)
        idx = magdf.index.values
        magdf = magdf_backup[idx]
        finterp = interpolate.interp1d(idx, magdf.values)
        vals = finterp(peaks)
        magdf_backup[peaks] = vals
        df[mag_type] = magdf_backup.values
    return df

def make_plot(dataset, data, title, save_path):
    fig = plt.figure(figsize=(10,4), dpi=200)
    for i in range(data.shape[1]) :
        plt.scatter(dataset['time'], data[:,i],s=1)
    plt.legend(['yaw', 'pitch', 'roll'])
    plt.yticks(np.linspace(-180, 180, num=13, endpoint=True))
    plt.grid('on')
    plt.title(title)
    plt.ylabel('Angle (°)')
    plt.xlabel('Time (by 10ms)')
    plt.savefig(save_path)
    plt.close()


def main(args, dataset_path):
    prefix = dataset_path.split('/')[-1].rsplit('.', 1)[0]
    # Create the output directory
    if not os.path.isdir(args.outdirpath):
        os.makedirs(args.outdirpath, exist_ok=True)
        print("folder \'MPU_results\' created in " + str(os.getcwd()))

    # get IMU data
    labels = ['time', 'a_x', 'a_y', 'a_z', 'g_x', 'g_y', 'g_z', 'm_x', 'm_y', 'm_z']
    Magneto = ['m_x', 'm_y', 'm_z']
    try:
        dataset = pd.read_csv(dataset_path, delimiter=',', names=labels, dtype='float')
    except ValueError:
        dataset = pd.read_csv(dataset_path, header=0, names=labels, delimiter=',', dtype='float')

    print(dataset.head())
    dataset = filter_outliers(dataset)

    calib_dataset_path = args.calibration if args.calibration else False

    calibration = calib_dataset_path

    angles_notcalib = get_angles(dataset)
    df_angles_notcalib = pd.DataFrame(angles_notcalib, columns=['Yaw', 'Pitch', 'Roll'])
    df_angles_notcalib.to_csv(os.path.join(args.outdirpath, prefix + '_angles_mag_not_calib.csv'))

    if args.save_plot:
        make_plot(dataset, angles_notcalib, 'Angles without Magnetometer Calibration', os.path.join(args.outdirpath, prefix + '_angles_mag_not_calib.png'))

    if calibration:
        try:
            calib_df = pd.read_csv(calib_dataset_path, delimiter=',' , names=labels, dtype='float')
        except ValueError:
            calib_df = pd.read_csv(calib_dataset_path, header=0, names=labels, delimiter=',', dtype='float')
        calib_df = filter_outliers(calib_df)
        calib_df.to_csv(os.path.join(args.outdirpath, prefix + '_calib_dataset_mag_filtered_not_calib.csv'))
        calib_df_comp = calib_df.copy()

        offset, comp_matrix = mag_calibration(calib_df)
        if not np.all(np.isnan(comp_matrix)) :
            optimal_success = True
            print("\n/!\\ Optimal Calibration was a Success /!\\\n")
            print("Magnetometer calibration results : \n")
            print("offset : ")
            print(offset)
            print("compensation Matrix : ")
            print(comp_matrix)

            calib_df_comp[Magneto] = compensate(calib_df_comp[Magneto].values, comp_matrix, offset)
            dataset[Magneto] = compensate(dataset[Magneto].values, comp_matrix, offset)


        else:
            optimal_success = False
            print("\n/!\\ Optimal Calibration was a Failure /!\\\n")
            print("Calibration with 4 parameters instead or 9")
            precal_center, magfield = sphereFit(calib_df[Magneto].values)
            print("center : ", precal_center)
            print("magfield",magfield)

            calib_df_comp[Magneto] = proj(calib_df_comp[Magneto].values, magfield, precal_center)
            dataset[Magneto] = proj(dataset[Magneto].values, magfield, precal_center)

        angles = get_angles(dataset)

        if args.save_plot:
            if optimal_success:
                title = 'Angles after OPTIMAL Magnetometer Calibration (9 params)'
                save_path = os.path.join(args.outdirpath, prefix + '_angles_mag_calib_optimal.png')
            else:
                title = 'Angles after Suboptimal Magnetometer Calibration (only 4 params)'
                save_path = os.path.join(args.outdirpath, prefix + '_angles_mag_calib_suboptimal.png')
            make_plot(dataset, angles, title, save_path)

        df_angles = pd.DataFrame(np.concatenate((dataset['time'].values.reshape(-1,1) , angles),axis=1), columns = ['time','Yaw', 'Pitch', 'Roll'])
        df_angles.to_csv(save_path.replace('.png', '.csv'))
        dataset.to_csv(save_path.replace('.png', '.csv').replace('angle', 'IMU'), index=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='This script requires Numpy, Matplotlib, Scipy, and Pandas. \n'
                                                 'This script turns raw data of 9axis IMU to Yaw Pitch and Roll Angles. '
                                                 'A calibration of the magnetometer can be done if calibration dataset is provided. '
                                                 'Input dataset(s) must be in .csv file format with exactly those columns labelled in it : ["time","a_x","a_y","a_z","g_x","g_y","g_z","MagX","MagY","MagZ"]. \n'
                                                 'The script saves angles in .csv format with one or two plotted figures (.png) showing the result without & with calibration of the magnetometer if it is done during the process. The saved files are located in a folder named "MPU_results\".')

    parser.add_argument('input', type=str, help='path of the .csv dataset file, with ["time","a_x","a_y","a_z","g_x","g_y","g_z","m_x","m_y","m_z"] columns')
    parser.add_argument('--outdirpath', type=str, default='Angle', help='Output directory path')
    parser.add_argument('--calibration', type=str,  help='path of the .csv CALIBRATION dataset file, with ["time","a_x","a_y","a_z","g_x","g_y","g_z","m_x","m_y","m_z"] columns. It should be the recording of the IMU data while tilting the board around every axis, and during a suficient amount of time to get lots of samples.')
    parser.add_argument('--save_plot', action='store_true', help='save the angles without magnetometer calibration as a png')

    args = parser.parse_args()

    calib_dataset_path = args.calibration if args.calibration else False
    dataset_path = args.input if args.input else False

    if not calib_dataset_path :
        print("\nWARNING : Magnetometer won't be calibrated because calibration dataset filepath was not provided.\n")
    main(args, dataset_path)
