# -*- coding: utf-8 -*-
"""
Created on Mon Jul  7 09:54:33 2025

@author: Charlotte.Kastoun
"""
#~~~~~~~~~~LIBRARIES~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

import urllib.request
import xarray as xr
import numpy as np
from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings('ignore')

import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D

#~~~~~~~~~~EXPLANATIONS~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# download_data: downloads the chl from ERDDAP using an ERDDAP-generated URL.
#                I had to download it in chunks of 5 years otherwise it would crash.

# create_ds: actually loads in the netcdf data into python as an xarray DataSet, 
#            then concatenates the DataSet for each chunk of 5 years into 1 ds.   

# examine_data: looks at the DataSet you just uploaded and created

# create_basicmap: creates a map of the chl data at a given timestep (preliminary visualization).
#                  adapted from ERDDAP tutorials on github 

# find_pixels: finds the nearest pixel given the lat and lon of a station. Then finds a 3x3
#              box around that pixel

# station_mean: calculates the spatial mean chl for the 3x3 box of pixels at that station for each timestep

# winter_weeks: isolates the data only for the first 15 weeks of the year

# plot_timeseries: plots a basic timeseries of chlorophyll with time (week and year) on x axis and chl on y

# plot_timeseries_weekly: creates a timeseries with week of year on x and chl on y. 
#                         each line is a diff color and represents each year.  

#~~~~~~~~~~FUNCTIONS~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def download_data():
    # downloading the data from 2020 - 2025
    url_20_25 = 'https://coastwatch.noaa.gov/erddap/griddap/noaacwecnOLCImultisensorCHLeastcoast7Day.nc?chlora%5B(2020-01-07):1:(2025-05-27T00:00:00Z)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_20_25, "chlorophyll20-25.nc") #actually downloads the data
    
    # downloading the data from 2016-2020
    url_16_20 = 'https://coastwatch.noaa.gov/erddap/griddap/noaacwecnOLCImultisensorCHLeastcoast7Day.nc?chlora%5B(2016-04-28):1:(2019-12-31)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_16_20, "chlorophyll16-20.nc") #actually downloads the data

def create_ds():
    ds_chlora_20_25 = xr.open_dataset('chlorophyll20-25.nc', decode_cf=True) # loads the netcdf data into python and extract variables of interest
    ds_chlora_16_20 = xr.open_dataset('chlorophyll16-20.nc', decode_cf=True) # loads the netcdf data into python and extract variables of interest

    ds_chlora = xr.concat([ds_chlora_16_20, ds_chlora_20_25], dim="time")

    return ds_chlora
    
def examine_data(dataset):
    print('The coordinates variables:', list(dataset.coords), '\n')    
    print('The data variables:', list(dataset.data_vars), '\n')
    print('Shape/Structure of chlora variable:')
    print(dataset.chlora.shape)
    print('\n Dates for each time step: ')
    print(dataset.time)
    
    # examine if the latitude coordinate variable is in ascending or descending order
    # by looking at the first and last values in the lat array
    # relevant for when you need to slice (subset) the Dataset later
    print('First latitude value', dataset.latitude[0].item())
    print('Last latitude value', dataset.latitude[-1].item())
    
def create_basicmap(data, timestep):
    min_chlora = np.nanmin(data.chlora[timestep])
    max_chlora = np.nanmax(data.chlora[timestep])
    
    print("minimum chlorophyll value: ", min_chlora)
    print("maximum chlorophyll value: ", max_chlora)
    
    #creating custom colormap based on minimum and maximum chlorophyll value
    levs = np.arange(np.floor(min_chlora), np.floor(max_chlora)+1, 0.05)
    num_colors = len(levs)
    
    colors = plt.cm.plasma(np.linspace(0, 1, num_colors))[::-1]
    cm = LinearSegmentedColormap.from_list("chlorophyll_cmap", colors, N=num_colors)
    
    # plot chlorophyll map
    plt.contourf(data.longitude, 
                 data.latitude, 
                 data.chlora[timestep, :, :], 
                 levs,
                 cmap=cm)
    
    # Plot the colorbar
    plt.colorbar()
    
    # Annotation: Add a point for each station
    plt.scatter(-76.0587, 38.5807, c='black') #ET5.2
    plt.scatter(-76.598, 38.1576, c='black') #LE2.2
    plt.scatter(-76.52254, 39.21309, c='black') #WT5.1
    plt.scatter(-76.38634, 37.24181, c='black') #WE4.2
    plt.scatter(-76.78889, 37.50869, c='black') #RET4.3
    
    # Annotation: Example of how to add a contour line (did not help here lol)
   # plt.contour(data.longitude, 
    #            data.latitude, 
     #           data.chlora[0, :, :], 
      #          levels=[14],
       #         linewidths=1)
    
    # Add a title
    plt.title("Chlorophyll-a at first timestep " 
              + data.time[timestep].dt.strftime('%b %Y').item())
    plt.show()

def find_pixels(station, lat, lon, data):
    # finds the index of the nearest lats and lons to the provided ones from each station
    # does this by extracting the list of lats and lons, subtracting the provided lat/lon value, and then finding which one has the smallest difference
    
    lat_index = np.abs(data.latitude.values - lat).argmin() # finds the index of the nearest existing latitude in the dataset
    lon_index = np.abs(data.longitude.values - lon).argmin() # finds the index of the nearest existing longitude in the dataset

    # actually finds the nearest lat/lon value
    #nearest_lat = data.latitude.values[lat_index]
    #nearest_lon = data.longitude.values[lon_index]
    
    if (lat_index > 0) and (lon_index > 0):
        if (lat_index < len(data.latitude.values)-1) and (lon_index < len(data.longitude.values)-1):
            station_data = data.sel(latitude=data.latitude.values[lat_index-1:lat_index+2],
                                    longitude=data.longitude.values[lon_index-1:lon_index+2])
            print(station+' data successfully selected')
        else:
            print('lat or lon out of upper bounds')
    else:
        print('lat or lon out of lower bounds')
    
    return station_data

def station_mean(station, data):
    station_mean = data.mean(dim=['latitude', 'longitude'], skipna=True)
    print(station+' spatial mean successfully found')
    return station_mean

def winter_weeks(station, data):
    winter_data = data.sel(time=data.time.dt.week.values<15)
    print(station + ' winter weeks successfully isolated')
    return winter_data

def plot_timeseries(station, data):
    fig, ax = plt.subplots(figsize=(12, 6), dpi=300)
    plot = ax.plot(data.time, data.chlora)

    ax.set_ylabel('Chlorophyll [units?]')
    ax.set_xlabel('Date')
    
    ax.set_title(station + 'Weekly Mean Chlorophyll Timeseries')
    
    plt.show()
    
    return(plot)
   
def plot_timeseries_weekly(station, data, season):
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # get unique years
    years = data.time.dt.year.values
    unique_years = np.unique(years)
    
    # create custom colormap
    n_colors = len(unique_years)
    colors = plt.cm.plasma(np.linspace(0, 1, n_colors))[::-1]
    cmap = LinearSegmentedColormap.from_list('years_cmap', colors, N=n_colors)
    
    # plot a line for each year
    for year in unique_years:
        year_data = data.sel(time=str(year))
        color = cmap( (year - unique_years.min()) / (unique_years.max()-unique_years.min()) )
        
        weeks = year_data.time.dt.week.values.flatten()
        
        ax.plot(year_data.time.dt.week, year_data.chlora.values.flatten(), color=color, alpha = 0.7)
    
    ax.set_xlabel('Week')
    ax.set_ylabel('Chlorophyll [units?]')
    ax.set_title('Weekly ' + str(season) + ' average chlorophyll by year at ' + str(station)) #later add tributary name and state
    
    legend_elements=[Line2D([0], [0], color=cmap( (year - unique_years.min()) / (unique_years.max()-unique_years.min()) ),
                            lw=4, label=f'{year}') for year in unique_years]
    ax.legend(handles=legend_elements, title='Years')
    
    plt.show()

#~~~~~~~~~~CODE~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
sns.set()
plt.rcParams['font.family'] = 'serif'

#download_data()
ds_chlora = create_ds()
ds_chlora = ds_chlora.drop_duplicates(dim='time', keep='first')

examine_data(ds_chlora)
# result: latitudes are in descending order (largest first)

#create_basicmap(ds_chlora, 0)

ET52_data = find_pixels('ET5.2', 38.5807, -76.0587, ds_chlora)
LE22_data = find_pixels('LE2.2', 38.1576, -76.598, ds_chlora)
WT51_data = find_pixels('WT5.1', 39.21309, -76.52254, ds_chlora)
WE42_data = find_pixels('WE4.2', 37.24181, -76.38634, ds_chlora)
RET43_data = find_pixels('RET4.3', 37.50869, -76.78889, ds_chlora)

print('\n')

ET52_mean = station_mean('ET5.2', ET52_data)
LE22_mean = station_mean('LE2.2', LE22_data)
WT51_mean = station_mean('WT5.1', WT51_data)
WE42_mean = station_mean('WE4.2', WE42_data)
RET43_mean = station_mean('RET4.3', RET43_data)

print('\n')

ET52_means_winter = winter_weeks('ET5.2', ET52_mean)
LE22_means_winter = winter_weeks('LE2.2', LE22_mean)
WT51_means_winter = winter_weeks('WT5.1', WT51_mean)
WE42_means_winter = winter_weeks('WE4.2', WE42_mean)
RET43_means_winter = winter_weeks('RET4.3', RET43_mean)


plot_timeseries('ET5.2', ET52_mean)
plot_timeseries('LE2.2', LE22_mean)
plot_timeseries('WT5.1', WT51_mean)
plot_timeseries('WE4.2', WE42_mean)
plot_timeseries('RET4.3', RET43_mean)

plot_timeseries_weekly('ET5.2', ET52_mean, 'allyear')
plot_timeseries_weekly('LE2.2', LE22_mean, 'allyear')
plot_timeseries_weekly('WT5.1', WT51_mean, 'allyear')
plot_timeseries_weekly('WE4.2', WE42_mean, 'allyear')
plot_timeseries_weekly('RET4.3', RET43_mean, 'allyear')

plot_timeseries_weekly('ET5.2', ET52_means_winter, 'winter')
plot_timeseries_weekly('LE2.2', LE22_means_winter, 'winter')
plot_timeseries_weekly('WT5.1', WT51_means_winter, 'winter')
plot_timeseries_weekly('WE4.2', WE42_means_winter, 'winter')
plot_timeseries_weekly('RET4.3', RET43_means_winter, 'winter')

