# -*- coding: utf-8 -*-
"""
Created on Mon Jul  7 11:11:39 2025

@author: Charlotte.Kastoun
"""

#~~~~~~~~~~LIBRARIES~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

import urllib.request
import xarray as xr
import numpy as np
from matplotlib import pyplot as plt
from scipy import stats

import warnings
warnings.filterwarnings('ignore')

import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D


#~~~~~~~~~~EXPLANATIONS~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# download_data: downloads the SST from ERDDAP using an ERDDAP-generated URL.
#                I had to download it in chunks of 5 years otherwise it would crash.

# download_clim: downloads the sst climatology from ERDDAP using an ERDDAP-generated URL

# create_ds: actually loads in the netcdf data into python as an xarray DataSet, 
#            then concatenates the DataSet for each chunk of 5 years into 1 ds.   

# create_clim_ds: same as create_ds but for the climatology data

# examine_data: looks at the DataSet you just uploaded and created

# create_basicmap: creates a map of the sst data at a given timestep (preliminary visualization).
#                  adapted from ERDDAP tutorials on github 

# find_pixels: finds the nearest pixel given the lat and lon of a station. Then finds a 3x3
#              box around that pixel

# calc_station_mean: calculates the spatial mean sst for the 3x3 box of pixels at that station for each timestep

# isolate_winter_weeks: isolates the data only for the first 15 weeks of the year

# isolate_winter_weeks_clim: same as isolate_winter_weeks but for the climatology sst data

# find_anom: finds the anomaly for each timestep comparing the sst data and the sst climatology data

# plot_timeseries: plots a timeseries with week and year on x axis, SST on y axis. 
#                  Includes a best fit line for the timeseries. Prints out basic stats using scipy.
#                  Can use this to also plot a timeseries of the anomaly rather than the raw data.

# plot_timeseries_weekly: creates a timeseries with week of year on x and sst on y. 
#                         each line is a diff color and represents each year.  

# plot_winter_anom_by_year: plots the anomaly for each year with one subplot for each year

#~~~~~~~~~~FUNCTIONS~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def download_data():
    
    # Downloading the data locally using the ERDDAP-generated URL
    
    #downloading the data from 2006-2010
    url_06_10 ='https://coastwatch.noaa.gov/erddap/griddap/noaacwecnAVHRRVIIRSmultisensorSSTeastcoast7Day.nc?sst%5B(2006-12-06T14:00:00Z):1:(2010-01-04T01:45:00Z)%5D%5B(0.0):1:(0.0)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_06_10, "sst06-10.nc") #actually downloads the data
    
    #downloading the data from 2010-2015
    url_10_15 ='https://coastwatch.noaa.gov/erddap/griddap/noaacwecnAVHRRVIIRSmultisensorSSTeastcoast7Day.nc?sst%5B(2010-01-04T01:45:00Z):1:(2015-11-29T04:00:00Z)%5D%5B(0.0):1:(0.0)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_10_15, "sst10-15.nc") #actually downloads the data
    
    #downloading the data from 2015-2020
    url_15_20 ='https://coastwatch.noaa.gov/erddap/griddap/noaacwecnAVHRRVIIRSmultisensorSSTeastcoast7Day.nc?sst%5B(2015-11-29T04:00:00Z):1:(2020-04-04T04:20:00Z)%5D%5B(0.0):1:(0.0)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_15_20, "sst15-20.nc") #actually downloads the data
    
    # downloading the data from 2020 - 2025
    url_20_25 ='https://coastwatch.noaa.gov/erddap/griddap/noaacwecnAVHRRVIIRSmultisensorSSTeastcoast7Day.nc?sst%5B(2020-04-04T04:20:00Z):1:(2025-05-31T04:09:59Z)%5D%5B(0.0):1:(0.0)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_20_25, "sst20-25.nc") #actually downloads the data

def download_clim():
    url_clim = 'https://coastwatch.noaa.gov/erddap/griddap/noaacwecnAVHRRVIIRSmultiSSTeastcoast7DayClimatol.nc?sst%5B(1.0):1:(52.0)%5D%5B(0.0):1:(0.0)%5D%5B(40.002202068418484):1:(36.99809976542181)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_clim, 'sst-clim.nc')
   
def create_ds():
    ds_sst_06_10 = xr.open_dataset('sst06-10.nc', decode_cf=True) # loads the netcdf data into python and extract variables of interest
    ds_sst_10_15 = xr.open_dataset('sst10-15.nc', decode_cf=True) # loads the netcdf data into python and extract variables of interest
    ds_sst_15_20 = xr.open_dataset('sst15-20.nc', decode_cf=True) # loads the netcdf data into python and extract variables of interest
    ds_sst_20_25 = xr.open_dataset('sst20-25.nc', decode_cf=True) # loads the netcdf data into python and extract variables of interest
    
    ds_all = xr.concat([ds_sst_06_10, ds_sst_10_15, ds_sst_15_20, ds_sst_20_25], dim="time")
    
    return ds_all

def create_clim_ds():
    ds_clim = xr.open_dataset('sst-clim.nc', decode_cf=True)
    return ds_clim

def examine_data(dataset):
    print('The coordinates variables:', list(dataset.coords), '\n')    
    print('The data variables:', list(dataset.data_vars), '\n')
    print('Shape/Structure of SST variable:')
    print(dataset.sst.shape)
   # print('\n Dates for each time step: ')
   # print(dataset.time)
    
    # examine if the latitude coordinate variable is in ascending or descending order
    # by looking at the first and last values in the lat array
    # relevant for when you need to slice (subset) the Dataset later
    print('First latitude value', dataset.latitude[0].item())
    print('Last latitude value', dataset.latitude[-1].item())
    
def create_basicmap(data, timestep):
    min_sst = np.nanmin(data.sst[timestep])
    max_sst = np.nanmax(data.sst[timestep])
    
    print("minimum sst value: ", min_sst)
    print("maximum sst value: ", max_sst)
    
    #creating custom colormap based on minimum and maximum chlorophyll value
    levs = np.arange(np.floor(min_sst), np.floor(max_sst)+1, 0.05)
    num_colors = len(levs)
    
    colors = plt.cm.plasma(np.linspace(0, 1, num_colors))[::-1]
    cm = LinearSegmentedColormap.from_list("sst_cmap", colors, N=num_colors)
    
    
    # plot chlorophyll map
    plt.contourf(data.longitude, 
                 data.latitude, 
                 data.sst[timestep, 0, :, :], 
                 levs,
                 cmap=cm)
    
    # Plot the colorbar
    plt.colorbar()
    
    # Annotation: Add a point for each station
    plt.scatter(-76.0587, 38.5807, c='black') #ET5.2
    plt.scatter(-76.598, 38.1576, c='black') #LE2.2
    plt.scatter(-76.52254, 39.21309, c='black') #WT5.1
    plt.scatter(-76.38634, 37.24181, c='black') #WE4.2
    plt.scatter(-76.78889, 37.50869, c='black') #RET4.3
    
    # Annotation: Example of how to add a contour line (did not help here lol)
   # plt.contour(data.longitude, 
    #            data.latitude, 
     #           data.chlora[0, :, :], 
      #          levels=[14],
       #         linewidths=1)
    
    # Add a title
    plt.title("SST at first timestep " 
              + data.time[timestep].dt.strftime('%b %Y').item())
    plt.show()

def find_pixels(station, lat, lon, data):
    # finds the index of the nearest lats and lons to the provided ones from each station
    # does this by extracting the list of lats and lons, subtracting the provided lat/lon value, and then finding which one has the smallest difference
    
    lat_index = np.abs(data.latitude.values - lat).argmin() # finds the index of the nearest existing latitude in the dataset
    lon_index = np.abs(data.longitude.values - lon).argmin() # finds the index of the nearest existing longitude in the dataset

    # actually finds the nearest lat/lon value
    #nearest_lat = data.latitude.values[lat_index]
    #nearest_lon = data.longitude.values[lon_index]
    
    if (lat_index > 0) and (lon_index > 0):
        if (lat_index < len(data.latitude.values)-1) and (lon_index < len(data.longitude.values)-1):
            station_data = data.sel(latitude=data.latitude.values[lat_index-1:lat_index+2],
                                    longitude=data.longitude.values[lon_index-1:lon_index+2])
            print(station+' data successfully selected')
        else:
            print('lat or lon out of upper bounds')
    else:
        print('lat or lon out of lower bounds')

    return station_data

def calc_station_mean(station, data):
    station_mean = data.mean(dim=['latitude', 'longitude'], skipna=True)
    print(station+' spatial mean successfully found')
    return station_mean

def isolate_winter_weeks(station, data):
    winter_data = data.sel(time=data.time.dt.week.values<15)
    print(station + ' winter weeks successfully isolated')
    return winter_data

def isolate_winter_weeks_clim(station, data):
    winter_data = data.sel(sevenDayPeriodOfYear = data.sevenDayPeriodOfYear<15)
    print(station + ' climatology winter weeks successfully isolated')
    return winter_data

def find_anom(sst_data, clim_data):
    weeks=sst_data.time.dt.isocalendar().week #creates a dataarray of week #s corr. to each time value in the original dataset ([1,2,3...,52,1,2,...])
    anom = sst_data - clim_data.sel(sevenDayPeriodOfYear=weeks) #looks at the current val it's examining for sst_data and then finds the corr. one in climatology based on the weeks array above
    return anom

def plot_timeseries(station, data):
    fig, ax = plt.subplots(figsize=(12, 6), dpi=500)
    plot = ax.plot(data.time, data.sst, color='#d5b7d8', alpha=0.9)
    
    x = data.sst_time_index
    y = data.sst.values.flatten()
    
    mask = np.isfinite(x) & np.isfinite(y)
    x_clean = x[mask]
    y_clean = y[mask]
    
    x_forPlot = data.time
    x_forPlot_clean = x_forPlot[mask]


    slope, intercept, r_value, p_value, std_err = stats.linregress(
        x_clean, y_clean)
    bestfit = slope*x_clean + intercept
    ax.plot(x_forPlot_clean, bestfit, color='#d85b95', linewidth=2)
    
    print('\nSTATS')
    print('Slope: ' + str(slope))
    print('Intercept: '+ str(intercept))
    print('r value: '+ str(r_value))
    print('p value: ' + str(p_value))
    print('Standard deviation error: '+ str(std_err))
    print('\n')
    print('~~~~~~~~~~~~~~~~~~~~~~~~~~~')
    print('\n')
     
    ax.set_ylabel('SST (ºC)')
    ax.set_xlabel('Date')
    
    ax.set_title(station + ' Weekly Mean SST Timeseries')
    
    plt.show()
    
    return(plot)
   
def plot_timeseries_weekly(station, data, season):
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # get unique years
    years = data.time.dt.year.values
    unique_years = np.unique(years)
    
    # create custom colormap
    n_colors = len(unique_years)
    colors = plt.cm.plasma(np.linspace(0, 1, n_colors))[::-1]
    cmap = LinearSegmentedColormap.from_list('years_cmap', colors, N=n_colors)
    
    # plot a line for each year
    for year in unique_years:
        year_data = data.sel(time=str(year))
        color = cmap( (year - unique_years.min()) / (unique_years.max()-unique_years.min()) )
        
        weeks = year_data.time.dt.week.values.flatten()
        
        ax.plot(year_data.time.dt.week, year_data.sst.values.flatten(), color=color, alpha = 0.7)
    
    ax.set_xlabel('Week')
    ax.set_ylabel('SST (ºC)')
    ax.set_title('Weekly ' + str(season) + ' average SST by year at ' + str(station)) #later add tributary name and state
    
    legend_elements=[Line2D([0], [0], color=cmap( (year - unique_years.min()) / (unique_years.max()-unique_years.min()) ),
                            lw=4, label=f'{year}') for year in unique_years]
    ax.legend(handles=legend_elements, title='Years')
    
    plt.show()
    
def plot_winter_anom_by_year(station, data):
    years_list = np.array([[2016, 2017, 2018, 2019],
                          [2020, 2021, 2022, 2023]])
    
    #creating subplots
    fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(30, 13))
    plots = [ ]
    row = 0
    col = 0

    while row<2:
        col=0
        while col <4:
            year_data = data.sel(time=str(years_list[row][col]))
            year_data = year_data.sortby('time.week')
            print(year_data.time.dt.week)
            
            print(row, col)
            
            plot = axes[row][col].plot(year_data.time.dt.week, year_data.sst, 
                                     color='#43039e')
            
            y = list(range(0, 1))*52
            axes[row][col].plot(year_data.time.dt.week, y, color='slategrey')
            
            axes[row][col].axhspan(-2, 2, color='#bab8d1', alpha=0.6)
            
            axes[row][col].set_xlim([0, 53])
            axes[row][col].set_ylim([-4, 5])
            
            months_list = list(range(0, 12)) 
            months_ticks_list = [i*4.3 for i in months_list]
            months_labels_list =  ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
            
             
            axes[row][col].set_xticks(months_ticks_list)
            axes[row][col].set_xticklabels(months_labels_list)
            
            axes[row][col].set_xlabel('Month') #dependent variable
            axes[row][col].set_ylabel('SST anomaly (ºC)')
            
            axes[row][col].set_title(station + ' ' + str(years_list[row][col]))
            
            plots.append(plot)
            
            col+=1
                     
        row+=1

    plt.show()
    
    fig.savefig(station+ ' SST Anomaly by year - other.png', dpi=300) #saves the figure into the same directory as your code
    
    
#~~~~~~~~~~CODE~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

sns.set()
plt.rcParams['font.family'] = 'serif'

#download_data()
ds_sst = create_ds()
ds_sst = ds_sst.drop_duplicates(dim='time', keep='first')


sst_time_index = np.arange(len(ds_sst.time))
ds_sst = ds_sst.assign_coords(sst_time_index=('time',sst_time_index))

#download_clim()
ds_clim = create_clim_ds()

examine_data(ds_sst)
# result: latitudes are in descending order (largest first)

examine_data(ds_clim)
# result: latitudes are in descending order (largest first)

#create_basicmap(ds_sst, 0)


stations = ['ET5.2', 'LE2.2', 'WT5.1', 'WE4.2', 'RET4.3']
lats = [38.580, 38.1576, 39.21309, 37.24181, 37.50869]
lons = [-76.0587, -76.598, -76.52254, -76.38634, -76.78889]

i=0

#for whole bay
bay_mean = calc_station_mean('whole bay', ds_sst)
bay_clim_mean = calc_station_mean('whole bay clim', ds_clim)
bay_anom = find_anom(bay_mean, bay_clim_mean)
plot_winter_anom_by_year('Chesapeake Bay', bay_anom)

'''
while i<5:
    print('BEGINNING ANALYSIS FOR ' + str(stations[i]))
    
    print('Processing for regular sst data')
    station_data = find_pixels(stations[i], lats[i], lons[i], ds_sst)
    station_mean = calc_station_mean(stations[i], station_data)
    station_means_winter = isolate_winter_weeks(stations[i], station_mean)
    
    print('Processing for climatology sst data')
    station_clim_data = find_pixels(stations[i], lats[i], lons[i], ds_clim) # I checked, this should yield the same lat lon values as for the normal dataset
    station_clim_mean = calc_station_mean(stations[i], station_clim_data)
    station_clim_means_winter = isolate_winter_weeks_clim(stations[i], station_clim_mean)
    
    print('Calculating anomaly')
    station_anom = find_anom(station_mean, station_clim_mean) #i'm not sure along what dimension they're subtracting??? is it smart enough?
    station_anom_winter = find_anom(station_means_winter, station_clim_means_winter)
    
    #print(station_anom)
    print('Plotting anomaly')
    
    #plot_timeseries(stations[i], station_anom)
    
   # plot_timeseries_weekly(stations[i], station_anom, 'allyear')
   # plot_timeseries_weekly(stations[i], station_anom_winter, 'winter')
    
    plot_winter_anom_by_year(stations[i], station_anom)
    
    i+=1
'''

'''
plot_timeseries('ET5.2', ET52_mean)
plot_timeseries('LE2.2', LE22_mean)
plot_timeseries('WT5.1', WT51_mean)
plot_timeseries('WE4.2', WE42_mean)
plot_timeseries('RET4.3', RET43_mean)

plot_timeseries_weekly('ET5.2', ET52_mean, 'allyear')
plot_timeseries_weekly('LE2.2', LE22_mean, 'allyear')
plot_timeseries_weekly('WT5.1', WT51_mean, 'allyear')
plot_timeseries_weekly('WE4.2', WE42_mean, 'allyear')
plot_timeseries_weekly('RET4.3', RET43_mean, 'allyear')

plot_timeseries_weekly('ET5.2', ET52_means_winter, 'winter')
plot_timeseries_weekly('LE2.2', LE22_means_winter, 'winter')
plot_timeseries_weekly('WT5.1', WT51_means_winter, 'winter')
plot_timeseries_weekly('WE4.2', WE42_means_winter, 'winter')
plot_timeseries_weekly('RET4.3', RET43_means_winter, 'winter')

'''


