source: TOOLS/MOSAIX/nemo.py @ 6446

Last change on this file since 6446 was 6245, checked in by omamce, 21 months ago

O.M. : MOSAIX, new functionnalitie in nemo.py

  • Property svn:keywords set to Date Revision HeadURL Author Id
File size: 67.9 KB
Line 
1# -*- coding: utf-8 -*-
2## ===========================================================================
3##
4##  This software is governed by the CeCILL  license under French law and
5##  abiding by the rules of distribution of free software.  You can  use,
6##  modify and/ or redistribute the software under the terms of the CeCILL
7##  license as circulated by CEA, CNRS and INRIA at the following URL
8##  "http://www.cecill.info".
9##
10##  Warning, to install, configure, run, use any of Olivier Marti's
11##  software or to read the associated documentation you'll need at least
12##  one (1) brain in a reasonably working order. Lack of this implement
13##  will void any warranties (either express or implied).
14##  O. Marti assumes no responsability for errors, omissions,
15##  data loss, or any other consequences caused directly or indirectly by
16##  the usage of his software by incorrectly or partially configured
17##  personal.
18##
19## ===========================================================================
20'''
21Utilities to plot NEMO ORCA fields
22Periodicity and other stuff
23
24olivier.marti@lsce.ipsl.fr
25'''
26
27## SVN information
28__Author__   = "$Author$"
29__Date__     = "$Date$"
30__Revision__ = "$Revision$"
31__Id__       = "$Id$"
32__HeadURL    = "$HeadURL$"
33
34import numpy as np
35try    : import xarray as xr
36except ImportError : pass
37
38try    : import f90nml
39except : pass
40
41try : from sklearn.impute import SimpleImputer
42except : pass
43
44try    : import numba
45except : pass
46
47rpi = np.pi ; rad = np.deg2rad (1.0) ; dar = np.rad2deg (1.0)
48
49nperio_valid_range = [0, 1, 4, 4.2, 5, 6, 6.2]
50
51rday   = 24.*60.*60.     # Day length [s]
52rsiyea = 365.25 * rday * 2. * rpi / 6.283076 # Sideral year length [s]
53rsiday = rday / (1. + rday / rsiyea)
54raamo  =  12.        # Number of months in one year
55rjjhh  =  24.        # Number of hours in one day
56rhhmm  =  60.        # Number of minutes in one hour
57rmmss  =  60.        # Number of seconds in one minute
58omega  = 2. * rpi / rsiday # Earth rotation parameter [s-1]
59ra     = 6371229.    # Earth radius [m]
60grav   = 9.80665     # Gravity [m/s2]
61repsi  = np.finfo (1.0).eps
62
63xList = [ 'x', 'X', 'lon'   , 'longitude' ]
64yList = [ 'y', 'Y', 'lat'   , 'latitude'  ]
65zList = [ 'z', 'Z', 'depth' , ]
66tList = [ 't', 'T', 'time'  , ]
67
68## ===========================================================================
69def __mmath__ (tab, default=None) :
70    mmath = default
71    try    :
72        if type (tab) == xr.core.dataarray.DataArray : mmath = xr
73    except :
74        pass
75
76    try    :
77        if type (tab) == np.ndarray : mmath = np
78    except :
79        pass
80           
81    return mmath
82
83
84def __guessNperio__ (jpj, jpi, nperio=None, out='nperio') :
85    '''
86    Tries to guess the value of nperio (periodicity parameter. See NEMO documentation for details)
87   
88    Inputs
89    jpj    : number of latitudes
90    jpi    : number of longitudes
91    nperio : periodicity parameter
92    '''
93    if nperio == None :
94        nperio = __guessConfig__ (jpj, jpi, nperio=None, out='nperio')
95   
96    return nperio
97
98def __guessConfig__ (jpj, jpi, nperio=None, config=None, out='nperio') :
99    '''
100    Tries to guess the value of nperio (periodicity parameter. See NEMO documentation for details)
101   
102    Inputs
103    jpj    : number of latitudes
104    jpi    : number of longitudes
105    nperio : periodicity parameter
106    '''
107    if nperio == None :
108        ## Values for NEMO version < 4.2
109        if jpj ==  149 and jpi ==  182 :
110            config = 'ORCA2.3'
111            nperio = 4   # ORCA2. We choose legacy orca2.
112            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'T'
113        if jpj ==  332 and jpi ==  362 : # eORCA1.
114            config = 'eORCA1.2'
115            nperio = 6 
116            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
117        if jpi == 1442 :  # ORCA025.
118            config = 'ORCA025'
119            nperio = 6 
120            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
121        if jpj ==  294 : # ORCA1
122            config = 'ORCA1'
123            nperio = 6
124            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
125           
126        ## Values for NEMO version >= 4.2. No more halo points
127        if jpj == 148 and jpi ==  180 :
128            config = 'ORCA2.4'
129            nperio = 4.2 # ORCA2. We choose legacy orca2.
130            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
131        if jpj == 331  and jpi ==  360 : # eORCA1.
132            config = 'eORCA1.4'
133            nperio = 6.2
134            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
135        if jpi == 1440 : # ORCA025.
136            config = 'ORCA025'
137            nperio = 6.2
138            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
139           
140        if nperio == None :
141            raise Exception  ('in nemo module : nperio not found, and cannot by guessed')
142        else :
143            if nperio in nperio_valid_range :
144                print ('nperio set as {:} (deduced from jpj={:d} jpi={:d})'.format (nperio, jpj, jpi))
145            else : 
146                raise ValueError ('nperio set as {:} (deduced from jpi={:d}) : nemo.py is not ready for this value'.format (nperio, jpi))
147
148    if out == 'nperio' : return nperio
149    if out == 'config' : return config
150    if out == 'perio'  : return Iperio, Jperio, NFold, NFtype
151    if out in ['full', 'all'] : return {'nperio':nperio, 'Iperio':Iperio, 'Jperio':Jperio, 'NFold':NFold, 'NFtype':NFtype}
152       
153def __guessPoint__ (ptab) :
154    '''
155    Tries to guess the grid point (periodicity parameter. See NEMO documentation for details)
156   
157    For array conforments with xgcm requirements
158
159    Inputs
160         ptab : xarray array
161
162    Credits : who is the original author ?
163    '''
164    gP = None
165    mmath = __mmath__ (ptab)
166    if mmath == xr :
167        if 'x_c' in ptab.dims and 'y_c' in ptab.dims                        : gP = 'T'
168        if 'x_f' in ptab.dims and 'y_c' in ptab.dims                        : gP = 'U'
169        if 'x_c' in ptab.dims and 'y_f' in ptab.dims                        : gP = 'V'
170        if 'x_f' in ptab.dims and 'y_f' in ptab.dims                        : gP = 'F'
171        if 'x_c' in ptab.dims and 'y_c' in ptab.dims and 'z_c' in ptab.dims : gP = 'T'
172        if 'x_c' in ptab.dims and 'y_c' in ptab.dims and 'z_f' in ptab.dims : gP = 'W'
173        if 'x_f' in ptab.dims and 'y_c' in ptab.dims and 'z_f' in ptab.dims : gP = 'U'
174        if 'x_c' in ptab.dims and 'y_f' in ptab.dims and 'z_f' in ptab.dims : gP = 'V'
175        if 'x_f' in ptab.dims and 'y_f' in ptab.dims and 'z_f' in ptab.dims : gP = 'F'
176             
177        if gP == None :
178            raise Exception ('in nemo module : cd_type not found, and cannot by guessed')
179        else :
180            print ('Grid set as', gP, 'deduced from dims ', ptab.dims)
181            return gP
182    else :
183         raise Exception  ('in nemo module : cd_type not found, input is not an xarray data')
184
185def lbc_diag (nperio) :
186    lperio = nperio ; aperio = False
187    if nperio == 4.2 :
188        lperio = 4 ; aperio = True
189    if nperio == 6.2 :
190        lperio = 6 ; aperio = True
191       
192    return lperio, aperio
193
194def __findAxis__ (tab, axis='z') :
195    '''
196    Find number and name of the requested axis
197    '''
198    mmath = __mmath__ (tab)
199    ix = None ; ax = None
200
201    if axis in xList :
202        axList = [ 'x', 'X',
203                   'lon', 'nav_lon', 'nav_lon_T', 'nav_lon_U', 'nav_lon_V', 'nav_lon_F', 'nav_lon_W',
204                   'x_grid_T', 'x_grid_U', 'x_grid_V', 'x_grid_F', 'x_grid_W',
205                   'glam', 'glamt', 'glamu', 'glamv', 'glamf', 'glamw' ]
206        unList = [ 'degrees_east' ]
207    if axis in yList :
208        axList = [ 'y', 'Y', 'lat',
209                   'nav_lat', 'nav_lat_T', 'nav_lat_U', 'nav_lat_V', 'nav_lat_F', 'nav_lat_W',
210                   'y_grid_T', 'y_grid_U', 'y_grid_V', 'y_grid_F', 'y_grid_W',
211                   'gphi', 'gphi', 'gphiu', 'gphiv', 'gphif', 'gphiw']
212        unList = [ 'degrees_north' ]
213    if axis in zList :
214        axList = [ 'z', 'Z',
215                   'depth', 'deptht', 'depthu', 'depthv', 'depthf', 'depthw',
216                   'olevel' ]
217        unList = [ 'm', 'meter' ]
218    if axis in tList :
219        axList = [ 't', 'T', 'time', 'time_counter' ]
220        unList = [ 'second', 'minute', 'hour', 'day', 'month' ]
221   
222    if mmath == xr :
223        for Name in axList :
224            try    :
225                ix = tab.dims.index (Name)
226                ax = Name
227            except : pass
228
229        for i, dim in enumerate (tab.dims) :
230            if 'units' in tab.coords[dim].attrs.keys() :
231                for name in unList :
232                    if name in tab.coords[dim].attrs['units'] :
233                        ix = i
234                        ax = dim
235    else :
236        if axis in xList : ix=-1
237        if axis in yList :
238            if len(tab.shape) >= 2 : ix=-2
239        if axis in zList :
240            if len(tab.shape) >= 3 : ix=-3
241        if axis in tList :
242            if len(tab.shape) >=3  : ix=-3
243            if len(tab.shape) >=4  : ix=-4
244       
245    return ix, ax
246
247#@numba.jit(forceobj=True)
248def fixed_lon (lon, center_lon=0.0) :
249    '''
250    Returns corrected longitudes for nicer plots
251
252    lon        : longitudes of the grid. At least 2D.
253    center_lon : center longitude. Default=0.
254
255    Designed by Phil Pelson. See https://gist.github.com/pelson/79cf31ef324774c97ae7
256    '''
257    mmath = __mmath__ (lon)
258   
259    fixed_lon = lon.copy ()
260       
261    fixed_lon = mmath.where (fixed_lon > center_lon+180., fixed_lon-360.0, fixed_lon)
262    fixed_lon = mmath.where (fixed_lon < center_lon-180., fixed_lon+360.0, fixed_lon)
263   
264    for i, start in enumerate (np.argmax (np.abs (np.diff (fixed_lon, axis=-1)) > 180., axis=-1)) :
265        fixed_lon [..., i, start+1:] += 360.
266
267    # Special case for eORCA025
268    if fixed_lon.shape [-1] == 1442 : fixed_lon [..., -2, :] = fixed_lon [..., -3, :]
269    if fixed_lon.shape [-1] == 1440 : fixed_lon [..., -1, :] = fixed_lon [..., -2, :]
270
271    if fixed_lon.min () > center_lon : fixed_lon += -360.0
272    if fixed_lon.max () < center_lon : fixed_lon +=  360.0
273       
274    if fixed_lon.min () < center_lon-360.0 : fixed_lon +=  360.0
275    if fixed_lon.max () > center_lon+360.0 : fixed_lon += -360.0
276               
277    return fixed_lon
278
279#@numba.jit(forceobj=True)
280def fill_empty (ztab, sval=np.nan, transpose=False) :
281    '''
282    Fill values
283
284    Useful when NEMO has run with no wet points options :
285    some parts of the domain, with no ocean points, has no
286    lon/lat values
287    '''
288    mmath = __mmath__ (ztab)
289
290    imp = SimpleImputer (missing_values=sval, strategy='mean')
291    if transpose :
292        imp.fit (ztab.T)
293        ptab = imp.transform (ztab.T).T
294    else : 
295        imp.fit (ztab)
296        ptab = imp.transform (ztab)
297   
298    if mmath == xr :
299        ptab = xr.DataArray (ptab, dims=ztab.dims, coords=ztab.coords)
300        ptab.attrs = ztab.attrs
301       
302    return ptab
303
304#@numba.jit(forceobj=True)
305def fill_lonlat (lon, lat, sval=-1) :
306    '''
307    Fill longitude/latitude values
308
309    Useful when NEMO has run with no wet points options :
310    some parts of the domain, with no ocean points, as no
311    lon/lat values
312    '''
313    mmath = __mmath__ (lon)
314
315    imp = SimpleImputer (missing_values=sval, strategy='mean')
316    imp.fit (lon)
317    plon = imp.transform (lon)
318    imp.fit (lat.T)
319    plat = imp.transform (lat.T).T
320
321    if mmath == xr :
322        plon = xr.DataArray (plon, dims=lon.dims, coords=lon.coords)
323        plat = xr.DataArray (plat, dims=lat.dims, coords=lat.coords)
324        plon.attrs = lon.attrs ; plat.attrs = lat.attrs
325       
326    plon = fixed_lon (plon)
327   
328    return plon, plat
329
330#@numba.jit(forceobj=True)
331def jeq (lat) :
332    '''
333    Returns j index of equator in the grid
334   
335    lat : latitudes of the grid. At least 2D.
336    '''
337    mmath = __mmath__ (lat)
338    ix, ax = __findAxis__ (lat, 'x')
339    iy, ay = __findAxis__ (lat, 'y')
340
341    if mmath == xr :
342        jeq = int ( np.mean ( np.argmin (np.abs (np.float64 (lat)), axis=iy) ) )
343    else : 
344        jeq = np.argmin (np.abs (np.float64 (lat[...,:, 0])))
345    return jeq
346
347#@numba.jit(forceobj=True)
348def lon1D (lon, lat=None) :
349    '''
350    Returns 1D longitude for simple plots.
351   
352    lon : longitudes of the grid
353    lat (optionnal) : latitudes of the grid
354    '''
355    mmath = __mmath__ (lon)
356    if np.max (lat) != None :
357        je    = jeq (lat)
358        lon1D = lon.copy() [..., je, :]
359    else :
360        jpj, jpi = lon.shape [-2:]
361        lon1D    = lon.copy() [..., jpj//3, :]
362
363    start = np.argmax (np.abs (np.diff (lon1D, axis=-1)) > 180.0, axis=-1)
364    lon1D [..., start+1:] += 360
365
366    if mmath == xr :
367        lon1D.attrs = lon.attrs
368        lon1D = lon1D.assign_coords ( {'x':lon1D} )
369       
370    return lon1D
371
372#@numba.jit(forceobj=True)
373def latreg (lat, diff=0.1) :
374    '''
375    Returns maximum j index where gridlines are along latitudes in the northern hemisphere
376   
377    lat : latitudes of the grid (2D)
378    diff [optional] : tolerance
379    '''
380    mmath = __mmath__ (lat)
381    if diff == None :
382        dy   = np.float64 (np.mean (np.abs (lat - np.roll (lat,shift=1,axis=-2, roll_coords=False))))
383        diff = dy/100.
384   
385    je     = jeq (lat)
386    jreg   = np.where (lat[...,je:,:].max(axis=-1) - lat[...,je:,:].min(axis=-1)< diff)[-1][-1] + je
387    latreg = np.float64 (lat[...,jreg,:].mean(axis=-1))
388    JREG   = jreg
389
390    return jreg, latreg
391
392#@numba.jit(forceobj=True)
393def lat1D (lat) :
394    '''
395    Returns 1D latitudes for zonal means and simple plots.
396
397    lat : latitudes of the grid (2D)
398    '''
399    mmath = __mmath__ (lat)
400    jpj, jpi = lat.shape[-2:]
401
402    dy     = np.float64 (np.mean (np.abs (lat - np.roll (lat, shift=1,axis=-2))))
403    je     = jeq (lat)
404    lat_eq = np.float64 (lat[...,je,:].mean(axis=-1))
405     
406    jreg, lat_reg = latreg (lat)
407    lat_ave = np.mean (lat, axis=-1)
408
409    if (np.abs (lat_eq) < dy/100.) : # T, U or W grid
410        dys    = (90.-lat_reg) / (jpj-jreg-1)*0.5
411        yrange = 90.-dys-lat_reg
412    else                           :  # V or F grid
413        yrange = 90.    -lat_reg
414       
415    lat1D = mmath.where (lat_ave<lat_reg, lat_ave, lat_reg + yrange * (np.arange(jpj)-jreg)/(jpj-jreg-1))   
416       
417    if mmath == xr :
418        lat1D.attrs = lat.attrs
419        lat1D = lat1D.assign_coords ( {'y':lat1D} )
420
421    return lat1D
422
423#@numba.jit(forceobj=True)
424def latlon1D (lat, lon) :
425    '''
426    Returns simple latitude and longitude (1D) for simple plots.
427
428    lat, lon : latitudes and longitudes of the grid (2D)
429    '''
430    return lat1D (lat),  lon1D (lon, lat)
431
432##@numba.jit(forceobj=True)
433def mask_lonlat (ptab, x0, x1, y0, y1, lon, lat, sval=np.nan) :
434    mmath = __mmath__ (ptab)
435    try :
436        lon = lon.copy().to_masked_array()
437        lat = lat.copy().to_masked_array()
438    except : pass
439           
440    mask = np.logical_and (np.logical_and(lat>y0, lat<y1), 
441            np.logical_or (np.logical_or (np.logical_and(lon>x0, lon<x1), np.logical_and(lon+360>x0, lon+360<x1)),
442                                      np.logical_and(lon-360>x0, lon-360<x1)))
443    tab = mmath.where (mask, ptab, np.nan)
444   
445    return tab
446
447#@numba.jit(forceobj=True)     
448def extend (tab, Lon=False, jplus=25, jpi=None, nperio=4) :
449    '''
450    Returns extended field eastward to have better plots, and box average crossing the boundary
451    Works only for xarray and numpy data (?)
452
453    tab : field to extend.
454    Lon : (optional, default=False) : if True, add 360 in the extended parts of the field
455    jpi : normal longitude dimension of the field. exrtend does nothing it the actual
456        size of the field != jpi (avoid to extend several times)
457    jplus (optional, default=25) : number of points added on the east side of the field
458   
459    '''
460    mmath = __mmath__ (tab)
461   
462    if tab.shape[-1] == 1 : extend = tab
463
464    else :
465        if jpi == None : jpi = tab.shape[-1]
466
467        if Lon : xplus = -360.0
468        else   : xplus =    0.0
469
470        if tab.shape[-1] > jpi :
471            extend = tab
472        else :
473            if nperio == 0 or nperio == 4.2 :
474                istart = 0 ; le=jpi+1 ; la=0
475            if nperio == 1 :
476                istart = 0 ; le=jpi+1 ; la=0
477            if nperio == 4 or nperio == 6 : # OPA case with two halo points for periodicity
478                istart = 1 ; le=jpi-2 ; la=1  # Perfect, except at the pole that should be masked by lbc_plot
479           
480            if mmath == xr :
481                extend = np.concatenate ((tab.values[..., istart   :istart+le+1    ] + xplus,
482                                          tab.values[..., istart+la:istart+la+jplus]         ), axis=-1)
483                lon    = tab.dims[-1]
484                new_coords = []
485                for coord in tab.dims :
486                    if coord == lon : new_coords.append ( np.arange( extend.shape[-1]))
487                    else            : new_coords.append ( tab.coords[coord].values)
488                extend = xr.DataArray ( extend, dims=tab.dims, coords=new_coords )
489            else : 
490                extend = np.concatenate ((tab [..., istart   :istart+le+1    ] + xplus,
491                                          tab [..., istart+la:istart+la+jplus]          ), axis=-1)
492    return extend
493
494def orca2reg (ff, lat_name='nav_lat', lon_name='nav_lon', y_name='y', x_name='x') :
495    '''
496    Assign an ORCA dataset on a regular grid.
497    For use in the tropical region.
498   
499    Inputs :
500      ff : xarray dataset
501      lat_name, lon_name : name of latitude and longitude 2D field in ff
502      y_name, x_name     : namex of dimensions in ff
503     
504      Returns : xarray dataset with rectangular grid. Incorrect above 20°N
505    '''
506    # Compute 1D longitude and latitude
507    (lat, lon) = latlon1D (ff[lat_name], ff[lon_name])
508
509    # Assign lon and lat as dimensions of the dataset
510    if y_name in ff.dims : 
511        lat = xr.DataArray (lat, coords=[lat,], dims=['lat',])     
512        ff  = ff.rename_dims ({y_name: "lat",}).assign_coords (lat=lat)
513    if x_name in ff.dims :
514        lon = xr.DataArray (lon, coords=[lon,], dims=['lon',])
515        ff  = ff.rename_dims ({x_name: "lon",}).assign_coords (lon=lon)
516    # Force dimensions to be in the right order
517    coord_order = ['lat', 'lon']
518    for dim in [ 'depthw', 'depthv', 'depthu', 'deptht', 'depth', 'z',
519                 'time_counter', 'time', 'tbnds', 
520                 'bnds', 'axis_nbounds', 'two2', 'two1', 'two', 'four',] :
521        if dim in ff.dims : coord_order.insert (0, dim)
522       
523    ff = ff.transpose (*coord_order)
524    return ff
525
526def lbc_init (ptab, nperio=None) :
527    '''
528    Prepare for all lbc calls
529   
530    Set periodicity on input field
531    nperio    : Type of periodicity
532      0       : No periodicity
533      1, 4, 6 : Cyclic on i dimension (generaly longitudes) with 2 points halo
534      2       : Obsolete (was symmetric condition at southern boundary ?)
535      3, 4    : North fold T-point pivot (legacy ORCA2)
536      5, 6    : North fold F-point pivot (ORCA1, ORCA025, ORCA2 with new grid for paleo)
537    cd_type   : Grid specification : T, U, V or F
538
539    See NEMO documentation for further details
540    '''
541    jpj, jpi = ptab.shape[-2:]
542    if nperio == None : nperio = __guessNperio__ (jpj, jpi, nperio)
543   
544    if nperio not in nperio_valid_range :
545        raise Exception ('nperio=', nperio, ' is not in the valid range', nperio_valid_range)
546
547    return jpj, jpi, nperio
548       
549#@numba.jit(forceobj=True)
550def lbc (ptab, nperio=None, cd_type='T', psgn=1.0, nemo_4U_bug=False) :
551    '''
552    Set periodicity on input field
553    ptab      : Input array (works for rank 2 at least : ptab[...., lat, lon])
554    nperio    : Type of periodicity
555    cd_type   : Grid specification : T, U, V or F
556    psgn      : For change of sign for vector components (1 for scalars, -1 for vector components)
557   
558    See NEMO documentation for further details
559    '''
560    jpj, jpi, nperio = lbc_init (ptab, nperio)
561    psgn   = ptab.dtype.type (psgn)
562    mmath = __mmath__ (ptab)
563   
564    if mmath == xr : ztab = ptab.values.copy ()
565    else           : ztab = ptab.copy ()
566       
567    #
568    #> East-West boundary conditions
569    # ------------------------------
570    if nperio in [1, 4, 6] :
571        # ... cyclic
572        ztab [..., :,  0] = ztab [..., :, -2]
573        ztab [..., :, -1] = ztab [..., :,  1]
574    #
575    #> North-South boundary conditions
576    # --------------------------------
577    if nperio in [3, 4] :  # North fold T-point pivot
578        if cd_type in [ 'T', 'W' ] : # T-, W-point
579            ztab [..., -1, 1:       ] = psgn * ztab [..., -3, -1:0:-1      ]
580            ztab [..., -1, 0        ] = psgn * ztab [..., -3, 2            ]
581            ztab [..., -2, jpi//2:  ] = psgn * ztab [..., -2, jpi//2:0:-1  ]
582               
583        if cd_type == 'U' :
584            ztab [..., -1, 0:-1     ] = psgn * ztab [..., -3, -1:0:-1      ]       
585            ztab [..., -1,  0       ] = psgn * ztab [..., -3,  1           ]
586            ztab [..., -1, -1       ] = psgn * ztab [..., -3, -2           ]
587               
588            if nemo_4U_bug :
589                ztab [..., -2, jpi//2+1:-1] = psgn * ztab [..., -2, jpi//2-2:0:-1]
590                ztab [..., -2, jpi//2-1   ] = psgn * ztab [..., -2, jpi//2       ]
591            else :
592                ztab [..., -2, jpi//2-1:-1] = psgn * ztab [..., -2, jpi//2:0:-1]
593               
594        if cd_type == 'V' : 
595            ztab [..., -2, 1:       ] = psgn * ztab [..., -3, jpi-1:0:-1   ]
596            ztab [..., -1, 1:       ] = psgn * ztab [..., -4, -1:0:-1      ]   
597            ztab [..., -1, 0        ] = psgn * ztab [..., -4, 2            ]
598               
599        if cd_type == 'F' :
600            ztab [..., -2, 0:-1     ] = psgn * ztab [..., -3, -1:0:-1      ]
601            ztab [..., -1, 0:-1     ] = psgn * ztab [..., -4, -1:0:-1      ]
602            ztab [..., -1,  0       ] = psgn * ztab [..., -4,  1           ]
603            ztab [..., -1, -1       ] = psgn * ztab [..., -4, -2           ]
604
605    if nperio in [4.2] :  # North fold T-point pivot
606        if cd_type in [ 'T', 'W' ] : # T-, W-point
607            ztab [..., -1, jpi//2:  ] = psgn * ztab [..., -1, jpi//2:0:-1  ]
608               
609        if cd_type == 'U' :
610            ztab [..., -1, jpi//2-1:-1] = psgn * ztab [..., -1, jpi//2:0:-1]
611               
612        if cd_type == 'V' : 
613            ztab [..., -1, 1:       ] = psgn * ztab [..., -2, jpi-1:0:-1   ]
614               
615        if cd_type == 'F' :
616            ztab [..., -1, 0:-1     ] = psgn * ztab [..., -2, -1:0:-1      ]
617
618    if nperio in [5, 6] :            #  North fold F-point pivot 
619        if cd_type in ['T', 'W']  :
620            ztab [..., -1, 0:       ] = psgn * ztab [..., -2, -1::-1       ]
621               
622        if cd_type == 'U' :
623            ztab [..., -1, 0:-1     ] = psgn * ztab [..., -2, -2::-1       ]       
624            ztab [..., -1, -1       ] = psgn * ztab [..., -2, 0            ] # Bug ?
625               
626        if cd_type == 'V' :
627            ztab [..., -1, 0:       ] = psgn * ztab [..., -3, -1::-1       ]
628            ztab [..., -2, jpi//2:  ] = psgn * ztab [..., -2, jpi//2-1::-1 ]
629               
630        if cd_type == 'F' :
631            ztab [..., -1, 0:-1     ] = psgn * ztab [..., -3, -2::-1       ]
632            ztab [..., -1, -1       ] = psgn * ztab [..., -3, 0            ]
633            ztab [..., -2, jpi//2:-1] = psgn * ztab [..., -2, jpi//2-2::-1 ]
634
635    #
636    #> East-West boundary conditions
637    # ------------------------------
638    if nperio in [1, 4, 6] :
639        # ... cyclic
640        ztab [..., :,  0] = ztab [..., :, -2]
641        ztab [..., :, -1] = ztab [..., :,  1]
642
643    if mmath == xr :
644        ztab = xr.DataArray ( ztab, dims=ptab.dims, coords=ptab.coords)
645        ztab.attrs = ptab.attrs
646       
647    return ztab
648
649#@numba.jit(forceobj=True)
650def lbc_mask (ptab, nperio=None, cd_type='T', sval=np.nan) :
651    #
652    '''
653    Mask fields on duplicated points
654    ptab      : Input array. Rank 2 at least : ptab [...., lat, lon]
655    nperio    : Type of periodicity
656    cd_type   : Grid specification : T, U, V or F
657   
658    See NEMO documentation for further details
659    '''
660    jpj, jpi, nperio = lbc_init (ptab, nperio)
661    ztab = ptab.copy ()
662
663    #
664    #> East-West boundary conditions
665    # ------------------------------
666    if nperio in [1, 4, 6] :
667        # ... cyclic
668        ztab [..., :,  0] = sval
669        ztab [..., :, -1] = sval
670
671    #
672    #> South (in which nperio cases ?)
673    # --------------------------------
674    if nperio in [1, 3, 4, 5, 6] :
675        ztab [..., 0, :] = sval
676       
677    #
678    #> North-South boundary conditions
679    # --------------------------------
680    if nperio in [3, 4] :  # North fold T-point pivot
681        if cd_type in [ 'T', 'W' ] : # T-, W-point
682            ztab [..., -1,  :         ] = sval
683            ztab [..., -2, :jpi//2  ] = sval
684               
685        if cd_type == 'U' :
686            ztab [..., -1,  :         ] = sval 
687            ztab [..., -2, jpi//2+1:  ] = sval
688               
689        if cd_type == 'V' :
690            ztab [..., -2, :       ] = sval
691            ztab [..., -1, :       ] = sval   
692               
693        if cd_type == 'F' :
694            ztab [..., -2, :       ] = sval
695            ztab [..., -1, :       ] = sval
696
697    if nperio in [4.2] :  # North fold T-point pivot
698        if cd_type in [ 'T', 'W' ] : # T-, W-point
699            ztab [..., -1, jpi//2  :  ] = sval
700               
701        if cd_type == 'U' :
702            ztab [..., -1, jpi//2-1:-1] = sval
703               
704        if cd_type == 'V' : 
705            ztab [..., -1, 1:       ] = sval
706               
707        if cd_type == 'F' :
708            ztab [..., -1, 0:-1     ] = sval
709   
710    if nperio in [5, 6] :            #  North fold F-point pivot
711        if cd_type in ['T', 'W']  :
712            ztab [..., -1, 0:       ] = sval
713               
714        if cd_type == 'U' :
715            ztab [..., -1, 0:-1     ] = sval       
716            ztab [..., -1, -1       ] = sval
717             
718        if cd_type == 'V' :
719            ztab [..., -1, 0:       ] = sval
720            ztab [..., -2, jpi//2:  ] = sval
721                             
722        if cd_type == 'F' :
723            ztab [..., -1, 0:-1       ] = sval
724            ztab [..., -1, -1         ] = sval
725            ztab [..., -2, jpi//2+1:-1] = sval
726
727    return ztab
728
729#@numba.jit(forceobj=True)
730def lbc_plot (ptab, nperio=None, cd_type='T', psgn=1.0, sval=np.nan) :
731    '''
732    Set periodicity on input field, adapted for plotting for any cartopy projection
733    ptab      : Input array. Rank 2 at least : ptab[...., lat, lon]
734    nperio    : Type of periodicity
735    cd_type   : Grid specification : T, U, V or F
736    psgn      : For change of sign for vector components (1 for scalars, -1 for vector components)
737   
738    See NEMO documentation for further details
739    '''
740
741    jpj, jpi, nperio = lbc_init (ptab, nperio)
742    psgn   = ptab.dtype.type (psgn)
743    ztab   = ptab.copy ()
744    #
745    #> East-West boundary conditions
746    # ------------------------------
747    if nperio in [1, 4, 6] :
748        # ... cyclic
749        ztab [..., :,  0] = ztab [..., :, -2]
750        ztab [..., :, -1] = ztab [..., :,  1]
751
752    #> Masks south
753    # ------------
754    if nperio in [4, 6] : ztab [..., 0, : ] = sval
755       
756    #
757    #> North-South boundary conditions
758    # --------------------------------
759    if nperio in [3, 4] :  # North fold T-point pivot
760        if cd_type in [ 'T', 'W' ] : # T-, W-point
761            ztab [..., -1,  :      ] = sval
762            #ztab [..., -2, jpi//2: ] = sval
763            ztab [..., -2, :jpi//2 ] = sval # Give better plots than above
764        if cd_type == 'U' :
765            ztab [..., -1, : ] = sval
766
767        if cd_type == 'V' : 
768            ztab [..., -2, : ] = sval
769            ztab [..., -1, : ] = sval
770           
771        if cd_type == 'F' :
772            ztab [..., -2, : ] = sval
773            ztab [..., -1, : ] = sval
774
775    if nperio in [4.2] :  # North fold T-point pivot
776        if cd_type in [ 'T', 'W' ] : # T-, W-point
777            ztab [..., -1, jpi//2:  ] = sval
778               
779        if cd_type == 'U' :
780            ztab [..., -1, jpi//2-1:-1] = sval
781               
782        if cd_type == 'V' : 
783            ztab [..., -1, 1:       ] = sval
784               
785        if cd_type == 'F' :
786            ztab [..., -1, 0:-1     ] = sval
787     
788    if nperio in [5, 6] :            #  North fold F-point pivot 
789        if cd_type in ['T', 'W']  :
790            ztab [..., -1, : ] = sval
791               
792        if cd_type == 'U' :
793            ztab [..., -1, : ] = sval     
794             
795        if cd_type == 'V' :
796            ztab [..., -1, :        ] = sval
797            ztab [..., -2, jpi//2:  ] = sval
798                             
799        if cd_type == 'F' :
800            ztab [..., -1, :          ] = sval
801            ztab [..., -2, jpi//2+1:-1] = sval
802
803    return ztab
804
805#@numba.jit(forceobj=True)
806def lbc_add (ptab, nperio=None, cd_type=None, psgn=1, sval=None) :
807    '''
808    Handle NEMO domain changes between NEMO 4.0 to NEMO 4.2
809      Peridodicity halo has been removed
810    This routine adds the halos if needed
811
812    ptab      : Input array (works
813      rank 2 at least : ptab[...., lat, lon]
814    nperio    : Type of periodicity
815 
816    See NEMO documentation for further details
817    '''
818    mmath = __mmath__ (ptab) 
819    jpj, jpi, nperio = lbc_init (ptab, nperio)
820
821    t_shape = np.array (ptab.shape)
822
823    if nperio == 4.2 or nperio == 6.2 :
824     
825        ext_shape = t_shape
826        ext_shape[-1] = ext_shape[-1] + 2
827        ext_shape[-2] = ext_shape[-2] + 1
828
829        if mmath == xr :
830            ptab_ext = xr.DataArray (np.zeros (ext_shape), dims=ptab.dims) 
831            ptab_ext.values[..., :-1, 1:-1] = ptab.values.copy ()
832        else           :
833            ptab_ext =               np.zeros (ext_shape)
834            ptab_ext[..., :-1, 1:-1] = ptab.copy ()
835           
836        #if sval != None :  ptab_ext[..., 0, :] = sval
837       
838        if nperio == 4.2 : ptab_ext = lbc (ptab_ext, nperio=4, cd_type=cd_type, psgn=psgn)
839        if nperio == 6.2 : ptab_ext = lbc (ptab_ext, nperio=6, cd_type=cd_type, psgn=psgn)
840             
841        if mmath == xr :
842            ptab_ext.attrs = ptab.attrs
843
844    else : ptab_ext = lbc (ptab, nperio=nperio, cd_type=cd_type, psgn=psgn)
845       
846    return ptab_ext
847
848def lbc_del (ptab, nperio=None, cd_type='T', psgn=1) :
849    '''
850    Handle NEMO domain changes between NEMO 4.0 to NEMO 4.2
851      Periodicity halo has been removed
852    This routine removes the halos if needed
853
854    ptab      : Input array (works
855      rank 2 at least : ptab[...., lat, lon]
856    nperio    : Type of periodicity
857 
858    See NEMO documentation for further details
859    '''
860
861    jpj, jpi, nperio = lbc_init (ptab, nperio)
862
863    if nperio == 4.2 or nperio == 6.2 :
864        return lbc (ptab[..., :-1, 1:-1], nperio=nperio, cd_type=cd_type, psgn=psgn)
865    else :
866        return ptab
867
868#@numba.jit(forceobj=True)
869def lbc_index (jj, ii, jpj, jpi, nperio=None, cd_type='T') :
870    '''
871    For indexes of a NEMO point, give the corresponding point inside the util domain
872    jj, ii    : indexes
873    jpi, jpi  : size of domain
874    nperio    : type of periodicity
875    cd_type   : grid specification : T, U, V or F
876   
877    See NEMO documentation for further details
878    '''
879
880    if nperio == None : nperio = __guessNperio__ (jpj, jpi, nperio)
881   
882    ## For the sake of simplicity, switch to the convention of original lbc Fortran routine from NEMO
883    ## : starts indexes at 1
884    jy = jj + 1 ; ix = ii + 1
885
886    mmath = __mmath__ (jj)
887    if mmath == None : mmath=np
888
889    #
890    #> East-West boundary conditions
891    # ------------------------------
892    if nperio in [1, 4, 6] :
893        #... cyclic
894        ix = mmath.where (ix==jpi, 2   , ix)
895        ix = mmath.where (ix== 1 ,jpi-1, ix)
896
897    #
898    def modIJ (cond, jy_new, ix_new) :
899        jy_r = mmath.where (cond, jy_new, jy)
900        ix_r = mmath.where (cond, ix_new, ix)
901        return jy_r, ix_r
902    #
903    #> North-South boundary conditions
904    # --------------------------------
905    if nperio in [ 3 , 4 ]  :
906        if cd_type in  [ 'T' , 'W' ] :
907            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix>=2       ), jpj-2, jpi-ix+2)
908            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==1       ), jpj-1, 3       )   
909            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, ix>=jpi//2+1), jy   , jpi-ix+2) 
910
911        if cd_type in [ 'U' ] :
912            (jy, ix) = modIJ (np.logical_and (jy==jpj  , np.logical_and (ix>=1, ix <= jpi-1)   ), jy   , jpi-ix+1)
913            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==1  )                               , jpj-2, 2       )
914            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==jpi)                               , jpj-2, jpi-1   )
915            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, np.logical_and (ix>=jpi//2, ix<=jpi-1)), jy   , jpi-ix+1)
916         
917        if cd_type in [ 'V' ] :
918            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, ix>=2  ), jpj-2, jpi-ix+2)
919            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix>=2  ), jpj-3, jpi-ix+2)
920            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==1  ), jpj-3,  3      )
921           
922        if cd_type in [ 'F' ] :
923            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, ix<=jpi-1), jpj-2, jpi-ix+1)
924            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix<=jpi-1), jpj-3, jpi-ix+1)
925            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==1    ), jpj-3, 2       )
926            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==jpi  ), jpj-3, jpi-1   )
927
928    if nperio in [ 5 , 6 ] :
929        if cd_type in [ 'T' , 'W' ] :                        # T-, W-point
930             (jy, ix) = modIJ (jy==jpj, jpj-1, jpi-ix+1)
931 
932        if cd_type in [ 'U' ] :                              # U-point
933            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix<=jpi-1   ), jpj-1, jpi-ix  )
934            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==jpi     ), jpi-1, 1       )
935           
936        if cd_type in [ 'V' ] :    # V-point
937            (jy, ix) = modIJ (jy==jpj                                 , jy   , jpi-ix+1)
938            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, ix>=jpi//2+1), jy   , jpi-ix+1)
939           
940        if cd_type in [ 'F' ] :                              # F-point
941            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix<=jpi-1   ), jpj-2, jpi-ix  )
942            (jy, ix) = modIJ (np.logical_and (ix==jpj  , ix==jpi     ), jpj-2, 1       )
943            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, ix>=jpi//2+1), jy   , jpi-ix  )
944
945    ## Restore convention to Python/C : indexes start at 0
946    jy += -1 ; ix += -1
947
948    if isinstance (jj, int) : jy = jy.item ()
949    if isinstance (ii, int) : ix = ix.item ()
950
951    return jy, ix
952   
953def geo2en (pxx, pyy, pzz, glam, gphi) : 
954    '''
955    Change vector from geocentric to east/north
956
957    Inputs :
958        pxx, pyy, pzz : components on the geocentric system
959        glam, gphi : longitude and latitude of the points
960    '''
961
962    gsinlon = np.sin (rad * glam)
963    gcoslon = np.cos (rad * glam)
964    gsinlat = np.sin (rad * gphi)
965    gcoslat = np.cos (rad * gphi)
966         
967    pte = - pxx * gsinlon            + pyy * gcoslon
968    ptn = - pxx * gcoslon * gsinlat  - pyy * gsinlon * gsinlat + pzz * gcoslat
969
970    return pte, ptn
971
972def en2geo (pte, ptn, glam, gphi) :
973    '''
974    Change vector from east/north to geocentric
975
976    Inputs :
977        pte, ptn   : eastward/northward components
978        glam, gphi : longitude and latitude of the points
979    '''
980   
981    gsinlon = np.sin (rad * glam)
982    gcoslon = np.cos (rad * glam)
983    gsinlat = np.sin (rad * gphi)
984    gcoslat = np.cos (rad * gphi)
985
986    pxx = - pte * gsinlon - ptn * gcoslon * gsinlat
987    pyy =   pte * gcoslon - ptn * gsinlon * gsinlat
988    pzz =   ptn * gcoslat
989   
990    return pxx, pyy, pzz
991
992def findJI (lat_data, lon_data, lat_grid, lon_grid, mask=1.0, verbose=False) :
993    '''
994    Description: seeks J,I indices of the grid point which is the closest of a given point
995    Usage: go FindJI  <data latitude> <data longitude> <grid latitudes> <grid longitudes> [mask]
996    <longitude fields> <latitude field> are 2D fields on J/I (Y/X) dimensions
997    mask : if given, seek only non masked grid points (i.e with mask=1)
998   
999    Example : findIJ (40, -20, nav_lat, nav_lon, mask=1.0)
1000
1001    Note : all longitudes and latitudes in degrees
1002       
1003    Note : may work with 1D lon/lat (?)
1004    '''
1005    # Get grid dimensions
1006    if len (lon_grid.shape) == 2 : (jpj, jpi) = lon_grid.shape
1007    else                         : jpj = len(lat_grid) ; jpi=len(lon_grid)
1008
1009    mmath = __mmath__ (lat_grid)
1010       
1011    # Compute distance from the point to all grid points (in radian)
1012    arg      = np.sin (rad*lat_data) * np.sin (rad*lat_grid) \
1013             + np.cos (rad*lat_data) * np.cos (rad*lat_grid) * np.cos(rad*(lon_data-lon_grid))
1014    distance = np.arccos (arg) + 4.0*rpi*(1.0-mask) # Send masked points to 'infinite'
1015
1016    # Truncates to alleviate some precision problem with some grids
1017    prec = int (1E7)
1018    distance = (distance*prec).astype(int) / prec
1019
1020    # Compute minimum of distance, and index of minimum
1021    #
1022    distance_min = distance.min    ()
1023    jimin        = int (distance.argmin ())
1024   
1025    # Compute 2D indices
1026    jmin = jimin // jpi ; imin = jimin - jmin*jpi
1027
1028    # Compute distance achieved
1029    mindist = distance[jmin, imin]
1030   
1031    # Compute azimuth
1032    dlon = lon_data-lon_grid[jmin,imin]
1033    arg  = np.sin (rad*dlon) /  (np.cos(rad*lat_data)*np.tan(rad*lat_grid[jmin,imin]) - np.sin(rad*lat_data)*np.cos(rad*dlon))
1034    azimuth = dar*np.arctan (arg)
1035   
1036    # Result
1037    if verbose : 
1038        print ('I={:d} J={:d} - Data:{:5.1f}°N {:5.1f}°E - Grid:{:4.1f}°N {:4.1f}°E - Dist: {:6.1f}km {:5.2f}° - Azimuth: {:3.2f}rad - {:5.1f}°'
1039            .format (imin, jmin, lat_data, lon_data, lat_grid[jmin,imin], lon_grid[jmin,imin], ra*distance[jmin,imin], dar*distance[jmin,imin], rad*azimuth, azimuth))
1040
1041    return jmin, imin
1042
1043def clo_lon (lon, lon0) :
1044    '''Choose closest to lon0 longitude, adding or substacting 360° if needed'''
1045    mmath = __mmath__ (lon, np)
1046       
1047    clo_lon = lon
1048    clo_lon = mmath.where (clo_lon > lon0 + 180., clo_lon-360., clo_lon)
1049    clo_lon = mmath.where (clo_lon < lon0 - 180., clo_lon+360., clo_lon)
1050    clo_lon = mmath.where (clo_lon > lon0 + 180., clo_lon-360., clo_lon)
1051    clo_lon = mmath.where (clo_lon < lon0 - 180., clo_lon+360., clo_lon)
1052    if clo_lon.shape == () : clo_lon = clo_lon.item ()
1053    return clo_lon
1054
1055def angle_full (glamt, gphit, glamu, gphiu, glamv, gphiv, glamf, gphif, nperio=None) :
1056    '''Compute sinus and cosinus of model line direction with respect to east'''
1057    mmath = __mmath__ (glamt)
1058
1059    zlamt = lbc_add (glamt, nperio, 'T', 1.)
1060    zphit = lbc_add (gphit, nperio, 'T', 1.)
1061    zlamu = lbc_add (glamu, nperio, 'U', 1.)
1062    zphiu = lbc_add (gphiu, nperio, 'U', 1.)
1063    zlamv = lbc_add (glamv, nperio, 'V', 1.)
1064    zphiv = lbc_add (gphiv, nperio, 'V', 1.)
1065    zlamf = lbc_add (glamf, nperio, 'F', 1.)
1066    zphif = lbc_add (gphif, nperio, 'F', 1.)
1067   
1068    # north pole direction & modulous (at T-point)
1069    zxnpt = 0. - 2.0 * np.cos (rad*zlamt) * np.tan (rpi/4.0 - rad*zphit/2.0)
1070    zynpt = 0. - 2.0 * np.sin (rad*zlamt) * np.tan (rpi/4.0 - rad*zphit/2.0)
1071    znnpt = zxnpt*zxnpt + zynpt*zynpt
1072   
1073    # north pole direction & modulous (at U-point)
1074    zxnpu = 0. - 2.0 * np.cos (rad*zlamu) * np.tan (rpi/4.0 - rad*zphiu/2.0)
1075    zynpu = 0. - 2.0 * np.sin (rad*zlamu) * np.tan (rpi/4.0 - rad*zphiu/2.0)
1076    znnpu = zxnpu*zxnpu + zynpu*zynpu
1077   
1078    # north pole direction & modulous (at V-point)
1079    zxnpv = 0. - 2.0 * np.cos (rad*zlamv) * np.tan (rpi/4.0 - rad*zphiv/2.0)
1080    zynpv = 0. - 2.0 * np.sin (rad*zlamv) * np.tan (rpi/4.0 - rad*zphiv/2.0)
1081    znnpv = zxnpv*zxnpv + zynpv*zynpv
1082
1083    # north pole direction & modulous (at F-point)
1084    zxnpf = 0. - 2.0 * np.cos( rad*zlamf ) * np.tan ( rpi/4. - rad*zphif/2. )
1085    zynpf = 0. - 2.0 * np.sin( rad*zlamf ) * np.tan ( rpi/4. - rad*zphif/2. )
1086    znnpf = zxnpf*zxnpf + zynpf*zynpf
1087
1088    # j-direction: v-point segment direction (around T-point)
1089    zlam = zlamv 
1090    zphi = zphiv
1091    zlan = np.roll ( zlamv, axis=-2, shift=1)  # glamv (ji,jj-1)
1092    zphh = np.roll ( zphiv, axis=-2, shift=1)  # gphiv (ji,jj-1)
1093    zxvvt =  2.0 * np.cos ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1094          -  2.0 * np.cos ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1095    zyvvt =  2.0 * np.sin ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1096          -  2.0 * np.sin ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1097    znvvt = np.sqrt ( znnpt * ( zxvvt*zxvvt + zyvvt*zyvvt )  )
1098
1099    # j-direction: f-point segment direction (around u-point)
1100    zlam = zlamf
1101    zphi = zphif
1102    zlan = np.roll (zlamf, axis=-2, shift=1) # glamf (ji,jj-1)
1103    zphh = np.roll (zphif, axis=-2, shift=1) # gphif (ji,jj-1)
1104    zxffu =  2.0 * np.cos ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1105          -  2.0 * np.cos ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1106    zyffu =  2.0 * np.sin ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1107          -  2.0 * np.sin ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1108    znffu = np.sqrt ( znnpu * ( zxffu*zxffu + zyffu*zyffu )  )
1109
1110    # i-direction: f-point segment direction (around v-point)
1111    zlam = zlamf 
1112    zphi = zphif
1113    zlan = np.roll (zlamf, axis=-1, shift=1) # glamf (ji-1,jj)
1114    zphh = np.roll (zphif, axis=-1, shift=1) # gphif (ji-1,jj)
1115    zxffv =  2.0 * np.cos ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1116          -  2.0 * np.cos ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1117    zyffv =  2.0 * np.sin ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1118          -  2.0 * np.sin ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1119    znffv = np.sqrt ( znnpv * ( zxffv*zxffv + zyffv*zyffv )  )
1120
1121    # j-direction: u-point segment direction (around f-point)
1122    zlam = np.roll (zlamu, axis=-2, shift=-1) # glamu (ji,jj+1)
1123    zphi = np.roll (zphiu, axis=-2, shift=-1) # gphiu (ji,jj+1)
1124    zlan = zlamu
1125    zphh = zphiu
1126    zxuuf =  2. * np.cos ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1127          -  2. * np.cos ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1128    zyuuf =  2. * np.sin ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1129          -  2. * np.sin ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1130    znuuf = np.sqrt ( znnpf * ( zxuuf*zxuuf + zyuuf*zyuuf )  )
1131
1132   
1133    # cosinus and sinus using scalar and vectorial products
1134    gsint = ( zxnpt*zyvvt - zynpt*zxvvt ) / znvvt
1135    gcost = ( zxnpt*zxvvt + zynpt*zyvvt ) / znvvt
1136   
1137    gsinu = ( zxnpu*zyffu - zynpu*zxffu ) / znffu
1138    gcosu = ( zxnpu*zxffu + zynpu*zyffu ) / znffu
1139   
1140    gsinf = ( zxnpf*zyuuf - zynpf*zxuuf ) / znuuf
1141    gcosf = ( zxnpf*zxuuf + zynpf*zyuuf ) / znuuf
1142   
1143    gsinv = ( zxnpv*zxffv + zynpv*zyffv ) / znffv
1144    gcosv =-( zxnpv*zyffv - zynpv*zxffv ) / znffv  # (caution, rotation of 90 degres)
1145   
1146    #gsint = lbc (gsint, cd_type='T', nperio=nperio, psgn=-1.)
1147    #gcost = lbc (gcost, cd_type='T', nperio=nperio, psgn=-1.)
1148    #gsinu = lbc (gsinu, cd_type='U', nperio=nperio, psgn=-1.)
1149    #gcosu = lbc (gcosu, cd_type='U', nperio=nperio, psgn=-1.)
1150    #gsinv = lbc (gsinv, cd_type='V', nperio=nperio, psgn=-1.)
1151    #gcosv = lbc (gcosv, cd_type='V', nperio=nperio, psgn=-1.)
1152    #gsinf = lbc (gsinf, cd_type='F', nperio=nperio, psgn=-1.)
1153    #gcosf = lbc (gcosf, cd_type='F', nperio=nperio, psgn=-1.)
1154
1155    gsint = lbc_del (gsint, cd_type='T', nperio=nperio, psgn=-1.)
1156    gcost = lbc_del (gcost, cd_type='T', nperio=nperio, psgn=-1.)
1157    gsinu = lbc_del (gsinu, cd_type='U', nperio=nperio, psgn=-1.)
1158    gcosu = lbc_del (gcosu, cd_type='U', nperio=nperio, psgn=-1.)
1159    gsinv = lbc_del (gsinv, cd_type='V', nperio=nperio, psgn=-1.)
1160    gcosv = lbc_del (gcosv, cd_type='V', nperio=nperio, psgn=-1.)
1161    gsinf = lbc_del (gsinf, cd_type='F', nperio=nperio, psgn=-1.)
1162    gcosf = lbc_del (gcosf, cd_type='F', nperio=nperio, psgn=-1.)
1163
1164    if mmath == xr :
1165        gsint = gsint.assign_coords ( glamt.coords )
1166        gcost = gcost.assign_coords ( glamt.coords )
1167        gsinu = gsinu.assign_coords ( glamu.coords )
1168        gcosu = gcosu.assign_coords ( glamu.coords )
1169        gsinv = gsinv.assign_coords ( glamv.coords )
1170        gcosv = gcosv.assign_coords ( glamv.coords )
1171        gsinf = gsinf.assign_coords ( glamf.coords )
1172        gcosf = gcosf.assign_coords ( glamf.coords )
1173
1174    return gsint, gcost, gsinu, gcosu, gsinv, gcosv, gsinf, gcosf
1175
1176def angle (glam, gphi, nperio, cd_type='T') :
1177    '''Compute sinus and cosinus of model line direction with respect to east'''
1178    mmath = __mmath__ (glam)
1179
1180    zlam = lbc_add (glam, nperio, cd_type, 1.)
1181    zphi = lbc_add (gphi, nperio, cd_type, 1.)
1182   
1183    # north pole direction & modulous
1184    zxnp = 0. - 2.0 * np.cos (rad*zlam) * np.tan (rpi/4.0 - rad*zphi/2.0)
1185    zynp = 0. - 2.0 * np.sin (rad*zlam) * np.tan (rpi/4.0 - rad*zphi/2.0)
1186    znnp = zxnp*zxnp + zynp*zynp
1187
1188    # j-direction: segment direction (around point)
1189    zlan_n = np.roll (zlam, axis=-2, shift=-1) # glam [jj+1, ji]
1190    zphh_n = np.roll (zphi, axis=-2, shift=-1) # gphi [jj+1, ji]
1191    zlan_s = np.roll (zlam, axis=-2, shift= 1) # glam [jj-1, ji]
1192    zphh_s = np.roll (zphi, axis=-2, shift= 1) # gphi [jj-1, ji]
1193   
1194    zxff = 2.0 * np.cos (rad*zlan_n) * np.tan (rpi/4.0 - rad*zphh_n/2.0) \
1195        -  2.0 * np.cos (rad*zlan_s) * np.tan (rpi/4.0 - rad*zphh_s/2.0)
1196    zyff = 2.0 * np.sin (rad*zlan_n) * np.tan (rpi/4.0 - rad*zphh_n/2.0) \
1197        -  2.0 * np.sin (rad*zlan_s) * np.tan (rpi/4.0 - rad*zphh_s/2.0)
1198    znff = np.sqrt (znnp * (zxff*zxff + zyff*zyff) )
1199 
1200    gsin = (zxnp*zyff - zynp*zxff) / znff
1201    gcos = (zxnp*zxff + zynp*zyff) / znff
1202
1203    gsin = lbc_del (gsin, cd_type=cd_type, nperio=nperio, psgn=-1.)
1204    gcos = lbc_del (gcos, cd_type=cd_type, nperio=nperio, psgn=-1.)
1205
1206    if mmath == xr :
1207        gsin = gsin.assign_coords ( glam.coords )
1208        gcos = gcos.assign_coords ( glam.coords )
1209       
1210    return gsin, gcos
1211
1212def rot_en2ij ( u_e, v_n, gsin, gcos, nperio, cd_type ) :
1213    '''
1214    ** Purpose :   Rotate the Repere: Change vector componantes between
1215    geographic grid --> stretched coordinates grid.
1216    All components are on the same grid (T, U, V or F)
1217    '''
1218
1219    u_i = + u_e * gcos + v_n * gsin
1220    v_j = - u_e * gsin + v_n * gcos
1221   
1222    u_i = lbc (u_i, nperio=nperio, cd_type=cd_type, psgn=-1.0)
1223    v_j = lbc (v_j, nperio=nperio, cd_type=cd_type, psgn=-1.0)
1224   
1225    return u_i, v_j
1226
1227def rot_ij2en ( u_i, v_j, gsin, gcos, nperio, cd_type='T' ) :
1228    '''
1229    ** Purpose :   Rotate the Repere: Change vector componantes from
1230    stretched coordinates grid --> geographic grid
1231    All components are on the same grid (T, U, V or F)
1232    '''
1233    u_e = + u_i * gcos - v_j * gsin
1234    v_n = + u_i * gsin + v_j * gcos
1235   
1236    u_e = lbc (u_e, nperio=nperio, cd_type=cd_type, psgn= 1.0)
1237    v_n = lbc (v_n, nperio=nperio, cd_type=cd_type, psgn= 1.0)
1238   
1239    return u_e, v_n
1240
1241def rot_uv2en ( uo, vo, gsint, gcost, nperio, zdim='deptht' ) :
1242    '''
1243    ** Purpose :   Rotate the Repere: Change vector componantes from
1244    stretched coordinates grid --> geographic grid
1245    uo is on the U grid point, vo is on the V grid point
1246    east-north components on the T grid point   
1247    '''
1248    mmath = __mmath__ (uo)
1249
1250    ut = U2T (uo, nperio=nperio, psgn=-1.0, zdim=zdim)
1251    vt = V2T (vo, nperio=nperio, psgn=-1.0, zdim=zdim)
1252   
1253    u_e = + ut * gcost - vt * gsint
1254    v_n = + ut * gsint + vt * gcost
1255
1256    u_e = lbc (u_e, nperio=nperio, cd_type='T', psgn=1.0)
1257    v_n = lbc (v_n, nperio=nperio, cd_type='T', psgn=1.0)
1258   
1259    return u_e, v_n
1260
1261def rot_uv2enF ( uo, vo, gsinf, gcosf, nperio, zdim='deptht' ) :
1262    '''
1263    ** Purpose : Rotate the Repere: Change vector componantes from
1264    stretched coordinates grid --> geographic grid
1265    uo is on the U grid point, vo is on the V grid point
1266    east-north components on the T grid point   
1267    '''
1268    mmath = __mmath__ (uo)
1269
1270    uf = U2F (uo, nperio=nperio, psgn=-1.0, zdim=zdim)
1271    vf = V2F (vo, nperio=nperio, psgn=-1.0, zdim=zdim)
1272   
1273    u_e = + uf * gcosf - vf * gsinf
1274    v_n = + uf * gsinf + vf * gcosf
1275
1276    u_e = lbc (u_e, nperio=nperio, cd_type='F', psgn= 1.0)
1277    v_n = lbc (v_n, nperio=nperio, cd_type='F', psgn= 1.0)
1278   
1279    return u_e, v_n
1280
1281#@numba.jit(forceobj=True)
1282def U2T (utab, nperio=None, psgn=-1.0, zdim='deptht', action='ave') :
1283    '''Interpolate an array from U grid to T grid i-mean)'''
1284    mmath = __mmath__ (utab)
1285    utab_0 = mmath.where ( np.isnan(utab), 0., utab)
1286    lperio, aperio = lbc_diag (nperio)
1287    utab_0 = lbc_add (utab_0, nperio=nperio, cd_type='U', psgn=psgn)
1288    ix, ax = __findAxis__ (utab_0, 'x')
1289    iz, az = __findAxis__ (utab_0, 'z')
1290    if action == 'ave'  : ttab = 0.5 *      (utab_0 + np.roll (utab_0, axis=ix, shift=1))
1291    if action == 'min'  : ttab = np.minimum (utab_0 , np.roll (utab_0, axis=ix, shift=1))
1292    if action == 'max'  : ttab = np.maximum (utab_0 , np.roll (utab_0, axis=ix, shift=1))
1293    if action == 'mult' : ttab =             utab_0 * np.roll (utab_0, axis=ix, shift=1)
1294    ttab = lbc_del (ttab, nperio=nperio, cd_type='T', psgn=psgn)
1295   
1296    if mmath == xr :
1297        if ax != None :
1298            ttab = ttab.assign_coords({ax:np.arange (ttab.shape[ix])+1.})
1299        if zdim != None and iz != None  and az != 'olevel' : 
1300            ttab = ttab.rename( {az:zdim}) 
1301    return ttab
1302
1303#@numba.jit(forceobj=True)
1304def V2T (vtab, nperio=None, psgn=-1.0, zdim='deptht', action='ave') :
1305    '''Interpolate an array from V grid to T grid (j-mean)'''
1306    mmath = __mmath__ (vtab)
1307    lperio, aperio = lbc_diag (nperio)
1308    vtab_0 = mmath.where ( np.isnan(vtab), 0., vtab)
1309    vtab_0 = lbc_add (vtab_0, nperio=nperio, cd_type='V', psgn=psgn)
1310    iy, ay = __findAxis__ (vtab_0, 'y')
1311    iz, az = __findAxis__ (vtab_0, 'z')
1312    if action == 'ave'  : ttab = 0.5 *      (vtab_0 + np.roll (vtab_0, axis=iy, shift=1))
1313    if action == 'min'  : ttab = np.minimum (vtab_0 , np.roll (vtab_0, axis=iy, shift=1))
1314    if action == 'max'  : ttab = np.maximum (vtab_0 , np.roll (vtab_0, axis=iy, shift=1))
1315    if action == 'mult' : ttab =             vtab_0 * np.roll (vtab_0, axis=iy, shift=1)
1316    ttab = lbc_del (ttab, nperio=nperio, cd_type='T', psgn=psgn)
1317    if mmath == xr :
1318        if ay !=None : 
1319            ttab = ttab.assign_coords({ay:np.arange(ttab.shape[iy])+1.})
1320        if zdim != None and iz != None  and az != 'olevel' :
1321            ttab = ttab.rename( {az:zdim}) 
1322    return ttab
1323
1324#@numba.jit(forceobj=True)
1325def F2T (ftab, nperio=None, psgn=1.0, zdim='depthf', action='ave') :
1326    '''Interpolate an array from F grid to T grid (i- and j- means)'''
1327    mmath = __mmath__ (ftab)
1328    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
1329    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
1330    ttab = V2T(F2V(ftab_0, nperio=nperio, psgn=psgn, zdim=zdim, action=action), nperio=nperio, psgn=psgn, zdim=zdim, action=action)
1331    return lbc_del (ttab, nperio=nperio, cd_type='T', psgn=psgn)
1332
1333#@numba.jit(forceobj=True)
1334def T2U (ttab, nperio=None, psgn=1.0, zdim='depthu', action='ave') :
1335    '''Interpolate an array from T grid to U grid (i-mean)'''
1336    mmath = __mmath__ (ttab)
1337    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1338    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
1339    ix, ax = __findAxis__ (ttab_0, 'x')
1340    iz, az = __findAxis__ (ttab_0, 'z')
1341    if action == 'ave'  : utab = 0.5 *      (ttab_0 + np.roll (ttab_0, axis=ix, shift=-1))
1342    if action == 'min'  : utab = np.minimum (ttab_0 , np.roll (ttab_0, axis=ix, shift=-1))
1343    if action == 'max'  : utab = np.maximum (ttab_0 , np.roll (ttab_0, axis=ix, shift=-1))
1344    if action == 'mult' : utab =             ttab_0 * np.roll (ttab_0, axis=ix, shift=-1)
1345    utab = lbc_del (utab, nperio=nperio, cd_type='U', psgn=psgn)
1346
1347    if mmath == xr :   
1348        if ax != None : 
1349            utab = ttab.assign_coords({ax:np.arange(utab.shape[ix])+1.})
1350        if zdim != None  and iz != None  and az != 'olevel' :
1351            utab = utab.rename( {az:zdim}) 
1352    return utab
1353
1354#@numba.jit(forceobj=True)
1355def T2V (ttab, nperio=None, psgn=1.0, zdim='depthv', action='ave') :
1356    '''Interpolate an array from T grid to V grid (j-mean)'''
1357    mmath = __mmath__ (ttab)
1358    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1359    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
1360    iy, ay = __findAxis__ (ttab_0, 'y')
1361    iz, az = __findAxis__ (ttab_0, 'z')
1362    if action == 'ave'  : vtab = 0.5 *      (ttab_0 + np.roll (ttab_0, axis=iy, shift=-1))
1363    if action == 'min'  : vtab = np.minimum (ttab_0 , np.roll (ttab_0, axis=iy, shift=-1))
1364    if action == 'max'  : vtab = np.maximum (ttab_0 , np.roll (ttab_0, axis=iy, shift=-1))
1365    if action == 'mult' : vtab =             ttab_0 * np.roll (ttab_0, axis=iy, shift=-1)
1366
1367    vtab = lbc_del (vtab, nperio=nperio, cd_type='V', psgn=psgn)
1368    if mmath == xr :
1369        if ay != None : 
1370            vtab = vtab.assign_coords({ay:np.arange(vtab.shape[iy])+1.})
1371        if zdim != None  and iz != None and az != 'olevel' :
1372            vtab = vtab.rename( {az:zdim}) 
1373    return vtab
1374
1375#@numba.jit(forceobj=True)
1376def V2F (vtab, nperio=None, psgn=-1.0, zdim='depthf', action='ave') :
1377    '''Interpolate an array from V grid to F grid (i-mean)'''
1378    mmath = __mmath__ (vtab)
1379    vtab_0 = mmath.where ( np.isnan(vtab), 0., vtab)
1380    vtab_0 = lbc_add (vtab_0 , nperio=nperio, cd_type='V', psgn=psgn)
1381    ix, ax = __findAxis__ (vtab_0, 'x')
1382    iz, az = __findAxis__ (vtab_0, 'z')
1383    if action == 'ave'  : 0.5 *      (vtab_0 + np.roll (vtab_0, axis=ix, shift=-1))
1384    if action == 'min'  : np.minimum (vtab_0 , np.roll (vtab_0, axis=ix, shift=-1))
1385    if action == 'max'  : np.maximum (vtab_0 , np.roll (vtab_0, axis=ix, shift=-1))
1386    if action == 'mult' :             vtab_0 * np.roll (vtab_0, axis=ix, shift=-1)
1387    ftab = lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
1388   
1389    if mmath == xr :
1390        if ax != None : 
1391            ftab = ftab.assign_coords({ax:np.arange(ftab.shape[ix])+1.})
1392        if zdim != None and iz != None and az != 'olevel' :
1393            ftab = ftab.rename( {az:zdim}) 
1394    return lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
1395
1396#@numba.jit(forceobj=True)
1397def U2F (utab, nperio=None, psgn=-1.0, zdim='depthf', action='ave') :
1398    '''Interpolate an array from U grid to F grid i-mean)'''
1399    mmath = __mmath__ (utab)
1400    utab_0 = mmath.where ( np.isnan(utab), 0., utab)
1401    utab_0 = lbc_add (utab_0 , nperio=nperio, cd_type='U', psgn=psgn)
1402    iy, ay = __findAxis__ (utab_0, 'y')
1403    iz, az = __findAxis__ (utab_0, 'z')
1404    if action == 'ave'  :    ftab = 0.5 *      (utab_0 + np.roll (utab_0, axis=iy, shift=-1))
1405    if action == 'min'  :    ftab = np.minimum (utab_0 , np.roll (utab_0, axis=iy, shift=-1))
1406    if action == 'max'  :    ftab = np.maximum (utab_0 , np.roll (utab_0, axis=iy, shift=-1))
1407    if action == 'mult' :    ftab =             utab_0 * np.roll (utab_0, axis=iy, shift=-1)
1408    ftab = lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
1409
1410    if mmath == xr :
1411        if ay != None : 
1412            ftab = ftab.assign_coords({'y':np.arange(ftab.shape[iy])+1.})
1413        if zdim != None and iz != None and az != 'olevel' :
1414            ftab = ftab.rename( {az:zdim}) 
1415    return ftab
1416
1417#@numba.jit(forceobj=True)
1418def F2T (ftab, nperio=None, psgn=1.0, zdim='deptht', action='ave') :
1419    '''Interpolate an array on F grid to T grid (i- and j- means)'''
1420    mmath = __mmath__ (ftab)
1421    ftab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1422    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
1423    ttab = U2T(F2U(ftab_0, nperio=nperio, psgn=psgn, zdim=zdim, action=action), nperio=nperio, psgn=psgn, zdim=zdim, action=action)
1424    return lbc_del (ttab, nperio=nperio, cd_type='T', psgn=psgn)
1425
1426#@numba.jit(forceobj=True)
1427def T2F (ttab, nperio=None, psgn=1.0, zdim='deptht', action='mean') :
1428    '''Interpolate an array on T grid to F grid (i- and j- means)'''
1429    mmath = __mmath__ (ttab)
1430    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1431    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
1432    ftab = T2U(U2F(ttab, nperio=nperio, psgn=psgn, zdim=zdim, action=action), nperio=nperio, psgn=psgn, zdim=zdim, action=action)
1433   
1434    return lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
1435
1436#@numba.jit(forceobj=True)
1437def F2U (ftab, nperio=None, psgn=1.0, zdim='depthu', action='ave') :
1438    '''Interpolate an array on F grid to FUgrid (i-mean)'''
1439    mmath = __mmath__ (ftab)
1440    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
1441    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
1442    iy, ay = __findAxis__ (ftab_0, 'y')
1443    iz, az = __findAxis__ (ftab_0, 'z')
1444    if action == 'ave'  : utab = 0.5 *      (ftab_0 + np.roll (ftab_0, axis=iy, shift=-1))
1445    if action == 'min'  : utab = np.minimum (ftab_0 , np.roll (ftab_0, axis=iy, shift=-1))
1446    if action == 'max'  : utab = np.maximum (ftab_0 , np.roll (ftab_0, axis=iy, shift=-1))
1447    if action == 'mult' : utab =             ftab_0 * np.roll (ftab_0, axis=iy, shift=-1)
1448
1449    utab = lbc_del (utab, nperio=nperio, cd_type='U', psgn=psgn)
1450   
1451    if mmath == xr :
1452        utab = utab.assign_coords({ay:np.arange(ftab.shape[iy])+1.})
1453        if zdim != None and iz != None and az != 'olevel' :
1454            utab = utab.rename( {az:zdim}) 
1455    return utab
1456
1457#@numba.jit(forceobj=True)
1458def F2V (ftab, nperio=None, psgn=1.0, zdim='depthv', action='ave') :
1459    '''Interpolate an array from F grid to V grid (i-mean)'''
1460    mmath = __mmath__ (ftab)
1461    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
1462    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
1463    ix, ax = __findAxis__ (ftab_0, 'x')
1464    iz, az = __findAxis__ (ftab_0, 'z')
1465    if action == 'ave'  : vtab = 0.5 *      (ftab_0 + np.roll (ftab_0, axis=ix, shift=-1))
1466    if action == 'min'  : vtab = np.minimum (ftab_0 , np.roll (ftab_0, axis=ix, shift=-1))
1467    if action == 'max'  : vtab = np.maximum (ftab_0 , np.roll (ftab_0, axis=ix, shift=-1))
1468    if action == 'mult' : vtab =             ftab_0 * np.roll (ftab_0, axis=ix, shift=-1)
1469
1470    vtab = lbc_del (vtab, nperio=nperio, cd_type='V', psgn=psgn)
1471    if mmath == xr :
1472        vtab = vtab.assign_coords({ax:np.arange(ftab.shape[ix])+1.})
1473        if zdim != None and iz != None and az != 'olevel' :
1474            vtab = vtab.rename( {az:zdim}) 
1475    return vtab
1476
1477#@numba.jit(forceobj=True)
1478def W2T (wtab, zcoord=None, zdim='deptht', sval=np.nan) :
1479    '''
1480    Interpolate an array on W grid to T grid (k-mean)
1481    sval is the bottom value
1482    '''
1483    mmath = __mmath__ (wtab)
1484    wtab_0 = mmath.where ( np.isnan(wtab), 0., wtab)
1485
1486    iz, az = __findAxis__ (wtab_0, 'z')
1487       
1488    ttab = 0.5 * ( wtab_0 + np.roll (wtab_0, axis=iz, shift=-1) )
1489   
1490    if mmath == xr :
1491        ttab[{az:iz}] = sval
1492        if zdim != None and iz != None and az != 'olevel' :
1493            ttab = ttab.rename ( {az:zdim} )
1494        try    : ttab = ttab.assign_coords ( {zdim:zcoord} )
1495        except : pass
1496    else :
1497        ttab[..., -1, :, :] = sval
1498
1499    return ttab
1500
1501#@numba.jit(forceobj=True)
1502def T2W (ttab, zcoord=None, zdim='depthw', sval=np.nan, extrap_surf=False) :
1503    '''Interpolate an array from T grid to W grid (k-mean)
1504    sval is the surface value
1505    if extrap_surf==True, surface value is taken from 1st level value.
1506    '''
1507    mmath = __mmath__ (ttab)
1508    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1509    iz, az = __findAxis__ (ttab_0, 'z')
1510    wtab = 0.5 * ( ttab_0 + np.roll (ttab_0, axis=iz, shift=1) )
1511
1512    if mmath == xr :
1513        if extrap_surf : wtab[{az:0}] = ttabb[{az:0}]
1514        else           : wtab[{az:0}] = sval
1515    else : 
1516        if extrap_surf : wtab[..., 0, :, :] = ttab[..., 0, :, :]
1517        else           : wtab[..., 0, :, :] = sval
1518
1519    if mmath == xr :
1520        if zdim != None and iz != None and az != 'olevel' :
1521                wtab = wtab.rename ( {az:zdim})
1522        if zcoord != None : wtab = wtab.assign_coords ( {zdim:zcoord})
1523        else              : ztab = wtab.assign_coords ( {zdim:np.arange(ttab.shape[iz])+1.} )
1524    return wtab
1525
1526#@numba.jit(forceobj=True)
1527def fill (ptab, nperio, cd_type='T', npass=1, sval=0.) :
1528    '''
1529    Fill sval values with mean of neighbours
1530   
1531    Inputs :
1532       ptab : input field to fill
1533       nperio, cd_type : periodicity characteristics
1534    '''       
1535
1536    mmath = __mmath__ (ptab)
1537
1538    DoPerio = False ; lperio = nperio
1539    if nperio == 4.2 :
1540        DoPerio = True ; lperio = 4
1541    if nperio == 6.2 :
1542        DoPerio = True ; lperio = 6
1543       
1544    if DoPerio :
1545        ztab = lbc_add (ptab, nperio=nperio, sval=sval)
1546    else : 
1547        ztab = ptab
1548       
1549    if np.isnan (sval) : 
1550        ztab   = mmath.where (np.isnan(ztab), np.nan, ztab)
1551    else :
1552        ztab   = mmath.where (ztab==sval    , np.nan, ztab)
1553   
1554    for nn in np.arange (npass) : 
1555        zmask = mmath.where ( np.isnan(ztab), 0., 1.   )
1556        ztab0 = mmath.where ( np.isnan(ztab), 0., ztab )
1557        # Compte du nombre de voisins
1558        zcount = 1./6. * ( zmask \
1559          + np.roll(zmask, shift=1, axis=-1) + np.roll(zmask, shift=-1, axis=-1) \
1560          + np.roll(zmask, shift=1, axis=-2) + np.roll(zmask, shift=-1, axis=-2) \
1561          + 0.5 * ( \
1562                + np.roll(np.roll(zmask, shift= 1, axis=-2), shift= 1, axis=-1) \
1563                + np.roll(np.roll(zmask, shift=-1, axis=-2), shift= 1, axis=-1) \
1564                + np.roll(np.roll(zmask, shift= 1, axis=-2), shift=-1, axis=-1) \
1565                + np.roll(np.roll(zmask, shift=-1, axis=-2), shift=-1, axis=-1) ) )
1566
1567        znew =1./6. * ( ztab0 \
1568           + np.roll(ztab0, shift=1, axis=-1) + np.roll(ztab0, shift=-1, axis=-1) \
1569           + np.roll(ztab0, shift=1, axis=-2) + np.roll(ztab0, shift=-1, axis=-2) \
1570           + 0.5 * ( \
1571                + np.roll(np.roll(ztab0 , shift= 1, axis=-2), shift= 1, axis=-1) \
1572                + np.roll(np.roll(ztab0 , shift=-1, axis=-2), shift= 1, axis=-1) \
1573                + np.roll(np.roll(ztab0 , shift= 1, axis=-2), shift=-1, axis=-1) \
1574                + np.roll(np.roll(ztab0 , shift=-1, axis=-2), shift=-1, axis=-1) ) )
1575
1576        zcount = lbc (zcount, nperio=lperio, cd_type=cd_type)
1577        znew   = lbc (znew  , nperio=lperio, cd_type=cd_type)
1578       
1579        ztab = mmath.where (np.logical_and (zmask==0., zcount>0), znew/zcount, ztab)
1580
1581    ztab = mmath.where (zcount==0, sval, ztab)
1582    if DoPerio : ztab = lbc_del (ztab, nperio=lperio)
1583
1584    return ztab
1585
1586#@numba.jit(forceobj=True)
1587def correct_uv (u, v, lat) :
1588    '''
1589    Correct a Cartopy bug in Orthographic projection
1590
1591    See https://github.com/SciTools/cartopy/issues/1179
1592
1593    The correction is needed with cartopy <= 0.20
1594    It seems that version 0.21 will correct the bug (https://github.com/SciTools/cartopy/pull/1926)
1595
1596    Inputs :
1597       u, v : eastward/nothward components
1598       lat  : latitude of the point (degrees north)
1599
1600    Outputs :
1601       modified eastward/nothward components to have correct polar projections in cartopy
1602    '''
1603    uv = np.sqrt (u*u + v*v)           # Original modulus
1604    zu = u
1605    zv = v * np.cos (rad*lat)
1606    zz = np.sqrt ( zu*zu + zv*zv )     # Corrected modulus
1607    uc = zu*uv/zz ; vc = zv*uv/zz      # Final corrected values
1608    return uc, vc
1609
1610def msf (v_e1v_e3v, lat1d, depthw) :
1611    '''
1612    Computes the meridonal stream function
1613    First input is meridional_velocity*e1v*e3v
1614    '''
1615    @numba.jit(forceobj=True)
1616    def iin (tab, dim) :
1617        '''
1618        Integrate from the bottom
1619        '''
1620        result = tab * 0.0
1621        nlen = len(tab.coords[dim])
1622        for jn in np.arange (nlen-2, 0, -1) :
1623            result [{dim:jn}] = result [{dim:jn+1}] - tab [{dim:jn}]
1624        result = result.where (result !=0, np.nan)
1625        return result
1626   
1627    zomsf = iin ((v_e1v_e3v).sum (dim='x', keep_attrs=True)*1E-6, dim='depthv')
1628    zomsf = zomsf.assign_coords ( {'depthv':depthw.values, 'y':lat1d})
1629    zomsf = zomsf.rename ( {'depthv':'depthw', 'y':'lat'})
1630    zomsf.attrs['long_name'] = 'Meridional stream function'
1631
1632    zomsf.attrs['units'] = 'Sv'
1633    zomsf.depthw.attrs=depthw.attrs
1634    zomsf.lat.attrs=lat1d.attrs
1635       
1636    return zomsf
1637
1638def bsf (u_e2u_e3u, mask, nperio=None, bsf0=None ) :
1639    '''
1640    Computes the barotropic stream function
1641    First input is zonal_velocity*e2u*e3u
1642    bsf0 is the point with bsf=0 (ex: bsf0={'x':5, 'y':120} )
1643    '''
1644    @numba.jit(forceobj=True)
1645    def iin (tab, dim) :
1646        '''
1647        Integrate from the south
1648        '''
1649        result = tab * 0.0
1650        nlen = len(tab.coords[dim])
1651        for jn in np.arange (3, nlen) :
1652            result [{dim:jn}] = result [{dim:jn-1}] + tab [{dim:jn}]
1653        return result
1654   
1655    bsf = iin ((u_e2u_e3u).sum(dim='depthu', keep_attrs=True)*1E-6, dim='y')
1656    bsf.attrs = u_e2u_e3u.attrs
1657    if bsf0 != None :
1658        bsf = bsf - bsf.isel (bsf0)
1659       
1660    bsf = bsf.where (mask !=0, np.nan)
1661    bsf.attrs['long_name'] = 'Barotropic stream function'
1662    bsf.attrs['units'] = 'Sv'
1663    bsf = lbc (bsf, nperio=nperio, cd_type='F')
1664       
1665    return bsf
1666
1667def namelist_read (ref=None, cfg=None, out='dict', flat=False, verbose=False) :
1668    '''
1669    Read NEMO namelist(s) and return either a dictionnary or an xarray dataset
1670
1671    ref : file with reference namelist, or a f90nml.namelist.Namelist object
1672    cfg : file with config namelist, or a f90nml.namelist.Namelist object
1673    At least one namelist neaded
1674
1675    out:
1676        'dict' to return a dictonnary
1677        'xr'   to return an xarray dataset
1678    flat : only for dict output. Output a flat dictionnary with all values.
1679   
1680    '''
1681
1682    if ref != None :
1683        if isinstance (ref, str) : nml_ref = f90nml.read (ref)
1684        if isinstance (ref, f90nml.namelist.Namelist) : nml_ref = ref
1685       
1686    if cfg != None :
1687        if isinstance (cfg, str) : nml_cfg = f90nml.read (cfg)
1688        if isinstance (cfg, f90nml.namelist.Namelist) : nml_cfg = cfg
1689   
1690    if out == 'dict' : dict_namelist = {}
1691    if out == 'xr'   : xr_namelist = xr.Dataset ()
1692
1693    list_nml = [] ; list_comment = []
1694
1695    if ref != None :
1696        list_nml.append (nml_ref) ; list_comment.append ('ref')
1697    if cfg != None :
1698        list_nml.append (nml_cfg) ; list_comment.append ('cfg')
1699
1700    for nml, comment in zip (list_nml, list_comment) :
1701        if verbose : print (comment)
1702        if flat and out =='dict' :
1703            for nam in nml.keys () :
1704                if verbose : print (nam)
1705                for value in nml[nam] :
1706                     if out == 'dict' : dict_namelist[value] = nml[nam][value]
1707                     if verbose : print (nam, ':', value, ':', nml[nam][value])
1708        else :
1709            for nam in nml.keys () :
1710                if verbose : print (nam)
1711                if out == 'dict' :
1712                    if nam not in dict_namelist.keys () : dict_namelist[nam] = {}
1713                for value in nml[nam] :
1714                    if out == 'dict' : dict_namelist[nam][value] = nml[nam][value]
1715                    if out == 'xr'   : xr_namelist[value] = nml[nam][value]
1716                    if verbose : print (nam, ':', value, ':', nml[nam][value])
1717
1718    if out == 'dict' : return dict_namelist
1719    if out == 'xr'   : return xr_namelist
1720
1721
1722def fill_closed_seas (imask, nperio=None,  cd_type='T') :
1723    '''Fill closed seas with image processing library
1724    imask : mask, 1 on ocean, 0 on land
1725    '''
1726    from scipy import ndimage
1727
1728    imask_filled = ndimage.binary_fill_holes ( lbc (imask, nperio=nperio, cd_type=cd_type))
1729    imask_filled = lbc ( imask_filled, nperio=nperio, cd_type=cd_type)
1730
1731    return imask_filled
1732
1733## ===========================================================================
1734##
1735##                               That's all folk's !!!
1736##
1737## ===========================================================================
1738
1739def __is_orca_north_fold__ ( Xtest, cname_long='T' ) :
1740    '''
1741    Ported (pirated !!?) from Sosie
1742
1743    Tell if there is a 2/point band overlaping folding at the north pole typical of the ORCA grid
1744
1745    0 => not an orca grid (or unknown one)
1746    4 => North fold T-point pivot (ex: ORCA2)
1747    6 => North fold F-point pivot (ex: ORCA1)
1748
1749    We need all this 'cname_long' stuff because with our method, there is a
1750    confusion between "Grid_U with T-fold" and "Grid_V with F-fold"
1751    => so knowing the name of the longitude array (as in namelist, and hence as
1752    in netcdf file) might help taking the righ decision !!! UGLY!!!
1753    => not implemented yet
1754    '''
1755   
1756    ifld_nord =  0 ; cgrd_type = 'X'
1757    ny, nx = Xtest.shape[-2:]
1758
1759    if ny > 3 : # (case if called with a 1D array, ignoring...)
1760        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-1:nx-nx//2+1:-1] ).sum() == 0. :
1761          ifld_nord = 4 ; cgrd_type = 'T' # T-pivot, grid_T     
1762
1763        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-2:nx-nx//2  :-1] ).sum() == 0. :
1764            if cnlon == 'U' : ifld_nord = 4 ;  cgrd_type = 'U' # T-pivot, grid_T
1765                ## LOLO: PROBLEM == 6, V !!!
1766
1767        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-1:nx-nx//2+1:-1] ).sum() == 0. :
1768            ifld_nord = 4 ; cgrd_type = 'V' # T-pivot, grid_V
1769
1770        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-2, nx-1-1:nx-nx//2:-1] ).sum() == 0. :
1771            ifld_nord = 6 ; cgrd_type = 'T'# F-pivot, grid_T
1772
1773        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-1, nx-1:nx-nx//2-1:-1] ).sum() == 0. :
1774            ifld_nord = 6 ;  cgrd_type = 'U' # F-pivot, grid_U
1775
1776        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-2:nx-nx//2  :-1] ).sum() == 0. :
1777            if cnlon == 'V' : ifld_nord = 6 ; cgrd_type = 'V' # F-pivot, grid_V
1778                ## LOLO: PROBLEM == 4, U !!!
1779
1780    return ifld_nord, cgrd_type
Note: See TracBrowser for help on using the repository browser.