"""clev2er.utils.breakpoints.breakpoint_files.py

Functions to support writing of breakpoint files
"""

import logging
import os
from typing import Any, List, Union

import numpy as np
from netCDF4 import Dataset, Dimension  # pylint: disable=E0611

# pylint: disable=too-many-arguments
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
# pylint: disable=too-many-locals


def create_netcdf_file(file_path, data_dict):
    """Create a NetCDF4 file from contents of a dictionary

    Args:
        file_path (str): _description_
        data_dict (dict): dictionary containing 1 or more levels
    Returns:
        None
    """

    def preprocess_dict(
        data: Union[dict, list, np.ndarray, bool, Any],
    ) -> Union[dict, list, np.ndarray, Any]:
        """
        Recursively convert booleans in a dictionary or list to integers.

        This function traverses a dictionary, list, or numpy array, converting any boolean values
        (either scalar or arrays) to integers. The conversion turns `True` into `1` and `False`
        into `0`. Other data types are returned unchanged.

        Args:
            data (Union[dict, list, np.ndarray, bool, Any]): The input data which can be a
            dictionary, list, numpy array, or scalar value.
            The function recursively processes dictionaries and lists, converting booleans
            found at any level.

        Returns:
            Union[dict, list, np.ndarray, Any]: The processed data with all boolean values
            converted to integers.
            The structure of the input (dict, list, or numpy array) is preserved.

        Example:
            >>> preprocess_dict({'a': True, 'b': [False, np.array([True, False])]})
            {'a': 1, 'b': [0, array([1, 0])]}

        """
        if isinstance(data, dict):
            return {key: preprocess_dict(value) for key, value in data.items()}
        if isinstance(data, list):
            return [preprocess_dict(item) for item in data]
        if isinstance(data, np.ndarray) and data.dtype == "bool":
            return data.astype(int)  # Convert boolean arrays to integers
        if isinstance(data, bool):
            return int(data)  # Convert boolean scalars to integers (True -> 1, False -> 0)
        return data

    def create_variables(
        ncfile: Dataset,
        parent_key: str,
        data: dict[str, Any],
        dim_sizes: List[int],
        dim_names: List[str],
        dims: List[Dimension],
    ) -> None:
        """
        Recursively create NetCDF variables and attributes from a dictionary.

        This function processes a nested dictionary and creates corresponding groups, variables,
        and dimensions in a NetCDF4 file.
        The function handles arrays (1D, 2D, and 3D), scalar values, strings, and boolean arrays
        (which are converted to integers).

        Args:
            ncfile (netCDF4.Dataset): The open NetCDF file where variables and attributes will
            be written.
            parent_key (str): The key representing the parent group for nested dictionaries.
            This is used to create group hierarchies.
            data (Dict[str, Any]): A dictionary containing the data. The dictionary can be nested,
            and its keys will be used as variable or group names.
            dim_sizes (List[int]): A list that tracks the sizes of dimensions created for 1D, 2D,
            and 3D arrays.
            dim_names (List[str]): A list that tracks the names of dimensions created for 1D, 2D,
            and 3D arrays.
            dims (List[netCDF4.Dimension]): A list that stores the actual dimensions from the
            NetCDF file to avoid recreating them.

        Raises:
            ValueError: If the data contains an unsupported data type (i.e., types other than int,
            float, string, list, ndarray, or dict).

        Returns:
            None
        """
        for key, value in data.items():
            current_key = (
                f"{parent_key}_{key}" if parent_key else key
            )  # Include parent_key for nested levels

            # Convert lists to numpy arrays
            if isinstance(value, list):
                value = np.array(value)

            # Recursively handle nested dictionaries (subgroups)
            if isinstance(value, dict):
                subgroup = ncfile.createGroup(current_key)
                subgroup_dims: List[Dimension] = []
                create_variables(subgroup, "", value, dim_sizes, dim_names, subgroup_dims)

            elif isinstance(value, np.ndarray):
                # Skip empty arrays
                if value.size == 0:
                    continue

                # Handle boolean arrays by converting them to integers
                if value.dtype == "bool":
                    value = value.astype(int)

                # Handle 1D arrays
                if len(value.shape) == 1:
                    dim_size = value.shape[0]
                    dim_name = f"dim{len(dim_names)}"

                    # Create dimension if it doesn't exist
                    if dim_name not in ncfile.dimensions:
                        ncfile.createDimension(dim_name, dim_size)
                        dim_names.append(dim_name)
                        dims.append(ncfile.dimensions[dim_name])

                    dim = ncfile.dimensions[dim_name]

                    var = ncfile.createVariable(current_key, value.dtype.str, dimensions=(dim,))
                    var[:] = value

                # Handle 2D arrays
                elif len(value.shape) == 2:
                    dim_size1, dim_size2 = value.shape
                    dim_name1 = f"dim{len(dim_names)}"
                    dim_name2 = f"dim{len(dim_names) + 1}"

                    # Create dimension 1 if it doesn't exist
                    if dim_name1 not in ncfile.dimensions:
                        ncfile.createDimension(dim_name1, dim_size1)
                        dim_names.append(dim_name1)
                        dims.append(ncfile.dimensions[dim_name1])

                    # Create dimension 2 if it doesn't exist
                    if dim_name2 not in ncfile.dimensions:
                        ncfile.createDimension(dim_name2, dim_size2)
                        dim_names.append(dim_name2)
                        dims.append(ncfile.dimensions[dim_name2])

                    dim1 = ncfile.dimensions[dim_name1]
                    dim2 = ncfile.dimensions[dim_name2]

                    var = ncfile.createVariable(
                        current_key, value.dtype.str, dimensions=(dim1, dim2)
                    )
                    var[:] = value

                # Handle 3D arrays
                elif len(value.shape) == 3:
                    dim_size1, dim_size2, dim_size3 = value.shape
                    dim_name1 = f"dim{len(dim_names)}"
                    dim_name2 = f"dim{len(dim_names) + 1}"
                    dim_name3 = f"dim{len(dim_names) + 2}"

                    # Create dimensions if they don't exist
                    if dim_name1 not in ncfile.dimensions:
                        ncfile.createDimension(dim_name1, dim_size1)
                        dim_names.append(dim_name1)
                        dims.append(ncfile.dimensions[dim_name1])

                    if dim_name2 not in ncfile.dimensions:
                        ncfile.createDimension(dim_name2, dim_size2)
                        dim_names.append(dim_name2)
                        dims.append(ncfile.dimensions[dim_name2])

                    if dim_name3 not in ncfile.dimensions:
                        ncfile.createDimension(dim_name3, dim_size3)
                        dim_names.append(dim_name3)
                        dims.append(ncfile.dimensions[dim_name3])

                    dim1 = ncfile.dimensions[dim_name1]
                    dim2 = ncfile.dimensions[dim_name2]
                    dim3 = ncfile.dimensions[dim_name3]

                    var = ncfile.createVariable(
                        current_key, value.dtype.str, dimensions=(dim1, dim2, dim3)
                    )
                    var[:] = value

            # Handle scalar values (ints and floats)
            elif isinstance(value, (int, float)):
                ncfile.setncattr(current_key, value)

            # Handle strings
            elif isinstance(value, str):
                ncfile.setncattr(current_key, value)

            # Unsupported data types
            else:
                raise ValueError(
                    f"Unsupported data type for variable '{current_key}': {type(value)}"
                )

    with Dataset(file_path, "w") as ncfile:
        dim_sizes = []
        dim_names = []
        dims = []

        preproc_dict = preprocess_dict(data_dict)
        create_variables(ncfile, "", preproc_dict, dim_sizes, dim_names, dims)


def write_breakpoint_file(
    config: dict, shared_dict: dict, log: logging.Logger, breakpoint_alg_name: str
) -> str:
    """write a netcdf breakpoint file containing contents of
       shared dictionary

    Args:
        config (dict): chain config file
        shared_dict (dict): shared working dictionary
        log (logging.Logger): current logger instance to use
        breakpoint_alg_name (str): name of the algorithm after which the bp is set
    Returns:
        (str) : path of breakpoint file
    """

    # form breakpoint dir path

    breakpoint_dir = "/tmp"

    if "breakpoint_files" in config:
        breakpoint_dir = config["breakpoint_files"]["default_dir"]

    if "l1b_file_name" in shared_dict:
        filename = os.path.splitext(os.path.basename(shared_dict["l1b_file_name"]))[0]
        filename = f"{breakpoint_dir}/{filename}_bkp_{breakpoint_alg_name}.nc"
    else:
        filename = f"{breakpoint_dir}/breakpoint_{breakpoint_alg_name}.nc"
    log.info("breakpoint file: %s", filename)
    create_netcdf_file(filename, shared_dict)
    return filename
