#!/usr/bin/env python3

from datetime import datetime, timedelta
import time
import numpy as np
from netCDF4 import Dataset, stringtochar
import argparse
import os
import sys

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,description='''\
  This program takes three FP-IT files and merges them into single file for used
  by l2gen ''',add_help=True)
parser.add_argument('-asm', nargs=1, type=str, required=True, help=' input ASM file) ')
parser.add_argument('-lnd', nargs=1, type=str,required=True, help=' input LND file ')
parser.add_argument('-ocn', nargs=1, type=str,required=True, help=' input OCN file ')
parser.add_argument('-ofile', nargs=1, type=str, help='''
 output file name; if not provided an output filename will be autogenerated
 based on ASM date/time metadata''')
parser.add_argument('-outlist',nargs=1,type=str,help='''
 filename for an optional output file that contains the name(s) of the output
 file(s) along with the corresponding start/mid/stop times''')

args=parser.parse_args()
asm = args.asm[0]
lnd = args.lnd[0]
ocn = args.ocn[0]

asm_vars = ['PS','QV10M','SLP','T10M','TO3','TQV','U10M','V10M']
lnd_vars = ['FRSNO']
ocn_vars = ['FRSEAICE']
var_attrs = ['long_name','units','standard_name','_FillValue',
             'missing_value','scale_factor','add_offset']

asm_fid = Dataset(asm, 'r')
lnd_fid = Dataset(lnd, 'r')
ocn_fid = Dataset(ocn, 'r')

sdate = asm_fid.getncattr('RangeBeginningDate')
stime = asm_fid.getncattr('RangeBeginningTime')
starttime = datetime.strptime(sdate + ' '+ stime, "%Y-%m-%d %H:%M:%S.%f")

title = None
gmao_typ = -1
out_start_id = 'UNKNOWN'
merra2_typ = 0
fpit_typ = 1
it_typ = 2
if 'MERRA' in asm_fid.getncattr('Title'):
    title = 'GMAO MERRA2 2D Hourly,  Instantaneous, Single-Level, Assimilation Data'
    gmao_typ = merra2_typ
    out_start_id = 'GMAO_MERRA2'
elif 'GEOS-IT' in asm_fid.getncattr('Title'):
    title = 'GMAO GEOS-5 GEOS-IT 2D 1-Hourly,  Instantaneous, Single-Level, Assimilation Data'
    out_start_id = 'GMAO_IT'
    gmao_typ = it_typ
elif 'FP-IT' in asm_fid.getncattr('Title'):
    title = 'GMAO GEOS-5 FP-IT 2D 3-Hourly,  Instantaneous, Single-Level, Assimilation Data'
    out_start_id = 'GMAO_FP'
    gmao_typ = fpit_typ
else:
    print('Unknown GMAO type encountered from title: ', asm_fid.getncattr('Title') )
    sys.exit(1)

ofile = None
outlist = None

if args.outlist:
    outlist = open(args.outlist[0], 'w')

time_dim = asm_fid.dimensions.get('time')
lat_dim = asm_fid.dimensions.get('lat')
lon_dim = asm_fid.dimensions.get('lon')
file_times = asm_fid.variables.get('time')
ntime_steps = time_dim.size

for timestep in range(ntime_steps):
    minutes_of_day = int(file_times[timestep])
    time_slice_start = starttime + timedelta(minutes=minutes_of_day)

    # determine output file
    if not args.ofile:
        if  gmao_typ == merra2_typ:
            ofile = out_start_id + '.' + time_slice_start.strftime("%Y%m%dT%H0000")+'.MET.nc'
        else:
            ofile = out_start_id + '.' + time_slice_start.strftime("%Y%m%dT%H0000")+'.MET.NRT.nc'
    else:
        if ntime_steps > 1:
            ofile = args.ofile[0] + '_time_step_' + str(timestep)
        else:
            ofile = args.ofile[0]

    dsout = Dataset(ofile, "w")

    if outlist:
        outlist.write('{} {}\n'.format(ofile, time_slice_start.strftime("%Y-%m-%dT%H:%M:%S")))

    #Copy dimensions from the asm file
    for dname, the_dim in asm_fid.dimensions.items():
        if dname != 'time':
            dsout.createDimension(dname, len(the_dim) if not the_dim.isunlimited() else None)

    #Copy the lat variable from the asm file
    varin = asm_fid.variables.get('lat')
    if not varin:
        sys.exit("{} not found in {}\n".format('lat',ocn))
    outVar = dsout.createVariable('lat', varin.datatype, dimensions=('lat'),
       contiguous=False, zlib=True, complevel=4, shuffle=True )
    outVar.long_name = "Latitude".encode('ascii')
    outVar.units = "degrees_north".encode('ascii')
    outVar.standard_name = "latitude".encode('ascii')
    outVar.valid_min = float(-90.)
    outVar.valid_max = float(90.)
    outVar[:] = varin[:]

    #Copy the lon variable from the asm file
    varin = asm_fid.variables.get('lon')
    if not varin:
        sys.exit("{} not found in {}\n".format('lon',ocn))
    outVar = dsout.createVariable('lon', varin.datatype, dimensions=('lon'),
                                      contiguous=False, zlib=True,
                                      complevel=4, shuffle=True)
    outVar.long_name = "Longitude".encode('ascii')
    outVar.units = "degrees_north".encode('ascii')
    outVar.standard_name = "longitude".encode('ascii')
    outVar.valid_min = float(-180.)
    outVar.valid_max = float(180.)
    outVar[:] = varin[:]

    # Copy asm variables
    chunk = [len(lat_dim),len(lon_dim)]
    for v_name in asm_vars:
        varin = asm_fid.variables.get(v_name)
        if not varin:
            sys.exit("{} not found in {}\n".format(v_name,asm))
        outVar = dsout.createVariable(v_name, varin.datatype, 
          dimensions=('lat','lon'), contiguous=False, chunksizes=chunk,
          zlib=True, complevel=4, shuffle=True)
        # Copy variable attributes
        outVar.setncatts({k: varin.getncattr(k) for k in var_attrs})
        outVar[:] = varin[timestep][:][:]

    # Copy lnd variables
    for v_name in lnd_vars:
        varin = lnd_fid.variables.get(v_name)
        if not varin:
            sys.exit("{} not found in {}\n".format(v_name,lnd))
        outVar = dsout.createVariable(v_name, varin.datatype, 
          dimensions=('lat','lon'), contiguous=False, chunksizes=chunk,
          zlib=True, complevel=4, shuffle=True)
        # Copy variable attributes
        outVar.setncatts({k: varin.getncattr(k) for k in var_attrs})
        outVar[:] = varin[timestep][:][:]

    # Copy ocn variables
    for v_name in ocn_vars:
        varin = ocn_fid.variables.get(v_name)
        if not varin:
            sys.exit("{} not found in {}\n".format(v_name,ocn))
        outVar = dsout.createVariable(v_name, varin.datatype, 
          dimensions=('lat','lon'), contiguous=False, chunksizes=chunk,
          zlib=True, complevel=4, shuffle=True)
        # Copy variable attributes
        outVar.setncatts({k: varin.getncattr(k) for k in var_attrs})
        outVar[:] = varin[timestep][:][:]

    # Get variables with scale_factor attribute set to 1.0 and remove if add_offset is 0
    for var in dsout.createGroup('/').get_variables_by_attributes(scale_factor=1.0):
            if var.getncattr('add_offset') == 0.0:
                del var.scale_factor
                del var.add_offset

    # Add OB.DAAC metadata
    # set time to GMT for outputting date_created in ZULU
    os.environ['TZ'] = 'GMT'
    time.tzset()
    dsout.date_created=(datetime.fromtimestamp(time.time()).strftime('%Y-%m-%dT%H:%M:%SZ')).encode('ascii')

    history = ''
    history = "{} -asm {} -lnd {} -ocn {}".format(parser.prog,asm,lnd,ocn)
    if args.ofile:
        history += " -ofile {}".format(ofile)

    history += "; {}".format(asm_fid.getncattr('History'))

    dsout.title=title.encode("ascii")
    dsout.product_name=ofile.encode("ascii")
    dsout.history=history.encode("ascii")
    dsout.time_coverage_start=(time_slice_start.strftime("%Y-%m-%dT%H:%M:%SZ")).encode("ascii")
    dsout.time_coverage_end=(time_slice_start.strftime("%Y-%m-%dT%H:%M:%SZ")).encode("ascii")
    dsout.geospatial_lat_max=float(asm_fid.getncattr('NorthernmostLatitude'))
    dsout.geospatial_lat_min=float(asm_fid.getncattr('SouthernmostLatitude'))
    dsout.geospatial_lon_max=float(asm_fid.getncattr('EasternmostLongitude'))
    dsout.geospatial_lon_min=float(asm_fid.getncattr('WesternmostLongitude'))
    dsout.geospatial_lat_resolution=float(asm_fid.getncattr('LatitudeResolution'))
    dsout.geospatial_lon_resolution=float(asm_fid.getncattr('LongitudeResolution'))
    dsout.geospatial_lat_units='degrees_north'.encode("ascii")
    dsout.geospatial_lon_units='degrees_east'.encode("ascii")
    dsout.grid_mapping_name='latitude_longitude'.encode("ascii")
    dsout.software_version=(asm_fid.getncattr('Source')).encode("ascii")
    dsout.Conventions='CF-1.6'.encode("ascii")
    dsout.Metadata_Conventions='Unidata Dataset Discovery v1.0'.encode("ascii")
    dsout.license='https://science.nasa.gov/earth-science/earth-science-data/data-information-policy/'.encode("ascii")
    dsout.naming_authority='gov.nasa.gsfc.sci.oceandata'.encode("ascii")
    dsout.id=(asm_fid.getncattr('VersionID') + '/L4/' + ofile).encode("ascii")
    dsout.keywords_vocabulary='NASA Global Change Master Directory (GCMD) Science Keywords'.encode("ascii")
    dsout.keywords='ATMOSPHERE>ATMOSPHERIC CHEMISTRY>OXYGEN COMPOUNDS>OZONE;ATMOSPHERE>ATMOSPHERIC PRESSURE>SURFACE PRESSURE'.encode("ascii")
    dsout.standard_name_vocabulary='NetCDF Climate and Forecast (CF) Metadata Convention'.encode("ascii")
    dsout.institution='NASA Goddard Space Flight Center, Ocean Ecology Laboratory, Ocean Biology Processing Group'.encode("ascii")
    dsout.creator_name='NASA Global Modeling and Assimilation Office'.encode("ascii")
    dsout.creator_email='data@gmao.gsfc.nasa.gov'.encode("ascii")
    dsout.creator_url='https://gmao.gsfc.nasa.gov'.encode("ascii")
    dsout.project='Ocean Biology Processing Group (NASA/GSFC/OBPG)'.encode("ascii")
    dsout.publisher_name='NASA/GSFC/OBPG'.encode("ascii")
    dsout.publisher_url='https://oceandata.sci.gsfc.nasa.gov'.encode("ascii")
    dsout.publisher_email='data@oceancolor.gsfc.nasa.gov'.encode("ascii")
    if gmao_typ == merra2_typ:
        dsout.identifier_product_doi_authority='https://dx.doi.org'.encode("ascii")
        dsout.identifier_product_doi= (asm_fid.getncattr('identifier_product_doi')
                    + ';'+ lnd_fid.getncattr('identifier_product_doi')
                    + ';'+ ocn_fid.getncattr('identifier_product_doi')).encode("ascii")
    dsout.processing_level='L4'.encode("ascii")
    dsout.cdm_data_type='grid'.encode("ascii")
    dsout.spatialResolution = (asm_fid.getncattr('LatitudeResolution')
                    + 'x' +asm_fid.getncattr('LongitudeResolution') + ' degrees').encode("ascii")
    dsout.source=("{},{},{}".format(asm,lnd,ocn)).encode("ascii")

    dsout.comment=(
    "This file contains a subset of variables generated using the GMAO GEOS model (v{}) for use as input to the OBPG Level 2 data processing stream".format(asm_fid.getncattr('VersionID')))
#    dsout.OBPG_anc_id='OBPG_GMAO_FP-IT_std_anc_v1'.encode("ascii")

    # close the output file

    dsout.close()
    print("Created:\t{}".format(ofile))

if outlist:
    outlist.close()

asm_fid.close()
lnd_fid.close()
ocn_fid.close()
