"""
Helper functions for updating CLEV2ER LIG shared_dict for CRISTAL
"""

import numpy as np
from numpy.typing import NDArray


def update_shared_dict_validity_mask(shared_dict: dict, band: str, mask: NDArray[np.bool_]) -> None:
    """set ku,ka/valid/bool_mask_to_orig to False where input mask is False
    Also then updates ku,ka/valid/index_to_orig, and ku,ka/valid/num_valid

    Args:
        shared_dict (dict): chain shared dictionary
        band (str): band which is either 'ku' or 'ka'
        mask (np.ndarray[bool]): array of boolean to indicate validity of the
                                    measurements
    """
    if shared_dict[band]["valid"]["bool_mask_to_orig"].size != mask.size:
        raise ValueError(
            f"{band}/valid/bool_mask_to_orig "
            f'({shared_dict[band]["valid"]["bool_mask_to_orig"].size})'
            f" has different size to mask ({mask.size})"
        )

    shared_dict[band]["valid"]["bool_mask_to_orig"][~mask] = False

    shared_dict[band]["valid"]["indices_to_orig"] = np.where(
        shared_dict[band]["valid"]["bool_mask_to_orig"]
    )[0]
    shared_dict[band]["valid"]["num_valid"] = shared_dict[band]["valid"]["indices_to_orig"].size
