"""
Module: pytest of clev2er.utils.cristal.waveform_quality.waveform_qc_checks

This module contains pytest functions to test various waveform quality control (QC) checks 
including noise power, waveform power, variance, peakiness, and coherence. Each function tests
different aspects of waveform QC using the corresponding QC functions from the `waveform_qc_checks` 
module.

---------------------------------------------------------------------------------------------------
test_noise_power
---------------------------------------------------------------------------------------------------
test_noise_power(waveform_name): Tests noise_power_test on predefined waveforms, comparing results 
to thresholds.

test_noise_power_empty(): Tests noise_power_test on an empty waveform, expecting a ValueError.

test_noise_power_invalid_params(waveform, noise_startgate, noise_endgate, noise_hi): Tests invalid 
params, expecting IndexError or ValueError.

test_noise_power_type_error(waveform, noise_startgate, noise_endgate, noise_hi): Tests with 
incorrect data types, expecting an exception.

---------------------------------------------------------------------------------------------------
test_waveform_power
---------------------------------------------------------------------------------------------------
test_waveform_power(waveform_name): Tests waveform_power_test on predefined waveforms, comparing 
power threshold results.

test_waveform_power_empty(): Tests waveform_power_test on an empty waveform, expecting a ValueError.

test_waveform_power_invalid_params(waveform, power_startgate, power_endgate, power_lo, 
noise_startgate, noise_endgate): Tests invalid params, expecting IndexError or ValueError.

test_waveform_power_type_error(waveform, power_startgate, power_endgate, power_lo, noise_startgate, 
noise_endgate): Tests with incorrect data types, expecting an exception.

---------------------------------------------------------------------------------------------------
test_waveform_variance
---------------------------------------------------------------------------------------------------
test_variance(waveform_name): Tests waveform_variance_test on predefined waveforms, checking 
variance threshold results.

test_waveform_variance_empty(): Tests waveform_variance_test on an empty waveform, expecting a 
ValueError.

test_waveform_variance_invalid_params(waveform, power_startgate, power_endgate, n_samples, 
variance_hi, expected_exception): Tests invalid params, expecting IndexError or ValueError.

test_waveform_variance_type_error(waveform, power_startgate, power_endgate, n_samples, 
variance_hi): Tests with incorrect data types, expecting an exception.

---------------------------------------------------------------------------------------------------
test_peakiness
---------------------------------------------------------------------------------------------------
test_peakiness(waveform_name): Tests peakiness_test on predefined waveforms, comparing peakiness 
thresholds.

test_peakiness_empty_waveform(): Tests peakiness_test on an empty waveform, expecting a ValueError.

test_peakiness_invalid_params(waveform, n_samples, trackingpoint, peakiness_lo, peakiness_hi, 
expected_exception): Tests invalid params, expecting IndexError or ValueError.

test_peakiness_type_error(waveform, n_samples, trackingpoint, peakiness_lo, peakiness_hi): Tests 
with incorrect data types, expecting an exception.
"""

import numpy as np
import pytest

from clev2er.utils.cristal.waveform_quality.waveform_qc_checks import (
    coherence_test,
    noise_power_test,
    peakiness_test,
    waveform_power_test,
    waveform_qc_test,
    waveform_variance_test,
)

# pylint: disable=too-many-arguments
# pylint: disable=missing-function-docstring


pytestmark = pytest.mark.non_core

# Shared parameters for all tests
POWER_STARTGATE = 0
POWER_ENDGATE = 5
NOISE_STARTGATE = 0
NOISE_ENDGATE = 3
TRACKING_POINT = 2

# Define the waveforms and their associated parameters
waveform_params = {
    "waveform_1": {
        "waveform": np.array([0.5, 0.5, 1.0, 1.0, 0.5, 0.5]),
        "coherence": np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
        "noise_hi": [(0.0, False), (0.6, False), (0.7, True), (1.0, True)],
        "power_lo": [(5.0, False), (1.1, False), (0.0, True), (0.9, True)],
        "variance_hi": [(0.0, False), (0.04, False), (1.0, True), (0.06, True)],
        "coherence_lo": [(1.0, False), (0.65, False), (0, True), (0.55, True)],
        "peakiness_lo_hi": [
            (3, 5, False),
            (1.1, 5, False),
            (0, 0.5, False),
            (3, 0.9, False),
            (0, 5, True),
            (0.9, 5, True),
            (0, 5, True),
            (0, 1.1, True),
        ],
    },
    "waveform_2": {
        "waveform": np.array([1, 1, 4, 4, 2, 2]),
        "coherence": np.array([0.9, 0.8, 0.7, 0.6, 0.5, 0.3]),
        "noise_hi": [(0.0, False), (0.4, False), (0.6, True), (1, True)],
        "power_lo": [(5.0, False), (1.2, False), (0.0, True), (1.0, True)],
        "variance_hi": [(0.0, False), (0.09, False), (0.099, True), (1.0, True)],
        "coherence_lo": [(1.0, False), (0.95, False), (0, True), (0.85, True)],
        "peakiness_lo_hi": [
            (5, 100, False),
            (1.2, 100, False),
            (0, 100, True),
            (1, 100, True),
            (0, 0.5, False),
            (0, 1, False),
            (0, 100, True),
            (0, 1.2, True),
        ],
    },
}

# Define parameters for the waveform_qc_test
qc_params = [
    {
        "waveforms": np.array(
            [waveform_params["waveform_1"]["waveform"], waveform_params["waveform_2"]["waveform"]]
        ),
        "coherence": np.array(
            [waveform_params["waveform_1"]["coherence"], waveform_params["waveform_2"]["coherence"]]
        ),
        "noise_hi": 0.4,  # Both waveforms return FALSE
        "power_lo": 2,  # Both waveforms return FALSE
        "variance_hi": 0,  # Both waveforms return FALSE
        "peakiness_lo": 2,  # Both waveforms return FALSE
        "peakiness_hi": 0.5,  # Both waveforms return FALSE
        "coherence_lo": 1,  # Both waveforms return FALSE
        "expected_bool": [False, False],  # Both should be FALSE
        "expected_qc_results": np.array([1, 1]),  # all tests return FALSE
    },
    {
        "waveforms": np.array(
            [waveform_params["waveform_1"]["waveform"], waveform_params["waveform_2"]["waveform"]]
        ),
        "coherence": np.array(
            [waveform_params["waveform_1"]["coherence"], waveform_params["waveform_2"]["coherence"]]
        ),
        "noise_hi": 3,  # Both waveforms return TRUE
        "power_lo": 0.5,  # Both waveforms return TRUE
        "variance_hi": 2,  # Both waveforms return TRUE
        "peakiness_lo": 0.5,  # Both waveforms return TRUE
        "peakiness_hi": 2,  # Both waveforms return TRUE
        "coherence_lo": 0.2,  # Both waveforms return TRUE
        "expected_bool": [True, True],  # Both should be TRUE
        "expected_qc_results": [0, 0],  # all tests return TRUE
    },
    {
        "waveforms": np.array(
            [waveform_params["waveform_1"]["waveform"], waveform_params["waveform_2"]["waveform"]]
        ),
        "coherence": np.array(
            [waveform_params["waveform_1"]["coherence"], waveform_params["waveform_2"]["coherence"]]
        ),
        "noise_hi": 3,  # both TRUE
        "power_lo": 1.05,  # w1 returns FALSE w2 returns TRUE
        "variance_hi": 2,  # both TRUE
        "peakiness_lo": 0.5,  # both TRUE
        "peakiness_hi": 2,  # both TRUE
        "coherence_lo": 0.2,  # both TRUE
        "expected_bool": [False, True],  # w1 returns FALSE w2 returns TRUE
        "expected_qc_results": [2, 0],  # w1 returns FALSE for POWER_TEST
    },
    {
        "waveforms": np.array(
            [waveform_params["waveform_2"]["waveform"], waveform_params["waveform_1"]["waveform"]]
        ),
        "coherence": np.array(
            [waveform_params["waveform_2"]["coherence"], waveform_params["waveform_1"]["coherence"]]
        ),
        "noise_hi": 3,  # w2 returns FALSE w1 returns TRUE
        "power_lo": 1.05,  # both TRUE
        "variance_hi": 2,  # both TRUE
        "peakiness_lo": 0.5,  # both TRUE
        "peakiness_hi": 2,  # both TRUE
        "coherence_lo": 0.2,  # both TRUE
        "expected_bool": [True, False],  # w1 returns TRUE w2 returns FALSE
        "expected_qc_results": [0, 2],  # w2 returns FALSE for NOISE_TEST
    },
]


@pytest.mark.parametrize("waveform_name", list(waveform_params.keys()))
def test_noise_power(waveform_name):
    """Test noise power QC using noise_power_test."""
    params = waveform_params[waveform_name]
    waveform = params["waveform"]

    # Test noise power
    for noise_hi, expected in params["noise_hi"]:
        result = noise_power_test(waveform, NOISE_STARTGATE, NOISE_ENDGATE, noise_hi)
        assert result == expected, (
            f"Noise power test failed for waveform: {waveform} "
            f"with noise threshold: {noise_hi}. Expected: {expected}, Got: {result}"
        )


def test_noise_power_empty():
    with pytest.raises(ValueError):
        noise_power_test(np.array([]), 0, 1, 0.3)


@pytest.mark.parametrize(
    "waveform, noise_startgate, noise_endgate, noise_hi",
    [
        (np.array([1, 4, 5, 2, 7, 8]), -1, 1, 0.3),
        (np.array([1, 4, 5, 2, 7, 8]), 2, 20, 0.3),
        (np.array([1, 4, 5, 2, 7, 8]), 2, 6, 0.3),
        (np.array([1, 4, 5, 2, 7, 8]), 6, 2, 0.3),
        (np.array([1, 4, 5, 2, 7, 8]), 2, 4, -20),
    ],
)
def test_noise_power_invalid_params(waveform, noise_startgate, noise_endgate, noise_hi):
    with pytest.raises((IndexError, ValueError)):
        noise_power_test(waveform, noise_startgate, noise_endgate, noise_hi)


@pytest.mark.parametrize(
    "waveform, noise_startgate, noise_endgate, noise_hi",
    [
        ("not an array", 1, 2, 0.3),
        (np.array([1, 4, 5, 2, 7, 8]), "start", 2, 0.3),
        (np.array([1, 4, 5, 2, 7, 8]), 1, "end", 0.3),
        (np.array([1, 4, 5, 2, 7, 8]), 1, 2, "high"),
    ],
)
def test_noise_power_type_error(waveform, noise_startgate, noise_endgate, noise_hi):
    with pytest.raises(Exception):
        noise_power_test(waveform, noise_startgate, noise_endgate, noise_hi)


@pytest.mark.parametrize("waveform_name", list(waveform_params.keys()))
def test_waveform_power(waveform_name):
    """Test waveform power QC using waveform_power_test."""
    params = waveform_params[waveform_name]
    waveform = params["waveform"]

    # Test waveform power
    for power_lo, expected in params["power_lo"]:
        result = waveform_power_test(
            waveform, POWER_STARTGATE, POWER_ENDGATE, power_lo, NOISE_STARTGATE, NOISE_ENDGATE
        )
        assert result == expected, (
            f"Waveform power test failed for waveform: {waveform} with power threshold:"
            f" {power_lo}. Expected: {expected}, Got: {result}"
        )


def test_waveform_power_empty():
    with pytest.raises(ValueError):
        waveform_power_test(np.array([]), 0, 1, 0.3, 0, 1)


@pytest.mark.parametrize(
    "waveform, power_startgate, power_endgate, power_lo, noise_startgate, noise_endgate",
    [
        (np.array([1, 4, 5, 2, 7, 8]), -1, 1, 0.3, 0, 1),
        (np.array([1, 4, 5, 2, 7, 8]), 2, 20, 0.3, 0, 1),
        (np.array([1, 4, 5, 2, 7, 8]), 2, 6, 0.3, 0, 1),
        (np.array([1, 4, 5, 2, 7, 8]), 6, 2, 0.3, 0, 1),
        (np.array([1, 4, 5, 2, 7, 8]), 2, 4, -1.5, 0, 1),
        (np.array([1, 4, 5, 2, 7, 8]), 0, 1, 0.3, -1, 1),
        (np.array([1, 4, 5, 2, 7, 8]), 2, 3, 0.3, 0, 50),
        (np.array([1, 4, 5, 2, 7, 8]), 2, 4, 0.3, 0, 6),
        (np.array([1, 4, 5, 2, 7, 8]), 2, 4, 0.3, 5, 1),
    ],
)
def test_waveform_power_invalid_params(
    waveform, power_startgate, power_endgate, power_lo, noise_startgate, noise_endgate
):
    with pytest.raises((IndexError, ValueError)):
        waveform_power_test(
            waveform, power_startgate, power_endgate, power_lo, noise_startgate, noise_endgate
        )


@pytest.mark.parametrize(
    "waveform, power_startgate, power_endgate, power_lo, noise_startgate, noise_endgate",
    [
        ("not an array", 1, 2, 0.3, 1, 2),
        (np.array([1, 4, 5, 2, 7, 8]), "start", 2, 0.3, 1, 2),
        (np.array([1, 4, 5, 2, 7, 8]), 1, "end", 0.3, 1, 2),
        (np.array([1, 4, 5, 2, 7, 8]), 1, 2, "low", 1, 2),
        (np.array([1, 4, 5, 2, 7, 8]), 1, 2, 0.3, "noise_start", 2),
        (np.array([1, 4, 5, 2, 7, 8]), 1, 2, 0.3, 1, "noise_end"),
    ],
)
def test_waveform_power_type_error(
    waveform, power_startgate, power_endgate, power_lo, noise_startgate, noise_endgate
):
    with pytest.raises(Exception):
        waveform_power_test(
            waveform, power_startgate, power_endgate, power_lo, noise_startgate, noise_endgate
        )


@pytest.mark.parametrize("waveform_name", list(waveform_params.keys()))
def test_variance(waveform_name):
    """Test variance QC using waveform_variance_test."""
    params = waveform_params[waveform_name]
    waveform = params["waveform"]

    # Test variance
    for variance_hi, expected in params["variance_hi"]:
        result = waveform_variance_test(
            waveform, POWER_STARTGATE, POWER_ENDGATE, len(waveform), variance_hi
        )
        assert result == expected, (
            f"Variance test failed for waveform: {waveform} "
            f"with variance threshold: {variance_hi}. Expected: {expected}, "
            f"Got: {result}"
        )


def test_waveform_variance_empty():
    with pytest.raises(ValueError):
        waveform_variance_test(np.array([]), 0, 1, 1, 0.5)


@pytest.mark.parametrize(
    "waveform, power_startgate, power_endgate, n_samples, variance_hi, expected_exception",
    [
        (np.array([1, 4, 5, 2, 7, 8]), -1, 2, 5, 0.5, IndexError),  # Invalid start gate
        (np.array([1, 4, 5, 2, 7, 8]), 2, 10, 5, 0.5, IndexError),  # Invalid end gate
        (np.array([1, 4, 5, 2, 7, 8]), 6, 2, 5, 0.5, IndexError),  # Start gate > end gate
        (np.array([1, 4, 5, 2, 7, 8]), 2, 4, -5, 0.5, ValueError),  # Negative number of samples
        (
            np.array([1, 4, 5, 2, 7, 8]),
            2,
            4,
            10,
            0.5,
            ValueError,
        ),  # Number of samples > waveform size
        (np.array([1, 4, 5, 2, 7, 8]), 2, 4, 0, 0.5, ValueError),  # Zero number of samples
        (np.array([1, 4, 5, 2, 7, 8]), 2, 4, 5, -0.5, ValueError),  # Negative variance threshold
    ],
)
def test_waveform_variance_invalid_params(
    waveform, power_startgate, power_endgate, n_samples, variance_hi, expected_exception
):
    with pytest.raises(expected_exception):
        waveform_variance_test(waveform, power_startgate, power_endgate, n_samples, variance_hi)


@pytest.mark.parametrize(
    "waveform, power_startgate, power_endgate, n_samples, variance_hi",
    [
        ("not an array", 0, 2, 5, 0.5),
        (np.array([1, 4, 5, 2, 7, 8]), "start", 2, 5, 0.5),
        (np.array([1, 4, 5, 2, 7, 8]), 1, "end", 5, 0.5),
        (np.array([1, 4, 5, 2, 7, 8]), 1, 2, "samples", 0.5),
        (np.array([1, 4, 5, 2, 7, 8]), 1, 2, 5, "threshold"),
    ],
)
def test_waveform_variance_type_error(
    waveform, power_startgate, power_endgate, n_samples, variance_hi
):
    with pytest.raises(Exception):
        waveform_variance_test(waveform, power_startgate, power_endgate, n_samples, variance_hi)


@pytest.mark.parametrize("waveform_name", list(waveform_params.keys()))
def test_peakiness(waveform_name):
    """Test peakiness QC using peakiness_test."""
    params = waveform_params[waveform_name]
    waveform = params["waveform"]

    # Test peakiness
    for peakiness_lo, peakiness_hi, expected in params["peakiness_lo_hi"]:
        result = peakiness_test(waveform, len(waveform), TRACKING_POINT, peakiness_lo, peakiness_hi)
        assert result == expected, (
            f"Peakiness test failed for waveform: {waveform} "
            f"with thresholds: {peakiness_lo}, {peakiness_hi}. Expected: {expected}, "
            f"Got: {result}"
        )


def test_peakiness_empty_waveform():
    with pytest.raises(ValueError):
        peakiness_test(np.array([]), 1, 0, 0.5, 1.5)


@pytest.mark.parametrize(
    "waveform, n_samples, trackingpoint, peakiness_lo, peakiness_hi, expected_exception",
    [
        (np.array([1, 2, 3, 4, 5, 6]), -1, 0, 0.5, 1.5, ValueError),  # Negative number of samples
        (
            np.array([1, 2, 3, 4, 5, 6]),
            10,
            0,
            0.5,
            1.5,
            ValueError,
        ),  # Number of samples > waveform size
        (np.array([1, 2, 3, 4, 5, 6]), 6, 7, 0.5, 1.5, IndexError),  # Tracking point out of bounds
        (np.array([1, 2, 3, 4, 5, 6]), 6, -1, 0.5, 1.5, IndexError),  # Negative tracking point
        (
            np.array([1, 2, 3, 4, 5, 6]),
            6,
            0,
            -0.5,
            1.5,
            ValueError,
        ),  # Negative lower peakiness threshold
        (
            np.array([1, 2, 3, 4, 5, 6]),
            6,
            0,
            0.5,
            -1.5,
            ValueError,
        ),  # Negative upper peakiness threshold
    ],
)
def test_peakiness_invalid_params(
    waveform, n_samples, trackingpoint, peakiness_lo, peakiness_hi, expected_exception
):
    with pytest.raises(expected_exception):
        peakiness_test(waveform, n_samples, trackingpoint, peakiness_lo, peakiness_hi)


@pytest.mark.parametrize(
    "waveform, n_samples, trackingpoint, peakiness_lo, peakiness_hi",
    [
        ("not an array", 5, 0, 0.5, 1.5),  # waveform is not an array
        (np.array([1, 2, 3, 4, 5, 6]), "samples", 0, 0.5, 1.5),  # n_samples is not an int
        (np.array([1, 2, 3, 4, 5, 6]), 6, "tracking", 0.5, 1.5),  # trackingpoint is not an int
        (np.array([1, 2, 3, 4, 5, 6]), 6, 0, "low threshold", 1.5),  # peakiness_lo is not a float
        (np.array([1, 2, 3, 4, 5, 6]), 6, 0, 0.5, "high threshold"),  # peakiness_hi is not a float
    ],
)
def test_peakiness_type_error(waveform, n_samples, trackingpoint, peakiness_lo, peakiness_hi):
    with pytest.raises(Exception):
        peakiness_test(waveform, n_samples, trackingpoint, peakiness_lo, peakiness_hi)


@pytest.mark.parametrize("waveform_name", list(waveform_params.keys()))
def test_coherence(waveform_name):
    """Test coherence QC using coherence_test."""
    params = waveform_params[waveform_name]
    coherence = params["coherence"]

    # Test coherence
    for coherence_lo, expected in params["coherence_lo"]:
        result = coherence_test(coherence, coherence_lo)
        assert result == expected, (
            f"Coherence test failed for coherence: {coherence} "
            f"with coherence threshold: {coherence_lo}. Expected: {expected}, Got: {result}"
        )


def test_coherence_empty_array():
    with pytest.raises(ValueError):
        coherence_test(np.array([]), 0.5)


@pytest.mark.parametrize(
    "coherence, coherence_lo, expected_exception",
    [
        (np.array([0.1, 0.3, 0.5, 0.7]), -0.5, ValueError),  # Negative coherence threshold
    ],
)
def test_coherence_invalid_params(coherence, coherence_lo, expected_exception):
    with pytest.raises(expected_exception):
        coherence_test(coherence, coherence_lo)


@pytest.mark.parametrize(
    "coherence, coherence_lo",
    [
        ("not an array", 0.5),  # coherence is not an array
        (np.array([0.1, 0.2, 0.3]), "low"),  # coherence_lo is not a float
    ],
)
def test_coherence_type_error(coherence, coherence_lo):
    with pytest.raises(Exception):
        coherence_test(coherence, coherence_lo)


@pytest.mark.parametrize("qc_set", qc_params)
def test_waveform_qc(qc_set):
    """Comprehensive test for the waveform_qc_test function."""
    bool_result, qc_results = waveform_qc_test(
        noise_endgate=NOISE_ENDGATE,
        noise_hi=qc_set["noise_hi"],
        noise_startgate=NOISE_STARTGATE,
        peakiness_hi=qc_set["peakiness_hi"],
        peakiness_lo=qc_set["peakiness_lo"],
        power_endgate=POWER_ENDGATE,
        power_lo=qc_set["power_lo"],
        power_startgate=POWER_STARTGATE,
        trackingpoint=TRACKING_POINT,
        variance_hi=qc_set["variance_hi"],
        waveforms=qc_set["waveforms"],
        waveforms_scaled=qc_set["waveforms"],
        coherence=qc_set["coherence"],
        coherence_lo=qc_set["coherence_lo"],
    )

    # Assert that the output matches the expected results for waveforms_ok
    assert (bool_result == qc_set["expected_bool"]).all(), (
        f"Waveform QC test failed for parameters: {qc_set}. "
        f"Expected: {qc_set['expected_bool']}, Got: {bool_result}"
    )

    # Assert that the output matches the expected results for waveform_qc_results
    assert (qc_results == qc_set["expected_qc_results"]).all(), (
        f"Waveform QC test results do not match for parameters: {qc_set}. "
        f"Expected QC Results: {qc_set['expected_qc_results']}, Got: {qc_results}"
    )


def test_waveform_qc_empty_waveforms():
    result, qc_results = waveform_qc_test(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, np.array([]), np.array([]))
    assert np.array_equal(result, np.array([]).astype(bool))
    assert np.array_equal(qc_results, np.array([]).astype(int))
