#!/usr/bin/env python3
#
#  gmao_cf_subset will subset the GMAO_CF chemistry profile file 
#  (chm_inst_1hr_g1440x721_v72).  It will also get the total and 
#  tropospheric column NO2 and create dataasets of TOTCOL_NO2, TROPCOL_NO2, 
#  and STRATCOL_NO2 from the extra chemistry file (xgc_tavg_1hr_g1440x721_x1)
#  This also interpolates the data from modified sigma levels to the MERRA2 42
#  pressure levels.  It uses the metpy library and can interpolate the entire 
#  array at once, much faster than the numpy or scipy libraries
#

from datetime import datetime, timedelta
import time
import numpy as np
from netCDF4 import Dataset, stringtochar
from metpy.interpolate import interpolate_1d
import argparse
import os
import sys

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
  description='''\
  This program extracts the profile and column total NO2 data from the GMAO-CF 
  files and puts them into single file for used by l2gen ''',add_help=True)
parser.add_argument('-no2_prof', nargs=1, type=str, required=True, help=' input NO2 profile file (chm_inst_1hr_g1440x721_v72) ')
parser.add_argument('-col', nargs=1, type=str, required=True, help=' input NO2 column value file (xgc_tavg_1hr_g1440x721_x1)')
parser.add_argument('-met_prof', nargs=1, type=str, required=True, help=' input Met profile file (met_inst_1hr_g1440x721_v72) ')
parser.add_argument('-ofile', nargs=1, type=str, help='''
 output file name NO2 and SO2 on 42 MERRA2 P levels; if not provided an 
 output filename will be autogenerated based on ASM date/time metadata''')
parser.add_argument('-outlist',nargs=1,type=str,help='''
 filename for an optional output file that contains the name(s) of the output
 file(s) along with the corresponding start/mid/stop times''')

args=parser.parse_args()
no2_prof = args.no2_prof[0]
col = args.col[0]
met_prof = args.met_prof[0]

mis_val = 1.e+15
no2_prof_vars = ['NO2','SO2']
#no2_prof_vars = ['NO2','SO2', 'CH4','O3']  # they are available
col_vars = ['TOTCOL_NO2', 'TROPCOL_NO2']
met_prof_vars = ['DELP']
var_attrs = ['long_name','units','standard_name','_FillValue',
             'missing_value','scale_factor','add_offset']

no2_prof_fid = Dataset(no2_prof, 'r')
col_fid = Dataset(col, 'r')
met_prof_fid = Dataset(met_prof, 'r')

sdate = no2_prof_fid.getncattr('RangeBeginningDate')
stime = no2_prof_fid.getncattr('RangeBeginningTime')
starttime = datetime.strptime(sdate + ' '+ stime, "%Y-%m-%d %H:%M:%S.%f")

title = None
#  new title
title = """GMAO GEOS-5 GEOS-CF 3d instantaneous chemistry diagnostics on 
    tropospheric and lower stratospheric pressure levels"""

ofile = None
outlist = None

if args.outlist:
    outlist = open(args.outlist[0], 'w')

time_dim = no2_prof_fid.dimensions.get('time')
lat_dim = no2_prof_fid.dimensions.get('lat')
lon_dim = no2_prof_fid.dimensions.get('lon')
file_times = no2_prof_fid.variables.get('time')
ntime_steps = time_dim.size

for timestep in range(ntime_steps):
    print("Doing time: ", timestep)
    minutes_of_day = int(file_times[timestep])
    time_slice_start = starttime + timedelta(minutes=minutes_of_day)

    # determine output file
    if not args.ofile:
        ofile = 'GMAO_GEOS-CF.'+time_slice_start.strftime("%Y%m%dT%H0000")+'.PROFILE.nc'
    else:
        if ntime_steps > 1:
            ofile = args.ofile[0] + '_time_step_' + str(timestep)
        else:
            ofile = args.ofile[0]

    dsout = Dataset(ofile, "w")

    if outlist:
        outlist.write('{} {}\n'.format(ofile, time_slice_start.strftime("%Y-%m-%dT%H:%M:%S")))

    #Copy dimensions from the no2_prof file
    for dname, the_dim in no2_prof_fid.dimensions.items():
        if dname != 'time' and dname != 'lev':
            dsout.createDimension(dname, len(the_dim) if not the_dim.isunlimited() else None)

    #Copy the lat variable from the no2_prof file
    varin = no2_prof_fid.variables.get('lat')
    nlat = varin.shape[0]
    if not varin:
        sys.exit("{} not found in {}\n".format('lat',no2_prof))
    chunk = varin.chunking()
    outVar = dsout.createVariable('lat', varin.datatype, dimensions=('lat'),
                                      contiguous=False, chunksizes=varin.chunking(),zlib=True,
                                      complevel=4, shuffle=True)
    outVar.long_name = "Latitude".encode('ascii')
    outVar.units = "degrees_north".encode('ascii')
    outVar.standard_name = "latitude".encode('ascii')
    outVar.valid_min = float(-90.)
    outVar.valid_max = float(90.)
    outVar[:] = varin[:]

    #Copy the lon variable from the no2_prof file
    varin = no2_prof_fid.variables.get('lon')
    nlon = varin.shape[0]
    if not varin:
        sys.exit("{} not found in {}\n".format('lon',no2_prof))
    chunk = varin.chunking()
    outVar = dsout.createVariable('lon', varin.datatype, dimensions=('lon'),
                                      contiguous=False, chunksizes=varin.chunking(),zlib=True,
                                      complevel=4, shuffle=True)
    outVar.long_name = "Longitude".encode('ascii')
    outVar.units = "degrees_north".encode('ascii')
    outVar.standard_name = "longitude".encode('ascii')
    outVar.valid_min = float(-180.)
    outVar.valid_max = float(180.)
    outVar[:] = varin[:]

    # the lev will now be the # merra2 levels=42
    varin = no2_prof_fid.variables.get('lev')
    if not varin:
        sys.exit("{} not found in {}\n".format('lev',no2_prof))
    if (varin.shape[0] != 72):
        sys.exit("unusual dim size for lev (not 72): {}\n".format(varin.shape[0]))
    dsout.createDimension('lev', 42 )
    outVar = dsout.createVariable('lev', "f4", dimensions=('lev'),
                                      contiguous=False, zlib=True,
                                      complevel=4, shuffle=True)
    #outVar = dsout.createVariable('lev', varin.datatype, dimensions=('lev'),
    #                                  contiguous=False, chunksizes=varin.chunking(),zlib=True,
    #                                  complevel=4, shuffle=True)
    outVar.long_name = "vertical level".encode('ascii')
    outVar.units = "pressure".encode('ascii')
    outVar.positive = "down".encode('ascii')
    outVar.coordinate = "PLE".encode('ascii')
    outVar.standard_name = "PLE_level".encode('ascii')
    outVar.valid_min = float(0.1)
    outVar.valid_max = float(1000.)
    plev_merra2 = np.array([1000., 975., 950., 925., 900., 875., 850., 825., 
    800., 775., 750., 725., 700., 650., 600., 550., 500., 450., 400., 350., 
    300., 250., 200., 150., 100., 70., 50., 40., 30., 20., 10., 7., 5., 4., 
    3., 2., 1., 0.7, 0.5, 0.4, 0.3, 0.1])
    outVar[:] = plev_merra2

    out_gas = np.empty( [42,nlat,nlon],"f4" )
    # get the p profile from delp, add the bottom and 1040 mb
    # to get monotonic increase in P
    p_prof = np.empty( [74,nlat,nlon],"f4" )
    v_name = met_prof_vars[0]
    varinp = met_prof_fid.variables.get(v_name)
    if not varinp:
      sys.exit("{} not found in {}\n".format(v_name,met_prof))
    p_prof[0,:,:] = 0.01 + varinp[timestep,0,:,:] / ( 2. * 100.)
    for ihgt in range(1,72):
      p_prof[ihgt,:,:] = p_prof[ihgt-1,:,:] + ( varinp[timestep,ihgt-1,:,:] + varinp[timestep,ihgt,:,:] ) / ( 2. * 100. )
    p_prof[72,:,:] = p_prof[71,:,:] + varinp[timestep,71,:,:] / ( 2. * 100. )
    p_prof[73,:,:] = 1040.

    # Copy no2_prof variables
    for v_name in no2_prof_vars:
        print("Doing gas: ", v_name)
        varin = no2_prof_fid.variables.get(v_name)
        if not varin:
            sys.exit("{} not found in {}\n".format(v_name,no2_prof))
        chunk = varin.chunking()
        # make gas profile with the bottom value repeated 2x
        gas_prof = np.empty( [74,nlat,nlon],"f4" )
        gas_prof[0:72,:,:] = varin[timestep,:]
        gas_prof[72,:,:] = varin[timestep,71,:,:]
        gas_prof[73,:,:] = varin[timestep,71,:,:]

        outVar = dsout.createVariable(v_name, "f4", dimensions=('lev','lat','lon'),
                                      contiguous=False, chunksizes=[1,91,144], zlib=True,
                                      complevel=4, shuffle=True)
        # Copy variable attributes
        outVar.setncatts({k: varin.getncattr(k) for k in var_attrs})
        # for the data itself, we'll convert to the 42 merra2 levels
        out_gas = interpolate_1d( plev_merra2, p_prof, gas_prof, fill_value=1.e15)
        # output the data
        outVar[:] = out_gas

    # Copy the col(umn) variables
    for v_name in col_vars:
        varin = col_fid.variables.get(v_name)
        if not varin:
            sys.exit("{} not found in {}\n".format(v_name,col))
        chunk = varin.chunking()
        outVar = dsout.createVariable(v_name, varin.datatype, dimensions=('lat','lon'),
                                      contiguous=False, chunksizes=chunk[1:],zlib=True,
                                      complevel=4, shuffle=True)
        # Copy variable attributes
        outVar.setncatts({k: varin.getncattr(k) for k in var_attrs})
        outVar[:] = varin[timestep][:][:]

    # subtract trop from total and put it in the STRATCOL_NO2
    outVar = dsout.createVariable('STRATCOL_NO2', varin.datatype, dimensions=('lat','lon'),
                                      contiguous=False, chunksizes=chunk[1:],zlib=True,
                                      complevel=4, shuffle=True)
    # Copy variable attributes
    outVar.setncatts({k: varin.getncattr(k) for k in var_attrs})

    # adjust for strat column change 'tropospheric' to 'stratospheric'
    mod_attrs = [ 'long_name', 'standard_name' ]
    for k in mod_attrs:
        name = varin.getncattr(k)
        new_name = name.replace("tropospheric", "stratospheric")
        outVar.setncatts({k: new_name})

    # and put in the data
    varin2 = col_fid.variables.get(col_vars[0])
    outVar[:] = varin2[timestep][:][:] - varin[timestep][:][:]

    # Get variables with scale_factor attribute set to 1.0 and remove if add_offset is 0
    for var in dsout.createGroup('/').get_variables_by_attributes(scale_factor=1.0):
            if var.getncattr('add_offset') == 0.0:
                del var.scale_factor
                del var.add_offset

    # Add OB.DAAC metadata
    # set time to GMT for outputting date_created in ZULU
    os.environ['TZ'] = 'GMT'
    time.tzset()
    dsout.date_created=(datetime.fromtimestamp(time.time()).strftime('%Y-%m-%dT%H:%M:%SZ')).encode('ascii')

    history = ''
    history = "{} -no2_prof {} -col {}".format(parser.prog,no2_prof,col)
    if args.ofile:
        history += " -ofile {}".format(ofile)

    history += "; {}".format(no2_prof_fid.getncattr('History'))

    dsout.title=title.encode("ascii")
    dsout.product_name=ofile.encode("ascii")
    dsout.history=history.encode("ascii")
    dsout.time_coverage_start=(time_slice_start.strftime("%Y-%m-%dT%H:%M:%SZ")).encode("ascii")
    dsout.time_coverage_end=(time_slice_start.strftime("%Y-%m-%dT%H:%M:%SZ")).encode("ascii")
    dsout.geospatial_lat_max=float(no2_prof_fid.getncattr('NorthernmostLatitude'))
    dsout.geospatial_lat_min=float(no2_prof_fid.getncattr('SouthernmostLatitude'))
    dsout.geospatial_lon_max=float(no2_prof_fid.getncattr('EasternmostLongitude'))
    dsout.geospatial_lon_min=float(no2_prof_fid.getncattr('WesternmostLongitude'))
    dsout.geospatial_lat_resolution=float(no2_prof_fid.getncattr('LatitudeResolution'))
    dsout.geospatial_lon_resolution=float(no2_prof_fid.getncattr('LongitudeResolution'))
    dsout.geospatial_lat_units='degrees_north'.encode("ascii")
    dsout.geospatial_lon_units='degrees_east'.encode("ascii")
    dsout.grid_mapping_name='latitude_longitude'.encode("ascii")
    dsout.software_version=(no2_prof_fid.getncattr('Source')).encode("ascii")
    dsout.Conventions='CF-1.6'.encode("ascii")
    dsout.Metadata_Conventions='Unidata Dataset Discovery v1.0'.encode("ascii")
    dsout.license='https://science.nasa.gov/earth-science/earth-science-data/data-information-policy/'.encode("ascii")
    dsout.naming_authority='gov.nasa.gsfc.sci.oceandata'.encode("ascii")
    dsout.id=(no2_prof_fid.getncattr('VersionID') + '/L4/' + ofile).encode("ascii")
    dsout.keywords_vocabulary='NASA Global Change Master Directory (GCMD) Science Keywords'.encode("ascii")
    dsout.keywords='ATMOSPHERE>ALTITUDE>GEOPOTENTIAL HEIGHT '.encode("ascii")
    dsout.standard_name_vocabulary='NetCDF Climate and Forecast (CF) Metadata Convention'.encode("ascii")
    dsout.institution='NASA Goddard Space Flight Center, Ocean Ecology Laboratory, Ocean Biology Processing Group'.encode("ascii")
    dsout.creator_name='NASA Global Modeling and Assimilation Office'.encode("ascii")
    dsout.creator_email='data@gmao.gsfc.nasa.gov'.encode("ascii")
    dsout.creator_url='https://gmao.gsfc.nasa.gov'.encode("ascii")
    dsout.project='Ocean Biology Processing Group (NASA/GSFC/OBPG)'.encode("ascii")
    dsout.publisher_name='NASA/GSFC/OBPG'.encode("ascii")
    dsout.publisher_url='https://oceandata.sci.gsfc.nasa.gov'.encode("ascii")
    dsout.publisher_email='data@oceancolor.gsfc.nasa.gov'.encode("ascii")
    dsout.processing_level='L4'.encode("ascii")
    dsout.cdm_data_type='grid'.encode("ascii")
    dsout.spatialResolution = (no2_prof_fid.getncattr('LatitudeResolution')
                    + 'x' +no2_prof_fid.getncattr('LongitudeResolution') + ' degrees').encode("ascii")
    dsout.source=("{},{}".format(no2_prof,col)).encode("ascii")

    dsout.comment=(
    "This file contains a subset of variables generated using the GMAO GEOS model (v{}) for use as input to the OBPG Level 2 data processing stream".format(no2_prof_fid.getncattr('VersionID')))
#    dsout.OBPG_anc_id='OBPG_GMAO_FP-IT_std_anc_v1'.encode("ascii")

    # close the output file

    dsout.close()
    print("Created:\t{}".format(ofile))

if outlist:
    outlist.close()

no2_prof_fid.close()
col_fid.close()

