"""clev2er.utils.cristal.waveform_quality.waveform_qc_checks.py
class to perform waveform qc checks
"""
# pylint: disable=too-many-locals
# pylint: disable=too-many-arguments
# pylint: disable=too-many-statements

import numpy as np

# Bit positions for waveform quality control tests
NOISE_TEST = 1
POWER_TEST = 2
VARIANCE_TEST = 3
PEAKINESS_TEST = 4
COHERENCE_TEST = 5


def normalize_waveform(waveform):
    """
    Normalize the input waveform to a range of [0, 1].

    This function scales the waveform so that the maximum value is 1. If the maximum
    value is 0, the function returns the original waveform to avoid division by zero.

    Args:
        waveform (numpy.ndarray): 1D array of waveform data to be normalized.

    Returns:
        numpy.ndarray: Normalized waveform where the maximum value is 1, or the
        original waveform if the maximum value is 0.
    """
    max_value = np.max(waveform)
    if max_value == 0:
        return waveform  # Avoid division by zero
    return waveform / max_value


def noise_power_test(waveform, noise_startgate, noise_endgate, noise_hi):
    """
    Determines if a waveform fails the noise power test by comparing the average noise power
    to a specified high threshold.

    Args:
        waveform (numpy.ndarray): 2D array of waveform data.
        noise_startgate (int): Start index for noise measurement.
        noise_endgate (int): End index for noise measurement.
        noise_hi (float): Threshold for excessive noise power.

    Returns:
        tuple:
            - bool: False if the average noise power exceeds the threshold, True otherwise.
    """
    if not waveform.size:
        raise ValueError("No waveform samples provided")
    if noise_hi < 0:
        raise ValueError("noise threshold cannot be negative")
    if noise_startgate < 0 or noise_startgate > noise_endgate or noise_endgate >= waveform.size:
        raise IndexError("Check the start and end gate indices.")
    try:
        w_norm = normalize_waveform(waveform)
        p_noise = np.mean(w_norm[noise_startgate:noise_endgate])
        return not p_noise > noise_hi

    except TypeError as e:
        raise TypeError("Check data types of inputs.") from e


def waveform_power_test(
    waveform, power_startgate, power_endgate, power_lo, noise_startgate, noise_endgate
):
    """
    Tests if the average power of a waveform within a specified range is lower
    than a given threshold relative to the noise power.

    Args:
        waveform (numpy.ndarray): 2D array of waveform data.
        power_startgate (int): Start index for power measurement.
        power_endgate (int): End index for power measurement.
        power_lo (float): Minimum acceptable power-to-noise ratio.
        noise_startgate (int): Start index for noise measurement.
        noise_endgate (int): End index for noise measurement.

    Returns:
        bool: False if the average waveform power divided by the noise power is below the
        threshold, True otherwise.
    """
    if not waveform.size:
        raise ValueError("No waveform samples provided")
    if power_lo < 0:
        raise ValueError("power ratio threshold cannot be negative")
    if power_startgate < 0 or power_startgate > power_endgate or power_endgate >= waveform.size:
        raise IndexError("Check the power start and end gate indices.")
    if noise_startgate < 0 or noise_startgate > noise_endgate or noise_endgate >= waveform.size:
        raise IndexError("Check the noise start and end gate indices.")
    try:
        p_waveform = np.mean(waveform[power_startgate:power_endgate])
        p_noise = np.mean(waveform[noise_startgate:noise_endgate])
        epsilon = 1e-10  # Small value to prevent division by zero
        return not (p_waveform / (p_noise + epsilon)) < power_lo

    except ZeroDivisionError as e:
        raise ZeroDivisionError("Noise power may be zero.") from e

    except TypeError as e:
        raise TypeError("Check data types of inputs.") from e


def waveform_variance_test(waveform, power_startgate, power_endgate, n_samples, variance_hi):
    """
    Tests if the variance of a waveform within a specified range exceeds a given threshold.

    Args:
        waveform (numpy.ndarray): 2D array of waveform data.
        power_startgate (int): Start index for power measurement.
        power_endgate (int): End index for power measurement.
        n_samples (int): Number of samples used for variance calculation.
        variance_hi (float): Threshold above which the variance is considered high.

    Returns:
        bool: False if the variance of the waveform is greater than the threshold, True otherwise.
    """
    if not waveform.size:
        raise ValueError("No waveform samples provided")
    if n_samples <= 0:
        raise ValueError("number of samples cannot be negative")
    if n_samples > waveform.size:
        raise ValueError("number of samples cannot be larger than waveform size")
    if variance_hi < 0:
        raise ValueError("variance threshold cannot be negative")
    if power_startgate < 0 or power_startgate > power_endgate or power_endgate >= waveform.size:
        raise IndexError("Check the start and end gate indices.")

    try:
        waveform = waveform.astype(np.float64)
        w_norm = normalize_waveform(waveform)
        p_waveform = np.mean(w_norm[power_startgate:power_endgate])
        var_waveform = (
            np.sum((w_norm[power_startgate:power_endgate] - p_waveform) ** 2)
        ) / n_samples
        return not var_waveform > variance_hi

    except TypeError as e:
        raise TypeError("Check data types of inputs.") from e


def peakiness_test(waveform, n_samples, trackingpoint, peakiness_lo, peakiness_hi):
    """
    Evaluates the peakiness of a waveform by comparing its peak-to-mean ratio
    to a specified range of acceptable values.

    Args:
        waveform (numpy.ndarray): 2D array of waveforms, with each row representing a waveform.
        n_samples (int): Total number of samples per waveform.
        trackingpoint (int): Index used for calculating peakiness.
        peakiness_lo (float): Minimum acceptable peakiness value.
        peakiness_hi (float): Maximum acceptable peakiness value.

    Returns:
        bool: False if the calculated peakiness is outwith the acceptable range, True otherwise.
    """
    if not waveform.size:
        raise ValueError("No waveform samples provided")
    if n_samples <= 0:
        raise ValueError("number of samples cannot be negative")
    if n_samples > waveform.size:
        raise ValueError("number of samples cannot be larger than waveform size")
    if trackingpoint < 0 or trackingpoint > waveform.size:
        raise IndexError("trackingpoint cannot be outside the waveform")
    if peakiness_lo < 0 or peakiness_hi < 0:
        raise ValueError("peakiness thresholds cannot be negative")

    try:
        w_norm = normalize_waveform(waveform)
        peak_value = np.max(w_norm[:n_samples])
        mean_value = np.mean(w_norm)
        epsilon = 1e-10  # Small value to prevent division by zero
        peakiness = (n_samples - trackingpoint) * (peak_value / (mean_value * n_samples + epsilon))
        return not (peakiness < peakiness_lo or peakiness > peakiness_hi)

    except TypeError as e:
        raise TypeError("Check data types of inputs.") from e

    except ZeroDivisionError as e:
        raise ZeroDivisionError("mean_value might be zero.") from e


def coherence_test(coherence, coherence_lo):
    """
    Tests if the maximum value of the coherence is less than the specified threshold.

    Args:
        coherence (numpy.ndarray): Array of coherence values.
        K_coherence_lo (float): Minimum acceptable coherence threshold.

    Returns:
        bool: False if max(coherence) < coherence_lo, True otherwise.
    """
    if coherence.size == 0:
        raise ValueError("No coherence samples provided")
    if coherence_lo < 0:
        raise ValueError("Coherence threshold cannot be negative")

    max_coherence = np.max(coherence)
    return not max_coherence < coherence_lo


def waveform_qc_test(
    noise_endgate,
    noise_hi,
    noise_startgate,
    peakiness_hi,
    peakiness_lo,
    power_endgate,
    power_lo,
    power_startgate,
    trackingpoint,
    variance_hi,
    waveforms,
    waveforms_scaled,
    coherence=None,
    coherence_lo=None,
):
    """
    Performs a series of quality control tests on multiple waveforms to determine if they pass
    various criteria related to noise power, waveform power, variance, and peakiness.

    Args:
        noise_endgate (int): End index for noise measurement.
        noise_hi (float): Threshold for high noise power.
        noise_startgate (int): Start index for noise measurement.
        peakiness_hi (float): Maximum acceptable peakiness.
        peakiness_lo (float): Minimum acceptable peakiness.
        power_endgate (int): End index for power measurement.
        power_lo (float): Minimum acceptable power-to-noise ratio.
        power_startgate (int): Start index for power measurement.
        trackingpoint (int): Index for peakiness calculation.
        variance_hi (float): Threshold above which the variance is considered high.
        waveform (numpy.ndarray): 2D array of waveforms with waveform power in Watts to test.

        coherence (numpy.ndarray): 2D array of interferometric coherence with shape
                    (time, samples_ov) #optional, only used for Ku Band in Interferometric Mode.
        coherence_lo(float): Minimum coherence threshold #optional, only used for Ku Band in
                      Interferometric Mode.

    Returns:
        waveforms_ok (numpy.ndarray) : A boolean array where each index indicates
                    whether the corresponding waveform passed all quality control tests.
                    False means the waveform is unsuitable, True means the waveform meets
                    QC requirements.
    """
    if not waveforms.size:
        return np.array([]).astype(bool), np.array([]).astype(int)
    n_waveforms, n_samples = np.shape(waveforms)
    waveforms_ok = np.ones(n_waveforms, dtype=bool)
    waveform_qc_results = np.zeros(n_waveforms, dtype=int)

    for i, waveform in enumerate(waveforms):
        if bool(noise_power_test(waveform, noise_startgate, noise_endgate, noise_hi)) is False:
            waveforms_ok[i] = False
            waveform_qc_results[i] = NOISE_TEST  # Set noise test
            continue

        if (
            bool(
                waveform_power_test(
                    waveforms_scaled[i],
                    power_startgate,
                    power_endgate,
                    power_lo,
                    noise_startgate,
                    noise_endgate,
                )
            )
            is False
        ):
            waveforms_ok[i] = False
            waveform_qc_results[i] = POWER_TEST  # Set power test
            continue

        if (
            bool(
                waveform_variance_test(
                    waveform,
                    power_startgate,
                    power_endgate,
                    n_samples,
                    variance_hi,
                )
            )
            is False
        ):
            waveforms_ok[i] = False
            waveform_qc_results[i] = VARIANCE_TEST  # Set variance test
            continue

        if (
            bool(
                peakiness_test(
                    waveform,
                    n_samples,
                    trackingpoint,
                    peakiness_lo,
                    peakiness_hi,
                )
            )
            is False
        ):
            waveforms_ok[i] = False
            waveform_qc_results[i] = PEAKINESS_TEST  # Set peakiness test
            continue

        if coherence is not None and coherence_lo is not None:
            if coherence.size == 0:
                raise ValueError("No coherence samples provided")
            if coherence_test(coherence[i], coherence_lo) is False:
                waveforms_ok[i] = False
                waveform_qc_results[i] = COHERENCE_TEST
                continue

    return waveforms_ok, waveform_qc_results
