#!/usr/bin/env python3

# gmao_rad_subset2 - for PAR algorithm use, get a day's worth of 
#  tau (TAUTOT) and cloud (CLDTOT)
#  The emphasis here is to do the 24 GEOS-IT files and make 1 output

from datetime import datetime, timedelta
import time
import numpy as np
from netCDF4 import Dataset, stringtochar
import argparse
import os
import sys

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,description='''\
  This program subsets out optical thickness and cloud amount from either 
  the MERRA2 or GEOS-IT file from GMAO.  All the day's times are stores in 
  the product ''',add_help=True)
parser.add_argument('-asm', nargs=1, type=str, required=True, help=' input ASM file or list of files) ')
parser.add_argument('-ofile', nargs=1, type=str, help='''
 output file name; if not provided an output filename will be autogenerated
 based on ASM date/time metadata''')

args=parser.parse_args()
asm = args.asm[0]
if args.ofile: ofile = args.ofile[0]

asm_vars = ['TAUTOT','CLDTOT']
var_attrs = ['long_name','units','_FillValue',
             'missing_value','scale_factor','add_offset']

#
# designate types
title = None
gmao_typ = -1
out_start_id = 'UNKNOWN'
merra2_typ = 0
fpit_typ = 1
it_typ = 2

nfiles = 1
file_list = 1
#
# try the '-asm' as a list of files 1st, else netcdf
try:
    with open(asm) as f:
        files=f.readlines()
    nfiles = len(files)
    for ifil in range(nfiles):
        files[ifil] = files[ifil].strip("\n")
except:
    files=[asm]
    file_list = 0
print('processing ', nfiles, ' files')

#  some set up for 1 or 24 files
if nfiles == 1:
    print('Branch for 1 MERRA2 file with 24 times')
    #  Do the MERRA2 with 24 times already in it
    # check that it is MERRA2
    asm_fid = Dataset(files[0], 'r')
    if 'MERRA' in asm_fid.getncattr('Title'):
        title = 'GMAO MERRA2 2D 1-Hour,  Time Averaged, Assimilation Radiation Diagnostics'
        gmao_typ = merra2_typ
        out_start_id = 'GMAO_MERRA2'
    else:
        print('For 1 file submission, the file must be MERRA2 type')
        print('Title was: ', asm_fid.getncattr('Title') )
        sys.exit(1)
    # Get time, start, end times
    sdate = asm_fid.getncattr('RangeBeginningDate')
    stime = asm_fid.getncattr('RangeBeginningTime')
    starttime = datetime.strptime(sdate + ' '+ stime, "%Y-%m-%d %H:%M:%S.%f")
    time_dim = asm_fid.dimensions.get('time')
    file_times = asm_fid.variables.get('time')
    lat_dim = asm_fid.dimensions.get('lat')
    lon_dim = asm_fid.dimensions.get('lon')
    nlat = len(lat_dim)
    nlon = len(lon_dim)
    chunk = (1,nlat,nlon)
    #
    ntime_steps = time_dim.size
    if ntime_steps != 24:
        print('Merra 2 file does not have 24 time periods')
        sys.exit(1)
    #
    # out file
    minutes_of_day = int(file_times[0])
    time_slice_start = starttime + timedelta(minutes=minutes_of_day)
    minutes_of_day = int(file_times[ntime_steps-1])
    time_slice_end = starttime + timedelta(minutes=minutes_of_day)
    if not args.ofile:
        ofile = out_start_id + '.' + time_slice_start.strftime("%Y%m%d")+'.RAD.nc'
    dsout = Dataset(ofile, "w")
    #Copy dimensions from the asm file
    for dname, the_dim in asm_fid.dimensions.items():
        dsout.createDimension(dname, len(the_dim) if not the_dim.isunlimited() else None)

    #Copy the lat variable from the asm file
    varin = asm_fid.variables.get('lat')
    if not varin:
        sys.exit("{} not found in {}\n".format('lat',asm))
    outVar = dsout.createVariable('lat', varin.datatype, dimensions=('lat'),
      contiguous=False, zlib=True, complevel=4, shuffle=True)
    outVar.long_name = "Latitude".encode('ascii')
    outVar.units = "degrees_north".encode('ascii')
    outVar.standard_name = "latitude".encode('ascii')
    outVar.valid_min = float(-90.)
    outVar.valid_max = float(90.)
    outVar[:] = varin[:]

    #Copy the lon variable from the asm file
    varin = asm_fid.variables.get('lon')
    if not varin:
        sys.exit("{} not found in {}\n".format('lon',asm))
    outVar = dsout.createVariable('lon', varin.datatype, dimensions=('lon'),
      contiguous=False, zlib=True, complevel=4, shuffle=True)
    outVar.long_name = "Longitude".encode('ascii')
    outVar.units = "degrees_north".encode('ascii')
    outVar.standard_name = "longitude".encode('ascii')
    outVar.valid_min = float(-180.)
    outVar.valid_max = float(180.)
    outVar[:] = varin[:]
    
    #Copy the time variable from the asm file
    varin = asm_fid.variables.get('time')
    if not varin:
        sys.exit("{} not found in {}\n".format('time',asm))
    outVar = dsout.createVariable('time', varin.datatype, dimensions=('time'),
      contiguous=False, zlib=True, complevel=4, shuffle=True)
    outVar.long_name = "time".encode('ascii')
    outVar.units = varin.units
    outVar.standard_name = "time".encode('ascii')
    outVar.valid_min = int(0)
    outVar.valid_max = int(1440)
    outVar[:] = varin[:]
    
    # Copy asm variables
    for v_name in asm_vars:
        varin = asm_fid.variables.get(v_name)
        if not varin:
            sys.exit("{} not found in {}\n".format(v_name,asm))
        outVar = dsout.createVariable(v_name, varin.datatype,
          dimensions=('time','lat','lon'), contiguous=False,
          chunksizes=chunk,zlib=True, complevel=4, shuffle=True)
        # Copy variable attributes
        outVar.setncatts({k: varin.getncattr(k) for k in var_attrs})
        outVar[:] = varin[:][:][:]
elif nfiles == 24:
#elif nfiles == 2:
    # Do the GEOS-IT (FP-IT) with 24 single-time data
    print('Branch for 24 separate files')
    #  set names based on input type
    #  loop all times
    for ifil in range(nfiles):
        print('Doing input file: ', files[ifil])
        # open it
        asm_fid = Dataset(files[ifil], 'r')
        sdate = asm_fid.getncattr('RangeBeginningDate')
        stime = asm_fid.getncattr('RangeBeginningTime')
        gran_time = datetime.strptime(sdate + ' '+ stime, "%Y-%m-%d %H:%M:%S.%f")       

        # get the lat, lon, time values and # of each
        varin = asm_fid.variables.get('lat')
        lat = varin[:]
        nlat = len(lat)
        #
        varin = asm_fid.variables.get('lon')
        lon = varin[:]
        nlon = len(lon)
        chunk = (1,nlat,nlon)
        #
        varin = asm_fid.variables.get('time')
        time_in = varin[:]
        ntime = len(time_in)
        # check for 1 time only
        if ntime != 1:
            sys.exit('All GEOS IT and FP input files need to have only 1 time. This file has {} times'.format(ntime) )
        #
        #  specific work for 1st file
        if ifil == 0:
            if 'GEOS-IT' in asm_fid.getncattr('Title'):
                title = 'GMAO GEOS-5 GEOS-IT 2D 1-Hour,  Time Averaged, Assimilation Radiation Diagnostics'
                out_start_id = 'GMAO_IT'
                gmao_typ = it_typ
            elif 'FP-IT' in asm_fid.getncattr('Title'):
                title = 'GMAO GEOS-5 FP-IT 2D 1-Hour,  Time Averaged, Assimilation Radiation Diagnostics'
                out_start_id = 'GMAO_FP'
                gmao_typ = fpit_typ
            else:
                print('Unknown GMAO type encountered from title: ', asm_fid.getncattr('Title') )
                sys.exit(1)
            # save the initial lat, lon, time information
            starttime = gran_time
            lat_out = lat
            nlat_out = nlat
            lon_out = lon
            nlon_out = nlon
            time_out = np.zeros(24)
            ntime_out = 24
            #  save the units attr of the 1st time
            first_t_units = varin.units
            # set up the output arrays of the 2 variables
            out_tau = np.zeros((nfiles,nlat_out,nlon_out))
            out_cld = np.zeros((nfiles,nlat_out,nlon_out))
        else:
            #  for all subsequent files
            #  do consistency of lat, lon dims here
            if  not np.array_equal( lat, lat_out ):
                print('file # ', ifil, ' name: ', files[ifil] )
                print('mismatch in one of the lat coords')
                sys.exit(1)
            if  not np.array_equal( lon, lon_out ):
                print('file # ', ifil, ' name: ', files[ifil] )
                print('mismatch in one of the lon coords')
                sys.exit(1)
            #  assume the incoming list of times is for 1 day, in hourly order
            #print('end of else')
        time_out[ifil] = (gran_time-starttime).total_seconds()/60
        varin = asm_fid.variables.get('TAUTOT')
        out_tau[ifil][:][:] = varin[0][:][:]
        varin = asm_fid.variables.get('CLDTOT')
        out_cld[ifil][:][:] = varin[0][:][:]
        #print('bottom of loop thru files')
    #print('end of loop files')

    minutes_of_day = int(time_out[0])
    time_slice_start = starttime + timedelta(minutes=minutes_of_day)
    minutes_of_day = int(time_out[nfiles-1])
    time_slice_end = starttime + timedelta(minutes=minutes_of_day)
    if not args.ofile:
        ofile = out_start_id + '.' + time_slice_start.strftime("%Y%m%d")+'.RAD.NRT.nc'
    dsout = Dataset(ofile, "w")
    #Copy dimensions from the asm file
    for dname, the_dim in asm_fid.dimensions.items():
        dsout.createDimension(dname, len(the_dim) if not the_dim.isunlimited() else None)
    #
    #  Copy the lat, lon, time
    varin = asm_fid.variables.get('lat')
    outVar = dsout.createVariable('lat', varin.datatype, dimensions=('lat'),
      contiguous=False, zlib=True, complevel=4, shuffle=True)
    outVar.long_name = "Latitude".encode('ascii')
    outVar.units = "degrees_north".encode('ascii')
    outVar.standard_name = "latitude".encode('ascii')
    outVar.valid_min = float(-90.)
    outVar.valid_max = float(90.)
    outVar[:] = varin[:]

    varin = asm_fid.variables.get('lon')
    outVar = dsout.createVariable('lon', varin.datatype, dimensions=('lon'),
      contiguous=False, zlib=True, complevel=4, shuffle=True)
    outVar.long_name = "Longitude".encode('ascii')
    outVar.units = "degrees_north".encode('ascii')
    outVar.standard_name = "longitude".encode('ascii')
    outVar.valid_min = float(-180.)
    outVar.valid_max = float(180.)
    outVar[:] = varin[:]

    varin = asm_fid.variables.get('time')
    outVar = dsout.createVariable('time', varin.datatype, dimensions=('time'),
        contiguous=False, zlib=True, complevel=4, shuffle=True)
    outVar.long_name = "time".encode('ascii')
    outVar.units = first_t_units
    outVar.standard_name = "time".encode('ascii')
    outVar.valid_min = int(0)
    outVar.valid_max = int(1440)
    outVar[:] = time_out

    #  write out the tau and cloud
    for v_name in asm_vars:
        varin = asm_fid.variables.get(v_name)
        outVar = dsout.createVariable(v_name, varin.datatype,
           dimensions=('time','lat','lon'), contiguous=False,
           chunksizes=chunk,zlib=True, complevel=4, shuffle=True)
        # Copy variable attributes
        outVar.setncatts({k: varin.getncattr(k) for k in var_attrs})
        if  v_name == 'TAUTOT':
            outVar[:] = out_tau
        else:
            outVar[:] = out_cld
    
else:
    # error, we only do a MERRA2 with all 24 or a set of 24 single-time 
    # GEOS-IT files
    print('Number of times per file to process needs to be 1 for MERRA2 or 24 for GEOS-IT')
    print(nfiles, ' Entered' )
    sys.exit(1)
#
# Add OB.DAAC metadata
# set time to GMT for outputting date_created in ZULU
os.environ['TZ'] = 'GMT'
time.tzset()
dsout.date_created=(datetime.fromtimestamp(time.time()).strftime('%Y-%m-%dT%H:%M:%SZ')).encode('ascii')

history = ''
history = "{} -asm {} ".format(parser.prog,asm)
if args.ofile:
    history += " -ofile {}".format(ofile)
if file_list == 1:
    history += " file list used: "
    for ifil in range(nfiles):
        history += " " + os.path.basename(files[ifil])

history += "; {}".format(asm_fid.getncattr('History'))

dsout.title=title.encode("ascii")
dsout.product_name=ofile.encode("ascii")
dsout.history=history.encode("ascii")
dsout.time_coverage_start=(time_slice_start.strftime("%Y-%m-%dT%H:%M:%SZ")).encode("ascii")
dsout.time_coverage_end=(time_slice_end.strftime("%Y-%m-%dT%H:59:59Z")).encode("ascii")
dsout.geospatial_lat_max=float(asm_fid.getncattr('NorthernmostLatitude'))
dsout.geospatial_lat_min=float(asm_fid.getncattr('SouthernmostLatitude'))
dsout.geospatial_lon_max=float(asm_fid.getncattr('EasternmostLongitude'))
dsout.geospatial_lon_min=float(asm_fid.getncattr('WesternmostLongitude'))
dsout.geospatial_lat_resolution=float(asm_fid.getncattr('LatitudeResolution'))
dsout.geospatial_lon_resolution=float(asm_fid.getncattr('LongitudeResolution'))
dsout.geospatial_lat_units='degrees_north'.encode("ascii")
dsout.geospatial_lon_units='degrees_east'.encode("ascii")
dsout.grid_mapping_name='latitude_longitude'.encode("ascii")
dsout.software_version=(asm_fid.getncattr('Source')).encode("ascii")
dsout.Conventions='CF-1.6'.encode("ascii")
dsout.Metadata_Conventions='Unidata Dataset Discovery v1.0'.encode("ascii")
dsout.license='https://science.nasa.gov/earth-science/earth-science-data/data-information-policy/'.encode("ascii")
dsout.naming_authority='gov.nasa.gsfc.sci.oceandata'.encode("ascii")
dsout.id=(asm_fid.getncattr('VersionID') + '/L4/' + ofile).encode("ascii")
dsout.keywords_vocabulary='NASA Global Change Master Directory (GCMD) Science Keywords'.encode("ascii")
dsout.keywords='ATMOSPHERE>ALTITUDE>GEOPOTENTIAL HEIGHT '.encode("ascii")
dsout.standard_name_vocabulary='NetCDF Climate and Forecast (CF) Metadata Convention'.encode("ascii")
dsout.institution='NASA Goddard Space Flight Center, Ocean Ecology Laboratory, Ocean Biology Processing Group'.encode("ascii")
dsout.creator_name='NASA Global Modeling and Assimilation Office'.encode("ascii")
dsout.creator_email='data@gmao.gsfc.nasa.gov'.encode("ascii")
dsout.creator_url='https://gmao.gsfc.nasa.gov'.encode("ascii")
dsout.project='Ocean Biology Processing Group (NASA/GSFC/OBPG)'.encode("ascii")
dsout.publisher_name='NASA/GSFC/OBPG'.encode("ascii")
dsout.publisher_url='https://oceandata.sci.gsfc.nasa.gov'.encode("ascii")
dsout.publisher_email='data@oceancolor.gsfc.nasa.gov'.encode("ascii")
if gmao_typ == merra2_typ:
    dsout.identifier_product_doi_authority='https://dx.doi.org'.encode("ascii")
    dsout.identifier_product_doi= (asm_fid.getncattr('identifier_product_doi')).encode("ascii")
dsout.processing_level='L4'.encode("ascii")
dsout.cdm_data_type='grid'.encode("ascii")
dsout.spatialResolution = (asm_fid.getncattr('LatitudeResolution')
                    + 'x' +asm_fid.getncattr('LongitudeResolution') + ' degrees').encode("ascii")
dsout.source=("{}".format(asm)).encode("ascii")

dsout.comment=(
  "This file contains a subset of variables generated using the GMAO GEOS model (v{}) for use as input to the OBPG Level 2 data processing stream".format(asm_fid.getncattr('VersionID')))
#    dsout.OBPG_anc_id='OBPG_GMAO_FP-IT_std_anc_v1'.encode("ascii")
# close the output file

dsout.close()
print("Created:\t{}".format(ofile))

asm_fid.close()
