"""clev2er.utils.cristal.reloc_sarin.reloc_sarin.py
class containing functions to calculate the location and elevation based upon the range
and angle of arrival.
"""
import numpy as np
from pyproj import Proj, Transformer


def geodetic_to_ecef(lat, lon, alt):
    """
    Converts geodetic coordinates (latitude, longitude, altitude) to ECEF
    (Earth-Centered Earth-Fixed) coordinates.

    Args:
        lat (array-like): Latitudes in degrees, shape (n,).
        lon (array-like): Longitudes in degrees, shape (n,).
        alt (array-like): Altitudes in meters, shape (n,).

    Returns:
        tuple:
            np.ndarray: X-coordinates in ECEF, shape (n,).
            np.ndarray: Y-coordinates in ECEF, shape (n,).
            np.ndarray: Z-coordinates in ECEF, shape (n,).
    """
    ecef = Proj(proj="geocent", ellps="WGS84", datum="WGS84")
    lla = Proj(proj="latlong", ellps="WGS84", datum="WGS84")

    trans = Transformer.from_proj(lla, ecef, always_xy=True)

    x, y, z = trans.transform(xx=lon, yy=lat, zz=alt, radians=False)  # pylint: disable=E0633

    return x, y, z


def ecef_to_geodetic(x, y, z):
    """djb to document

    Args:
        np.ndarray: X-coordinates in ECEF, shape (n,).
        np.ndarray: Y-coordinates in ECEF, shape (n,).
        np.ndarray: Z-coordinates in ECEF, shape (n,).

    Returns:
        tuple:
            lat (array-like): Latitudes in degrees, shape (n,).
            lon (array-like): Longitudes in degrees, shape (n,).
            alt (array-like): Altitudes in meters, shape (n,).
    """
    ecef = Proj(proj="geocent", ellps="WGS84", datum="WGS84")
    lla = Proj(proj="latlong", ellps="WGS84", datum="WGS84")

    trans = Transformer.from_proj(ecef, lla, always_xy=True)

    lon, lat, height = trans.transform(xx=x, yy=y, zz=z, radians=False)  # pylint: disable=E0633

    lon = lon % 360

    return lat, lon, height


def calculate_crf_points(base_vector_filt, echo_to_bsa_filt, r_filt):
    """
    Calculate the CRF (Coordinate Reference Frame) points by solving the intersection of
    a circle and a plane.

    This method computes the X and Z coordinates for the CRF points based on the filtered
    baseline vector (`base_vector_filt`), the echo to boresight angle (`echo_to_bsa_filt`),
    and the radial distance (`r_filt`). It solves the corresponding circle and plane
    intersection equations and selects the appropriate solution based on the larger
    X-coordinate.

    Args:
        base_vector_filt (numpy.ndarray): The baseline vector used in the calculation,
                                            shape (n, 3). Each row represents a 3D vector
                                            (x, y, z).
        echo_to_bsa_filt (numpy.ndarray): The echo to boresight angle, shape (n,).
                                                    Each element contains the angle in radians.
        r_filt (numpy.ndarray): The radial distance from the origin, shape (n,).
                                A 1D array containing the radial distances for each point.

    Returns:
        numpy.ndarray: The calculated CRF points, shape (n, 3). Each row represents a CRF point
                        with X, Y, and Z coordinates.
    """

    # Validate all inputs are numpy arrays
    for name, array in [
        ("base_vector_filt", base_vector_filt),
        ("echo_to_bsa_filt", echo_to_bsa_filt),
        ("r_filt", r_filt),
    ]:
        if not isinstance(array, np.ndarray):
            raise ValueError(f"{name} must be a numpy array.")
        if not np.issubdtype(array.dtype, np.number):
            raise ValueError(f"{name} must contain numeric values.")

    # Shape validation: Ensure that all arrays have the same length (n)
    n = base_vector_filt.shape[0]
    if echo_to_bsa_filt.shape[0] != n:
        raise ValueError(
            "echo_to_bsa_filt must have the same number of elements as base_vector_filt."
        )
    if r_filt.shape[0] != n:
        raise ValueError("r_filt must have the same number of elements as base_vector_filt.")

    if base_vector_filt.shape[1] != 3:
        raise ValueError(
            "base_vector_filt must have shape (n, 3), where n is the number of points."
        )

    # Calculate CRF Centres in Cartesian Coordinates
    crf_center_coordinates = base_vector_filt * (r_filt * np.sin(echo_to_bsa_filt))[:, np.newaxis]

    # Calculate A and B parameters

    circle_plane_param_a = (r_filt**2) * (np.cos(echo_to_bsa_filt) ** 2) - crf_center_coordinates[
        :, 1
    ] ** 2
    circle_plane_param_b = base_vector_filt[:, 1] * crf_center_coordinates[:, 1]

    crf_points = np.zeros_like(crf_center_coordinates)

    # Extract components of baseline_vector
    base_vec_x = base_vector_filt[:, 0]
    base_vec_z = base_vector_filt[:, 2]

    # Calculate intermediate values
    d_xplus = 2.0 * circle_plane_param_b * base_vec_x / np.power(base_vec_z, 2.0)

    d_befosqrt = (
        4.0
        * np.power(circle_plane_param_b, 2.0)
        * np.power(base_vec_x, 2.0)
        / np.power(base_vec_z, 4.0)
    )
    d_befosqrt -= (
        4.0
        * (np.power(base_vec_x, 2.0) / np.power(base_vec_z, 2.0) + 1.0)
        * (np.power(circle_plane_param_b, 2.0) / np.power(base_vec_z, 2.0) - circle_plane_param_a)
    )

    d_xplus += np.sqrt(d_befosqrt)
    d_xplus /= 2.0 * (np.power(base_vec_x, 2.0) / np.power(base_vec_z, 2.0) + 1.0)

    d_xminus = 2.0 * circle_plane_param_b * base_vec_x / np.power(base_vec_z, 2.0)
    d_xminus -= np.sqrt(d_befosqrt)
    d_xminus /= 2.0 * (np.power(base_vec_x, 2.0) / np.power(base_vec_z, 2.0) + 1.0)

    # Compute x and z coordinates for both solutions
    d_zplus = (circle_plane_param_b - base_vec_x * d_xplus) / base_vec_z
    d_zminus = (circle_plane_param_b - base_vec_x * d_xminus) / base_vec_z

    d_xxplus = d_xplus + crf_center_coordinates[:, 0]
    d_xxminus = d_xminus + crf_center_coordinates[:, 0]

    d_zzplus = d_zplus + crf_center_coordinates[:, 2]
    d_zzminus = d_zminus + crf_center_coordinates[:, 2]

    # Choose solution based on the larger x-coordinate
    mask_plus = d_xxplus > d_xxminus

    crf_points[:, 0] = np.where(mask_plus, d_xxplus, d_xxminus)
    crf_points[:, 2] = np.where(mask_plus, d_zzplus, d_zzminus)

    return crf_points


def compute_poca_geodetic_coords(crf_axes, crf_points, efc_cogs):
    """
    Compute the geodetic coordinates (latitude, longitude, elevation) of POCA points
    for multiple waveforms.

    Args:
        crf_axes (np.ndarray): Array of shape (n, 3, 3), where each 3x3 matrix represents a
        CRF axis for a waveform.

        crf_points (np.ndarray): Array of shape (n, 3), where each row represents a point in
        the CRF.

        efc_cogs (np.ndarray): Array of shape (n, 3), where each row represents the ECF center
        of gravity.

    Returns:
        tuple: Three arrays for latitude, longitude, and elevation of POCA points,
        each of shape (n,).
    """
    # Validate input arrays
    for name, array, expected_shape in [
        ("crf_axes", crf_axes, (None, 3, 3)),
        ("crf_points", crf_points, (None, 3)),
        ("efc_cogs", efc_cogs, (None, 3)),
    ]:
        if not isinstance(array, np.ndarray):
            raise ValueError(f"{name} must be a numpy array.")
        if array.ndim != len(expected_shape) or array.shape[1:] != expected_shape[1:]:
            raise ValueError(f"{name} must have shape {expected_shape}.")
        if not np.issubdtype(array.dtype, np.number):
            raise ValueError(f"{name} must contain numeric values.")

    if crf_axes.shape[0] != crf_points.shape[0] or crf_axes.shape[0] != efc_cogs.shape[0]:
        raise ValueError("Input arrays must have compatible dimensions (same number of waveforms).")

    # Rotate CRF points into ECF frame for all waveforms
    efc_vecs = np.einsum("nij,nj->ni", crf_axes, crf_points)  # matrix-vector multiplication

    # Translate to ECF coordinates
    efc_points = efc_vecs + efc_cogs

    # Convert to geodetic coordinates (lat, lon, elev) for all waveforms
    lat_poca, lon_poca, elev_poca = ecef_to_geodetic(
        efc_points[:, 0], efc_points[:, 1], efc_points[:, 2]
    )

    return lat_poca, lon_poca, elev_poca


def get_crf_in_efc(lon, lat, alt, vel_vec):
    """
    Calculate the Satellite Coordinate Frame (CRF) axes in the Earth-Centered Fixed (ECF) frame
    for multiple waveforms.

    Args:
        lon (np.ndarray): Longitudes of the waveforms, shape (n,).
        lat (np.ndarray): Latitudes of the waveforms, shape (n,).
        alt (np.ndarray): Altitudes of the waveforms, shape (n,).
        vel_vec (np.ndarray): Velocity vectors in ECF coordinates, shape (n, 3).

    Returns:
        tuple:
            np.ndarray: CRF axes matrices, shape (n, 3, 3).
            np.ndarray: Satellite positions in ECF coordinates, shape (n, 3).
    """

    # Ensure inputs are arrays
    lon = np.asarray(lon)
    lat = np.asarray(lat)
    alt = np.asarray(alt)
    vel_vec = np.asarray(vel_vec)

    # Check that the arrays contain numeric types
    if not (
        np.issubdtype(lon.dtype, np.number)
        and np.issubdtype(lat.dtype, np.number)
        and np.issubdtype(alt.dtype, np.number)
        and np.issubdtype(vel_vec.dtype, np.number)
    ):
        raise TypeError("All inputs must contain numeric values")

    # Validate shapes
    if vel_vec.shape[1] != 3:
        raise ValueError("vel_vec must have shape (n, 3)")

    n = lon.shape[0]
    if not lat.shape[0] == alt.shape[0] == n:
        raise ValueError("lon, lat, and alt must have the same length")

    # Initialize outputs
    crf_axes = np.zeros((n, 3, 3))
    efc_cog = np.zeros((n, 3))

    # Convert Geodetic to ECF for nadir (alt = 0) and satellite position (alt = alt)
    nad = np.zeros((n, 3))
    nad[:, 0], nad[:, 1], nad[:, 2] = geodetic_to_ecef(lat, lon, np.zeros_like(lat))
    efc_cog[:, 0], efc_cog[:, 1], efc_cog[:, 2] = geodetic_to_ecef(lat, lon, alt)

    # Compute nadir vectors
    sat_nad_vec = nad - efc_cog
    sat_nad_norm = np.linalg.norm(sat_nad_vec, axis=1, keepdims=True)

    # Compute Axis 1 (normalized nadir vector)
    ad_crf_axis1 = sat_nad_vec / sat_nad_norm

    # Normalize velocity vectors
    vel_vec_norm = np.linalg.norm(vel_vec, axis=1, keepdims=True)
    ad_efc_nv = vel_vec / vel_vec_norm  # Normalized velocity vector

    # Compute Axis 2 (orthogonal component of velocity vector to Axis 1)
    scal_prod = np.sum(ad_crf_axis1 * ad_efc_nv, axis=1, keepdims=True)
    ad_temp_vect = ad_efc_nv - scal_prod * ad_crf_axis1
    ad_crf_axis2 = ad_temp_vect / np.linalg.norm(ad_temp_vect, axis=1, keepdims=True)

    # Compute Axis 3 (cross product of Axis 2 and Axis 1)
    ad_crf_axis3 = np.cross(ad_crf_axis2, ad_crf_axis1)

    # Assign axes to CRF matrix
    crf_axes[:, :, 0] = ad_crf_axis1  # First column is Axis 1
    crf_axes[:, :, 1] = ad_crf_axis2  # Second column is Axis 2
    crf_axes[:, :, 2] = ad_crf_axis3  # Third column is Axis 3

    # Return CRF axes and ECF satellite positions
    return crf_axes, efc_cog


def reloc_sarin(
    base_vector_filt, echo_to_bsa_filt, r_filt, lon_filt, lat_filt, alt_filt, velocity_vector_filt
):
    """
    This function processes the CRF points, satellite CRF orientation, and computes
    the Poca geodetic coordinates.

    Parameters:
    base_vector_filt (array): Filtered base vector
    echo_to_bsa_filt (array): Filtered echo to BSA vector
    r_filt (array): Filtered range vector
    lon_filt (array): Filtered longitude
    lat_filt (array): Filtered latitude
    alt_filt (array): Filtered altitude
    velocity_vector_filt (array): Filtered velocity vector

    Returns:
    lat_poca, lon_poca, elev_poca (tuple): Geodetic coordinates of Poca
    """
    # ------------ Calculate CRF Points (Circle-Plane Intersection) --------
    crf_points = calculate_crf_points(base_vector_filt, echo_to_bsa_filt, r_filt)

    # ------------ Calculate Satellite CRF Orientation in ECF Frame --------
    crf_axes, efc_cogs = get_crf_in_efc(lon_filt, lat_filt, alt_filt, velocity_vector_filt)

    # ------------ Calculate Poca (Point of Closest Approach) Geodetic Coordinates --------
    lat_poca, lon_poca, elev_poca = compute_poca_geodetic_coords(crf_axes, crf_points, efc_cogs)

    return lat_poca, lon_poca, elev_poca
