##########################################################################
#
# PYTHON 3 FUNCTIONS FOR quality_controls
#
# CO-AUTHORS:   Katherine E. Lukens             NOAA/NESDIS/STAR, CISESS at U. of Maryland
#               Kevin Garrett                   NOAA/NWS/OSTI
#               Kayo Ide                        U. of Maryland
#               David Santek                    CIMSS at U. of Wisconsin-Madison
#               Brett Hoover                    NOAA/NWS/NCEP/EMC, Lynker Technologies
#               Ross N. Hoffman                 NOAA/NESDIS/STAR, CISESS at U. of Maryland
#               Hui Liu                         NOAA/NESDIS/STAR, CISESS at U. of Maryland
#
# Built with the following conda environment:
#
# name: bhoover-obs_match_3d
# channels:
#   - conda-forge
#   - defaults
# dependencies:
#   - python=3
#   - numpy
#   - pandas
#   - pynio
#   - matplotlib
#   - cartopy
#   - jupyter
#   - netCDF4
#   - scikit-learn
#   - dask
#   - geopy
#   - pip
#   - pip:
#
###########################################################################
#
# Import modules
#

import sys
import math
import numpy as np #....................................................... Array module
import datetime as dt #.................................................... Datetime module
import time #.............................................................. Time module
from datetime import datetime
import warnings

import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize 
from scipy.interpolate import interpn

from statistics import mean

#
###########################################################################
#
# STATISTICAL ANALYSIS functions for collocation program
#

fill = -999.0

# -------------------------------------------------------------------------
# Compute Horizontal Line-of-Sight (HLOS) Wind for using Aeolus azimuth angle
#
#	INPUTS:
#		x_azm ............................... Aeolus azimuth angle
#		y_dir ............................... 
#		y_spd ............................... driver dataset name
#
#	OUTPUTS:
#		idx ................................. indices where dataset PASSES QC
#	 	qc_list ............................. string listing all QC criteria used
#
def compute_hlos(x_azm,x_spd,y_dir,y_spd):

  if x_spd==np.nan or y_spd==np.nan:
    hlos = np.nan
    return hlos

	# compute HLOS wind velocity of y using x HLOS angle
  sindir = math.sin(y_dir*(math.pi/180.0))
  cosdir = math.cos(y_dir*(math.pi/180.0))
  
  sinazm = -1.0*math.sin(x_azm*(math.pi/180.0))
  cosazm = -1.0*math.cos(x_azm*(math.pi/180.0))
  
  u = -1.0*y_spd*sindir
  v = -1.0*y_spd*cosdir
  
  hlos = u*sinazm + v*cosazm
  
  return hlos
    
# -------------------------------------------------------------------------
# Reassign all missing values to NaN
#	Applies to collocation pairs using any dataset
#
#       INPUTS:
#               pccf ................................ percent confidence (quality indicator, QI) in data point
#               pct ................................. min percent allowed
#
#       OUTPUTS:
#               idx ................................. indices where dataset PASSES QC
#
def to_nan(sdrv_spd):

    # Reassign array value to NaN if value <= fill value (missing)
    tdrv_spd  = [np.nan if sdrv_spd[i]==fill else sdrv_spd[i] for i in range(np.size(sdrv_spd))]
    del sdrv_spd

    drv_spd = np.asarray(tdrv_spd)

    return drv_spd
    
# -------------------------------------------------------------------------
# Super-ob (average) matches
#
def superob_matches(idxD,Dyr,Dmm,Ddy,Dhr,Dmn,Dlat,Dlon,Dprs,Dhgt,Dspd,Ddir,tlat,tlon,tprs,thgt,tspd,tdir,Dwcm,twcm):

    sqdrv_yr = []; sqdrv_mm = []; sqdrv_dy = []; sqdrv_hr = []; sqdrv_mn = []
    sqdrv_lat = []; sqdrv_lon = []; sqdrv_prs = []; sqdrv_hgt = []; sqdrv_spd = []; sqdrv_dir = []
    sqt_lat = []; sqt_lon = []; sqt_prs = []; sqt_hgt = []; sqt_spd = []; sqt_dir = []
    sqdrv_wcm = []; sqt_wcm = []
    
    idxD_uniq = list(set(idxD))
    for imax in range(np.size(idxD_uniq)):
      tmp_idx = np.where(idxD==idxD_uniq[imax])
      if np.size(tmp_idx)>0:
        ttlat = tlat[tmp_idx]
        ttlon = tlon[tmp_idx]
        ttprs = tprs[tmp_idx]
        tthgt = thgt[tmp_idx]
        ttspd = tspd[tmp_idx]
        ttdir = tdir[tmp_idx]
        ttwcm = twcm[tmp_idx]
		# average DEPENDENT obs per DRIVER ob.
        mean_lat = np.mean(ttlat)
        mean_lon = np.mean(ttlon)
        mean_prs = np.mean(ttprs)
        mean_hgt = np.mean(tthgt)
        mean_spd = np.mean(ttspd)
        mean_dir = np.mean(ttdir)
        mean_wcm = np.mean(ttwcm)
        del ttlat,ttlon,ttprs,tthgt,ttspd,ttdir,ttwcm
		# append to output arrays
			# DEPENDENT
        sqt_lat.append(mean_lat)
        sqt_lon.append(mean_lon)
        sqt_prs.append(mean_prs)
        sqt_hgt.append(mean_hgt)
        sqt_spd.append(mean_spd)
        sqt_dir.append(mean_dir)
        sqt_wcm.append(mean_wcm)
        del mean_lat,mean_lon,mean_prs,mean_hgt,mean_spd,mean_dir,mean_wcm
			# DRIVER
        sDyr  = Dyr[tmp_idx]        
        sDmm  = Dmm[tmp_idx]        
        sDdy  = Ddy[tmp_idx]        
        sDhr  = Dhr[tmp_idx]        
        sDmn  = Dmn[tmp_idx]        
        sDlat = Dlat[tmp_idx]        
        sDlon = Dlon[tmp_idx]        
        sDprs = Dprs[tmp_idx]        
        sDhgt = Dhgt[tmp_idx]        
        sDspd = Dspd[tmp_idx]        
        sDdir = Ddir[tmp_idx]        
        sDwcm = Dwcm[tmp_idx]        

        sqdrv_yr.append(sDyr[0])
        sqdrv_mm.append(sDmm[0])
        sqdrv_dy.append(sDdy[0])
        sqdrv_hr.append(sDhr[0])
        sqdrv_mn.append(sDmn[0])
        sqdrv_lat.append(sDlat[0])
        sqdrv_lon.append(sDlon[0])
        sqdrv_prs.append(sDprs[0])
        sqdrv_hgt.append(sDhgt[0])
        sqdrv_spd.append(sDspd[0])
        sqdrv_dir.append(sDdir[0])
        sqdrv_wcm.append(sDwcm[0])
        del sDyr,sDmm,sDdy,sDhr,sDmn,sDlat,sDlon,sDprs,sDhgt,sDspd,sDdir,sDwcm
      del tmp_idx

	# convert list to np.array
    qdrv_yr  = np.asarray(sqdrv_yr )
    qdrv_mm  = np.asarray(sqdrv_mm )
    qdrv_dy  = np.asarray(sqdrv_dy )
    qdrv_hr  = np.asarray(sqdrv_hr )
    qdrv_mn  = np.asarray(sqdrv_mn )
    qdrv_lat = np.asarray(sqdrv_lat)
    qdrv_lon = np.asarray(sqdrv_lon)
    qdrv_prs = np.asarray(sqdrv_prs)
    qdrv_hgt = np.asarray(sqdrv_hgt)
    qdrv_spd = np.asarray(sqdrv_spd)
    qdrv_dir = np.asarray(sqdrv_dir)
    qdrv_wcm = np.asarray(sqdrv_wcm)

    qt_lat   = np.asarray(sqt_lat)
    qt_lon   = np.asarray(sqt_lon)
    qt_prs   = np.asarray(sqt_prs)
    qt_hgt   = np.asarray(sqt_hgt)
    qt_spd   = np.asarray(sqt_spd)
    qt_dir   = np.asarray(sqt_dir)
    qt_wcm   = np.asarray(sqt_wcm)

    return qdrv_yr,qdrv_mm,qdrv_dy,qdrv_hr,qdrv_mn,qdrv_lat,qdrv_lon,qdrv_prs,qdrv_hgt,qdrv_spd,qdrv_dir,qt_lat,qt_lon,qt_prs,qt_hgt,qt_spd,qt_dir,qdrv_wcm,qt_wcm

# -------------------------------------------------------------------------
# Convert pressure to height
#	Convert pressure to pressure altitude (height) following NWS formulation (https://www.weather.gov/media/epz/wxcalc/pressureAltitude.pdf)
#
# Inputs	prs = pressure in hPa
# Returns 	hgt = height in km
#
def prs_to_hgt(prs):	  

    	# check pressure units and convert to hPa
    if max(prs) > 10000.:
      prs = prs/100.

	# convert pressure to height
    hgt = np.nan * np.ones_like(prs)
    for i in range(np.size(prs)):
      hgt[i] = 145366.45 * (1.0 - (prs[i]/1013.25)**0.190284)			# convert hPa (mb) to feet
      hgt[i] = hgt[i] * 0.3048							# convert to meters
      hgt[i] = hgt[i] / 1000.0							# convert to km

    return hgt

# -------------------------------------------------------------------------
# Conform variable to 2 dimensions
#
def var_to_2d(xstr,ystr,str2d,x,y,xvar,yvar,var):

	# sort y to be in ascending order
    y.sort()

	# loop to compute sums per grid cell (determined by x,y arrays)
    SDvar2d  = np.zeros([len(y),len(x)], dtype=float)
    sumvar2d = np.zeros([len(y),len(x)], dtype=float)
    nsum     = np.zeros([len(y),len(x)], dtype=float)
    for j in range(len(y)-1):
      if ystr=="Pressure" or ystr=="Height":
        ydiff = abs(y[j+1] - y[j])/2.0  		# half the difference between each gridded z value
        ymin  = y[j]-ydiff
        ymax  = y[j]+ydiff
        del ydiff
      else:
        ymin  = y[j]
        ymax  = y[j+1]

      for k in range(len(x)-1):
        if xstr=="Time":
          xidx = np.where((xvar==x[k])*(yvar>=ymin)*(yvar<ymax))
        else:
          xidx = np.where((xvar>=x[k])*(xvar<x[k+1])*(yvar>=ymin)*(yvar<ymax))
        sumvar2d[j,k] = np.sum(var[xidx])
        nsum[j,k]     = np.size(xidx)
        if str2d=="SD":
          SDvar2d[j,k] = np.std(var[xidx])
        del xidx
      del ymin,ymax
	    
	# mean per grid cell 
    warnings.filterwarnings('ignore', category=RuntimeWarning)
    meanvar = sumvar2d / nsum

    if str2d=="Count":
      nsum_nan = nsum
      for j in range(len(y)):
        for k in range(np.size(x)):
          if nsum_nan[j,k]==0: nsum_nan[j,k]=np.nan
      return nsum_nan	# return counts per grid cell
      
    elif str2d=="SD":
      nSD_nan = SDvar2d
      for j in range(len(y)):
        for k in range(np.size(x)):
          if nSD_nan[j,k]==0: nSD_nan[j,k]=np.nan
      return nSD_nan 	# return SD of var per grid cell
      
    else:
      return meanvar	# return mean var per grid cell
    
# -------------------------------------------------------------------------
# Interpolate to given pressure level
#	Assumes a log-linear relationship
#
#	Input:
#   		vcoord_data:   1D array of vertical level values (e.g., pressure from a radiosonde)
#    		interp_var:    1D array of the variable to be interpolated to all pressure levels
#    		interp_levels: 1D array containing veritcal levels to interpolate to
#
#    	Return:
#    		interp_data:   1D array that contains the interpolated variable on the interp_levels
#
# Source: https://unidata.github.io/python-training/gallery/observational_data_cross_section/
#
def vert_interp(vcoord_data, interp_var, interp_levels):

    # Make veritcal coordinate data and grid level log variables
    lnp 	  = np.log(vcoord_data)
    lnp_intervals = np.log(interp_levels)

    # Use numpy to interpolate from observed levels to grid levels
    interp_data   = np.interp(lnp_intervals[::-1], lnp[::-1], interp_var[::-1])[::-1]

    # Mask for missing data (generally only near the surface)
    mask_low 		   = interp_levels > vcoord_data[0]
    mask_high 		   = interp_levels < vcoord_data[-1]
    interp_data[mask_low]  = interp_var[0]
    interp_data[mask_high] = interp_var[-1]

    return interp_data



#############################################################################
