#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Create a dilated Antarctic mask
    - bedmachine v2 grounded + floating 500m grid mask
    - dilated by 10km
    - saved as TBD format
"""

import argparse
import datetime
import sys

import numpy as np
from netCDF4 import Dataset  # pylint: disable=E0611
from scipy import ndimage

from clev2er.utils.masks.masks import Mask

# pylint: disable=too-many-statements
# pylint: disable=too-many-locals


def main() -> None:
    """main function for tool"""

    # ----------------------------------------------------------------------
    # Process Command Line Arguments for tool
    # ----------------------------------------------------------------------

    # initiate the command line parser
    parser = argparse.ArgumentParser()

    # add each argument

    parser.add_argument(
        "--antarctica",
        "-a",
        help=(""),
        required=False,
        action="store_const",
        const=1,
    )

    parser.add_argument(
        "--greenland",
        "-g",
        help=(""),
        required=False,
        action="store_const",
        const=1,
    )

    parser.add_argument(
        "--distance",
        "-d",
        help=("dilation distance in km"),
        required=False,
        default=10,
        type=int,
    )

    # read arguments from the command line
    args = parser.parse_args()

    if not args.antarctica and not args.greenland:
        sys.exit("Must have either -a or -g")

    # Read Antarctic Bedmachine Mask

    if args.antarctica:
        thismask = Mask("antarctica_bedmachine_v2_grid_mask")
    else:
        thismask = Mask("greenland_bedmachine_v3_grid_mask")

    # Grid dimensions
    num_x = thismask.num_x
    num_y = thismask.num_y

    # Grid metadata
    min_x = thismask.minxm  # meters
    min_y = thismask.minym  # meters
    binsize = thismask.binsize  # meters

    # Create a NetCDF file
    output_file_name = "ant_dilated_grid_mask.nc"
    if args.greenland:
        output_file_name = "grn_dilated_grid_mask.nc"
    ds = Dataset(output_file_name, "w", format="NETCDF4")

    # Create dimensions
    ds.createDimension("x", num_x)
    ds.createDimension("y", num_y)

    # Create variables
    x = ds.createVariable("x", "f4", ("x",))
    x.units = "m"
    x.standard_name = "distance"
    x.long_name = "grid locations in x direction"

    y = ds.createVariable("y", "f4", ("y",))
    y.units = "m"
    y.standard_name = "distance"
    x.long_name = "grid locations in y direction"

    mask = ds.createVariable(
        "mask",
        "i1",
        (
            "x",
            "y",
        ),
        zlib=True,
        complevel=5,
        shuffle=True,
    )  # Use 'i1' for byte (8-bit integer) to save space
    mask.units = "none"
    mask.description = "grid of boolean values to indicate inside mask (1) or outside mask(0)"

    # Set variable data (coordinates)
    x[:] = np.arange(min_x, min_x + num_x * binsize, binsize)
    y[:] = np.arange(min_y, min_y + num_y * binsize, binsize)

    # Here you can set your mask values as needed
    # For example, initializing all to False (0)
    mask_data = np.zeros((num_x, num_y), dtype="i1")

    # Fill mask grid from bedmachine data

    # Antarctic mask values:= ocean(0) ice_free_land(1) grounded_ice(2) floating_ice(3)
    # Lake Vostok(4)
    # Greenland mask values:= ocean(0) ice_free_land(1) grounded_ice(2) floating_ice(3)
    # non-Greenland land(4)

    print("setting mask values..")
    if mask_data.shape != thismask.mask_grid.shape:
        print("Transposing..")
        thismask.mask_grid = np.transpose(thismask.mask_grid)

    mask_data = np.where(thismask.mask_grid == 1, 1, mask_data)  # ice-free land
    mask_data = np.where(thismask.mask_grid == 2, 1, mask_data)  # grounded ice
    mask_data = np.where(thismask.mask_grid == 3, 1, mask_data)  # floating ice
    if args.antarctica:
        mask_data = np.where(thismask.mask_grid == 4, 1, mask_data)  # Vostok

    # Calculate the dilation size in grid cells
    dilation_distance_km = args.distance  # km
    dilation_distance_meters = dilation_distance_km * 1000  # convert km to meters
    grid_cell_size_meters = binsize  # binsize, or the size of each grid cell in meters
    dilation_size_in_cells = round(dilation_distance_meters / grid_cell_size_meters)

    # Create a structuring element for dilation
    structuring_element = ndimage.generate_binary_structure(
        2, 1
    )  # 2D, with connectivity 1 (4-connectivity)

    structuring_element = ndimage.iterate_structure(
        structuring_element, dilation_size_in_cells
    ).astype(np.bool_)

    # Dilate the mask
    print("dilating..")

    dilated_mask = ndimage.binary_dilation(mask_data, structure=structuring_element)

    print("done..")

    # Now, 'dilated_mask' contains your dilated binary mask.
    # You can then assign this back to the NetCDF variable if you're updating the file:
    mask[:] = dilated_mask.astype("i1")

    # Add global attributes (optional)
    if args.antarctica:
        ds.description = (
            "Dilated Antarctic Polar Stereographic Grid Mask of Grounded"
            " and Floating Ice in EPSG:3031"
        )
    else:
        ds.description = ds.description = (
            "Dilated Greenland Polar Stereographic Grid Mask of Grounded"
            " and Floating Ice in EPSG:3413"
        )

    current_datetime = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    ds.history = "Created " + current_datetime
    ds.projection = "epsg:3031"
    ds.min_x = min_x
    ds.min_y = min_y
    ds.posting = binsize
    ds.dilation_distance_km = args.distance
    if args.antarctica:
        ds.source = (
            "Bedmachine v2 surface type: grounded, floating ice, ice-free land, dilated by 10km"
        )
    else:
        ds.source = (
            "Bedmachine v3 surface type: grounded, floating ice, ice-free land, dilated by 10km"
        )
    ds.processing_centre = "MSSL-UCL"

    # Close the file to write on disk
    ds.close()

    print("NetCDF file created successfully.")


if __name__ == "__main__":
    main()
