source: TOOLS/MOSAIX/nemo.py

Last change on this file was 6666, checked in by omamce, 5 months ago

O.M. : MOSAIX

Improved code with pylint analysis

  • Property svn:keywords set to Date Revision HeadURL Author Id
File size: 93.0 KB
Line 
1# -*- coding: utf-8 -*-
2## ===========================================================================
3##
4##  This software is governed by the CeCILL  license under French law and
5##  abiding by the rules of distribution of free software.  You can  use,
6##  modify and/ or redistribute the software under the terms of the CeCILL
7##  license as circulated by CEA, CNRS and INRIA at the following URL
8##  "http://www.cecill.info".
9##
10##  Warning, to install, configure, run, use any of Olivier Marti's
11##  software or to read the associated documentation you'll need at least
12##  one (1) brain in a reasonably working order. Lack of this implement
13##  will void any warranties (either express or implied).
14##  O. Marti assumes no responsability for errors, omissions,
15##  data loss, or any other consequences caused directly or indirectly by
16##  the usage of his software by incorrectly or partially configured
17##  personal.
18##
19## ===========================================================================
20'''Utilities to plot NEMO ORCA fields,
21
22Handles periodicity and other stuff
23
24- Lots of tests for xarray object
25- Not much tested for numpy objects
26
27Author: olivier.marti@lsce.ipsl.fr
28
29## SVN information
30Author   = "$Author$"
31Date     = "$Date$"
32Revision = "$Revision$"
33Id       = "$Id$"
34HeadURL  = "$HeadURL$"
35'''
36
37import numpy as np
38import xarray as xr
39
40# Tries to import some moldules that are available in all Python installations
41#  and used in only a few specific functions
42try :
43    from sklearn.impute import SimpleImputer
44except ImportError as err :
45    print ( f'===> Warning : Module nemo : Import error of sklearn.impute.SimpleImputer : {err}' )
46    SimpleImputer = None
47
48try :
49    import f90nml
50except ImportError as err :
51    print ( f'===> Warning : Module nemo : Import error of f90nml : {err}' )
52    f90nml = None
53
54# SVN information
55__SVN__ = ({
56    'Author'   : '$Author$',
57    'Date'     : '$Date$',
58    'Revision' : '$Revision$',
59    'Id'       : '$Id$',
60    'HeadURL'  : '$HeadURL$',
61    'SVN_Date' : '$SVN_Date: $',
62    })
63##
64   
65RPI   = np.pi
66RAD   = np.deg2rad (1.0)
67DAR   = np.rad2deg (1.0)
68REPSI = np.finfo (1.0).eps
69
70NPERIO_VALID_RANGE = [0, 1, 4, 4.2, 5, 6, 6.2]
71
72RAAMO    =      12          # Number of months in one year
73RJJHH    =      24          # Number of hours in one day
74RHHMM    =      60          # Number of minutes in one hour
75RMMSS    =      60          # Number of seconds in one minute
76RA       = 6371229.0        # Earth radius                                  [m]
77GRAV     =       9.80665    # Gravity                                       [m/s2]
78RT0      =     273.15       # Freezing point of fresh water                 [Kelvin]
79RAU0     =    1026.0        # Volumic mass of sea water                     [kg/m3]
80SICE     =       6.0        # Salinity of ice (for pisces)                  [psu]
81SOCE     =      34.7        # Salinity of sea (for pisces and isf)          [psu]
82RLEVAP   =       2.5e+6     # Latent heat of evaporation (water)            [J/K]
83VKARMN   =       0.4        # Von Karman constant
84STEFAN   =       5.67e-8    # Stefan-Boltzmann constant                     [W/m2/K4]
85RHOS     =     330.         # Volumic mass of snow                          [kg/m3]
86RHOI     =     917.         # Volumic mass of sea ice                       [kg/m3]
87RHOW     =    1000.         # Volumic mass of freshwater in melt ponds      [kg/m3]
88RCND_I   =       2.034396   # Thermal conductivity of fresh ice             [W/m/K]
89RCPI     =    2067.0        # Specific heat of fresh ice                    [J/kg/K]
90RLSUB    =       2.834e+6   # Pure ice latent heat of sublimation           [J/kg]
91RLFUS    =       0.334e+6   # Latent heat of fusion of fresh ice            [J/kg]
92RTMLT    =       0.054      # Decrease of seawater meltpoint with salinity
93
94RDAY     = RJJHH * RHHMM * RMMSS                # Day length               [s]
95RSIYEA   = 365.25 * RDAY * 2. * RPI / 6.283076  # Sideral year length      [s]
96RSIDAY   = RDAY / (1. + RDAY / RSIYEA)          # Sideral day length       [s]
97OMEGA    = 2. * RPI / RSIDAY                    # Earth rotation parameter [s-1]
98
99## Default names of dimensions
100UDIMS = {'x':'x', 'y':'y', 'z':'olevel', 't':'time_counter'}
101
102## All possibles name of dimensions in Nemo files
103XNAME = [ 'x', 'X', 'X1', 'xx', 'XX',
104              'x_grid_T', 'x_grid_U', 'x_grid_V', 'x_grid_F', 'x_grid_W',
105              'lon', 'nav_lon', 'longitude', 'X1', 'x_c', 'x_f', ]
106YNAME = [ 'y', 'Y', 'Y1', 'yy', 'YY',
107              'y_grid_T', 'y_grid_U', 'y_grid_V', 'y_grid_F', 'y_grid_W',
108              'lat', 'nav_lat', 'latitude' , 'Y1', 'y_c', 'y_f', ]
109ZNAME = [ 'z', 'Z', 'Z1', 'zz', 'ZZ', 'depth', 'tdepth', 'udepth',
110              'vdepth', 'wdepth', 'fdepth', 'deptht', 'depthu',
111              'depthv', 'depthw', 'depthf', 'olevel', 'z_c', 'z_f', ]
112TNAME = [ 't', 'T', 'tt', 'TT', 'time', 'time_counter', 'time_centered', ]
113
114## All possibles name of units of dimensions in Nemo files
115XUNIT = [ 'degrees_east', ]
116YUNIT = [ 'degrees_north', ]
117ZUNIT = [ 'm', 'meter', ]
118TUNIT = [ 'second', 'minute', 'hour', 'day', 'month', 'year', ]
119
120## All possibles size of dimensions in Orca files
121XLENGTH = [ 180, 182, 360, 362, 1440 ]
122YLENGTH = [ 148, 149, 331, 332 ]
123ZLENGTH = [ 31, 75]
124
125## ===========================================================================
126def __mmath__ (ptab, default=None) :
127    '''Determines the type of tab : xarray, numpy or numpy.ma object ?
128
129    Returns type
130    '''
131    mmath = default
132    if isinstance (ptab, xr.core.dataarray.DataArray) :
133        mmath = xr
134    if isinstance (ptab, np.ndarray) :
135        mmath = np
136    if isinstance (ptab, np.ma.MaskType) :
137        mmath = np.ma
138
139    return mmath
140
141def __guess_nperio__ (jpj, jpi, nperio=None, out='nperio') :
142    '''Tries to guess the value of nperio (periodicity parameter.
143
144    See NEMO documentation for details)
145    Inputs
146    jpj    : number of latitudes
147    jpi    : number of longitudes
148    nperio : periodicity parameter
149    '''
150    if nperio is None :
151        nperio = __guess_config__ (jpj, jpi, nperio=None, out=out)
152    return nperio
153
154def __guess_config__ (jpj, jpi, nperio=None, config=None, out='nperio') :
155    '''Tries to guess the value of nperio (periodicity parameter).
156
157    See NEMO documentation for details)
158    Inputs
159    jpj    : number of latitudes
160    jpi    : number of longitudes
161    nperio : periodicity parameter
162    '''
163    print ( jpi, jpj)
164    if nperio is None :
165        ## Values for NEMO version < 4.2
166        if ( (jpj == 149 and jpi == 182) or (jpj is None and jpi == 182) or
167             (jpj == 149 or jpi is None) ) :
168            # ORCA2. We choose legacy orca2.
169            config, nperio, iperio, jperio, nfold, nftype = 'ORCA2.3' , 4, 1, 0, 1, 'T'
170        if ((jpj == 332 and jpi == 362) or (jpj is None and jpi == 362) or
171            (jpj ==  332 and jpi is None) ) : # eORCA1.
172            config, nperio, iperio, jperio, nfold, nftype = 'eORCA1.2', 6, 1, 0, 1, 'F'
173        if jpi == 1442 :  # ORCA025.
174            config, nperio, iperio, jperio, nfold, nftype = 'ORCA025' , 6, 1, 0, 1, 'F'
175        if jpj ==  294 : # ORCA1
176            config, nperio, iperio, jperio, nfold, nftype = 'ORCA1'   , 6, 1, 0, 1, 'F'
177
178        ## Values for NEMO version >= 4.2. No more halo points
179        if  (jpj == 148 and jpi == 180) or (jpj is None and jpi == 180) or \
180            (jpj == 148 and jpi is None) :  # ORCA2. We choose legacy orca2.
181            config, nperio, iperio, jperio, nfold, nftype = 'ORCA2.4' , 4.2, 1, 0, 1, 'F'
182        if  (jpj == 331 and jpi == 360) or (jpj is None and jpi == 360) or \
183            (jpj == 331 and jpi is None) : # eORCA1.
184            config, nperio, iperio, jperio, nfold, nftype = 'eORCA1.4', 6.2, 1, 0, 1, 'F'
185        if jpi == 1440 : # ORCA025.
186            config, nperio, iperio, jperio, nfold, nftype = 'ORCA025' , 6.2, 1, 0, 1, 'F'
187
188        if nperio is None :
189            raise ValueError ('in nemo module : nperio not found, and cannot by guessed')
190
191        if nperio in NPERIO_VALID_RANGE :
192            print ( f'nperio set as {nperio} (deduced from {jpj=} and {jpi=})' )
193        else :
194            raise ValueError ( f'nperio set as {nperio} (deduced from {jpi=} and {jpj=}) : \n'+
195                                'nemo.py is not ready for this value' )
196
197    if out == 'nperio' :
198        return nperio
199    if out == 'config' :
200        return config
201    if out == 'perio'  :
202        return iperio, jperio, nfold, nftype
203    if out in ['full', 'all'] :
204        return {'nperio':nperio, 'iperio':iperio, 'jperio':jperio, 'nfold':nfold, 'nftype':nftype}
205
206def __guess_point__ (ptab) :
207    '''Tries to guess the grid point (periodicity parameter.
208
209    See NEMO documentation for details)
210    For array conforments with xgcm requirements
211
212    Inputs
213         ptab : xarray array
214
215    Credits : who is the original author ?
216    '''
217
218    gp = None
219    mmath = __mmath__ (ptab)
220    if mmath == xr :
221        if ('x_c' in ptab.dims and 'y_c' in ptab.dims ) :
222            gp = 'T'
223        if ('x_f' in ptab.dims and 'y_c' in ptab.dims ) :
224            gp = 'U'
225        if ('x_c' in ptab.dims and 'y_f' in ptab.dims ) :
226            gp = 'V'
227        if ('x_f' in ptab.dims and 'y_f' in ptab.dims ) :
228            gp = 'F'
229        if ('x_c' in ptab.dims and 'y_c' in ptab.dims
230              and 'z_c' in ptab.dims )                :
231            gp = 'T'
232        if ('x_c' in ptab.dims and 'y_c' in ptab.dims
233                and 'z_f' in ptab.dims )                :
234            gp = 'W'
235        if ('x_f' in ptab.dims and 'y_c' in ptab.dims
236                and 'z_f' in ptab.dims )                :
237            gp = 'U'
238        if ('x_c' in ptab.dims and 'y_f' in ptab.dims
239                and 'z_f' in ptab.dims ) :
240            gp = 'V'
241        if ('x_f' in ptab.dims and 'y_f' in ptab.dims
242                and 'z_f' in ptab.dims ) :
243            gp = 'F'
244
245        if gp is None :
246            raise AttributeError ('in nemo module : cd_type not found, and cannot by guessed')
247        print ( f'Grid set as {gp} deduced from dims {ptab.dims}' )
248        return gp
249    else :
250        raise AttributeError  ('in nemo module : cd_type not found, input is not an xarray data')
251
252def get_shape ( ptab ) :
253    '''Get shape of ptab return a string with axes names
254
255    shape may contain X, Y, Z or T
256    Y is missing for a latitudinal slice
257    X is missing for on longitudinal slice
258    etc ...
259    '''
260
261    g_shape = ''
262    if __find_axis__ (ptab, 'x')[0] :
263        g_shape = 'X'
264    if __find_axis__ (ptab, 'y')[0] :
265        g_shape = 'Y' + g_shape
266    if __find_axis__ (ptab, 'z')[0] :
267        g_shape = 'Z' + g_shape
268    if __find_axis__ (ptab, 't')[0] :
269        g_shape = 'T' + g_shape
270    return g_shape
271
272def lbc_diag (nperio) :
273    '''Useful to switch between field with and without halo'''
274    lperio, aperio = nperio, False
275    if nperio == 4.2 :
276        lperio, aperio = 4, True
277    if nperio == 6.2 :
278        lperio, aperio = 6, True
279    return lperio, aperio
280
281def __find_axis__ (ptab, axis='z', back=True) :
282    '''Returns name and name of the requested axis'''
283    mmath = __mmath__ (ptab)
284    ax, ix = None, None
285
286    if axis in XNAME :
287        ax_name, unit_list, length = XNAME, XUNIT, XLENGTH
288    if axis in YNAME :
289        ax_name, unit_list, length = YNAME, YUNIT, YLENGTH
290    if axis in ZNAME :
291        ax_name, unit_list, length = ZNAME, ZUNIT, ZLENGTH
292    if axis in TNAME :
293        ax_name, unit_list, length = TNAME, TUNIT, None
294
295    if mmath == xr :
296        # Try by name
297        for dim in ax_name :
298            if dim in ptab.dims :
299                ix, ax = ptab.dims.index (dim), dim
300
301        # If not found, try by axis attributes
302        if not ix :
303            for i, dim in enumerate (ptab.dims) :
304                if 'axis' in ptab.coords[dim].attrs.keys() :
305                    l_axis = ptab.coords[dim].attrs['axis']
306                    if axis in ax_name and l_axis == 'X' :
307                        ix, ax = (i, dim)
308                    if axis in ax_name and l_axis == 'Y' :
309                        ix, ax = (i, dim)
310                    if axis in ax_name and l_axis == 'Z' :
311                        ix, ax = (i, dim)
312                    if axis in ax_name and l_axis == 'T' :
313                        ix, ax = (i, dim)
314
315        # If not found, try by units
316        if not ix :
317            for i, dim in enumerate (ptab.dims) :
318                if 'units' in ptab.coords[dim].attrs.keys() :
319                    for name in unit_list :
320                        if name in ptab.coords[dim].attrs['units'] :
321                            ix, ax = i, dim
322
323    # If numpy array or dimension not found, try by length
324    if mmath != xr or not ix :
325        if length :
326            l_shape = ptab.shape
327            for nn in np.arange ( len(l_shape) ) :
328                if l_shape[nn] in length :
329                    ix = nn
330
331    if ix and back :
332        ix -= len(ptab.shape)
333
334    return ax, ix
335
336def find_axis ( ptab, axis='z', back=True ) :
337    '''Version of find_axis with no __'''
338    ix, xx = __find_axis__ (ptab, axis, back)
339    return xx, ix
340
341def fixed_lon (plon, center_lon=0.0) :
342    '''Returns corrected longitudes for nicer plots
343
344    lon        : longitudes of the grid. At least 2D.
345    center_lon : center longitude. Default=0.
346
347    Designed by Phil Pelson.
348    See https://gist.github.com/pelson/79cf31ef324774c97ae7
349    '''
350    mmath = __mmath__ (plon)
351
352    f_lon = plon.copy ()
353
354    f_lon = mmath.where (f_lon > center_lon+180., f_lon-360.0, f_lon)
355    f_lon = mmath.where (f_lon < center_lon-180., f_lon+360.0, f_lon)
356
357    for i, start in enumerate (np.argmax (np.abs (np.diff (f_lon, axis=-1)) > 180., axis=-1)) :
358        f_lon [..., i, start+1:] += 360.
359
360    # Special case for eORCA025
361    if f_lon.shape [-1] == 1442 :
362        f_lon [..., -2, :] = f_lon [..., -3, :]
363    if f_lon.shape [-1] == 1440 :
364        f_lon [..., -1, :] = f_lon [..., -2, :]
365
366    if f_lon.min () > center_lon :
367        f_lon += -360.0
368    if f_lon.max () < center_lon :
369        f_lon +=  360.0
370
371    if f_lon.min () < center_lon-360.0 :
372        f_lon +=  360.0
373    if f_lon.max () > center_lon+360.0 :
374        f_lon += -360.0
375
376    return f_lon
377
378def bounds_clolon ( pbounds_lon, plon, rad=False, deg=True) :
379    '''Choose closest to lon0 longitude, adding/substacting 360° if needed
380    '''
381
382    if rad :
383        lon_range = 2.0*np.pi
384    if deg :
385        lon_range = 360.0
386    b_clolon = pbounds_lon.copy ()
387
388    b_clolon = xr.where ( b_clolon < plon-lon_range/2.,
389                          b_clolon+lon_range,
390                          b_clolon )
391    b_clolon = xr.where ( b_clolon > plon+lon_range/2.,
392                          b_clolon-lon_range,
393                          b_clolon )
394    return b_clolon
395
396def unify_dims ( dd, x='x', y='y', z='olevel', t='time_counter', verbose=False ) :
397    '''Rename dimensions to unify them between NEMO versions
398    '''
399    for xx in XNAME :
400        if xx in dd.dims and xx != x :
401            if verbose :
402                print ( f"{xx} renamed to {x}" )
403            dd = dd.rename ( {xx:x})
404
405    for yy in YNAME :
406        if yy in dd.dims and yy != y  :
407            if verbose :
408                print ( f"{yy} renamed to {y}" )
409            dd = dd.rename ( {yy:y} )
410
411    for zz in ZNAME :
412        if zz in dd.dims and zz != z :
413            if verbose :
414                print ( f"{zz} renamed to {z}" )
415            dd = dd.rename ( {zz:z} )
416
417    for tt in TNAME  :
418        if tt in dd.dims and tt != t :
419            if verbose :
420                print ( f"{tt} renamed to {t}" )
421            dd = dd.rename ( {tt:t} )
422
423    return dd
424
425
426if SimpleImputer : 
427  def fill_empty (ptab, sval=np.nan, transpose=False) :
428        '''Fill empty values
429
430        Useful when NEMO has run with no wet points options :
431        some parts of the domain, with no ocean points, have no
432        values
433        '''
434        mmath = __mmath__ (ptab)
435
436        imp = SimpleImputer (missing_values=sval, strategy='mean')
437        if transpose :
438            imp.fit (ptab.T)
439            ztab = imp.transform (ptab.T).T
440        else :
441            imp.fit (ptab)
442            ztab = imp.transform (ptab)
443
444        if mmath == xr :
445            ztab = xr.DataArray (ztab, dims=ztab.dims, coords=ztab.coords)
446            ztab.attrs.update (ptab.attrs)
447
448        return ztab
449
450   
451else : 
452    print ("Import error of sklearn.impute.SimpleImputer")
453    def fill_empty (ptab, sval=np.nan, transpose=False) :
454        '''Void version of fill_empy, because module sklearn.impute.SimpleImputer is not available
455
456        fill_empty :
457          Fill values
458
459          Useful when NEMO has run with no wet points options :
460          some parts of the domain, with no ocean points, have no
461          values
462        '''
463        print ( 'Error : module sklearn.impute.SimpleImputer not found' )
464        print ( 'Can not call fill_empty' )
465        print ( 'Call arguments where : ' )
466        print ( f'{ptab.shape=} {sval=} {transpose=}' )
467 
468def fill_lonlat (plon, plat, sval=-1) :
469    '''Fill longitude/latitude values
470
471    Useful when NEMO has run with no wet points options :
472    some parts of the domain, with no ocean points, have no
473    lon/lat values
474    '''
475    from sklearn.impute import SimpleImputer
476    mmath = __mmath__ (plon)
477
478    imp = SimpleImputer (missing_values=sval, strategy='mean')
479    imp.fit (plon)
480    zlon = imp.transform (plon)
481    imp.fit (plat.T)
482    zlat = imp.transform (plat.T).T
483
484    if mmath == xr :
485        zlon = xr.DataArray (zlon, dims=plon.dims, coords=plon.coords)
486        zlat = xr.DataArray (zlat, dims=plat.dims, coords=plat.coords)
487        zlon.attrs.update (plon.attrs)
488        zlat.attrs.update (plat.attrs)
489
490    zlon = fixed_lon (zlon)
491
492    return zlon, zlat
493
494def fill_bounds_lonlat (pbounds_lon, pbounds_lat, sval=-1) :
495    '''Fill longitude/latitude bounds values
496
497    Useful when NEMO has run with no wet points options :
498    some parts of the domain, with no ocean points, as no
499    lon/lat values
500    '''
501    mmath = __mmath__ (pbounds_lon)
502
503    z_bounds_lon = np.empty ( pbounds_lon.shape )
504    z_bounds_lat = np.empty ( pbounds_lat.shape )
505
506    imp = SimpleImputer (missing_values=sval, strategy='mean')
507
508    for n in np.arange (4) :
509        imp.fit (pbounds_lon[:,:,n])
510        z_bounds_lon[:,:,n] = imp.transform (pbounds_lon[:,:,n])
511        imp.fit (pbounds_lat[:,:,n].T)
512        z_bounds_lat[:,:,n] = imp.transform (pbounds_lat[:,:,n].T).T
513
514    if mmath == xr :
515        z_bounds_lon = xr.DataArray (pbounds_lon, dims=pbounds_lon.dims,
516                                         coords=pbounds_lon.coords)
517        z_bounds_lat = xr.DataArray (pbounds_lat, dims=pbounds_lat.dims,
518                                         coords=pbounds_lat.coords)
519        z_bounds_lon.attrs.update (pbounds_lat.attrs)
520        z_bounds_lat.attrs.update (pbounds_lat.attrs)
521
522    return z_bounds_lon, z_bounds_lat
523
524def jeq (plat) :
525    '''Returns j index of equator in the grid
526
527    lat : latitudes of the grid. At least 2D.
528    '''
529    mmath = __mmath__ (plat)
530    jy = __find_axis__ (plat, 'y')[-1]
531
532    if mmath == xr :
533        jj = int ( np.mean ( np.argmin (np.abs (np.float64 (plat)),
534                                                axis=jy) ) )
535    else :
536        jj = np.argmin (np.abs (np.float64 (plat[...,:, 0])))
537
538    return jj
539
540def lon1d (plon, plat=None) :
541    '''Returns 1D longitude for simple plots.
542
543    plon : longitudes of the grid
544    plat (optionnal) : latitudes of the grid
545    '''
546    mmath = __mmath__ (plon)
547    jpj, jpi  = plon.shape [-2:]
548    if np.max (plat) :
549        je    = jeq (plat)
550        lon0 = plon [..., je, 0].copy()
551        dlon = plon [..., je, 1].copy() - plon [..., je, 0].copy()
552        lon_1d = np.linspace ( start=lon0, stop=lon0+360.+2*dlon, num=jpi )
553    else :
554        lon0 = plon [..., jpj//3, 0].copy()
555        dlon = plon [..., jpj//3, 1].copy() - plon [..., jpj//3, 0].copy()
556        lon_1d = np.linspace ( start=lon0, stop=lon0+360.+2*dlon, num=jpi )
557
558    #start = np.argmax (np.abs (np.diff (lon1D, axis=-1)) > 180.0, axis=-1)
559    #lon1D [..., start+1:] += 360
560
561    if mmath == xr :
562        lon_1d = xr.DataArray( lon_1d, dims=('lon',), coords=(lon_1d,))
563        lon_1d.attrs.update (plon.attrs)
564        lon_1d.attrs['units']         = 'degrees_east'
565        lon_1d.attrs['standard_name'] = 'longitude'
566        lon_1d.attrs['long_name :']   = 'Longitude'
567
568    return lon_1d
569
570def latreg (plat, diff=0.1) :
571    '''Returns maximum j index where gridlines are along latitudes
572    in the northern hemisphere
573
574    lat : latitudes of the grid (2D)
575    diff [optional] : tolerance
576    '''
577    #mmath = __mmath__ (plat)
578    if diff is None :
579        dy = np.float64 (np.mean (np.abs (plat -
580                        np.roll (plat,shift=1,axis=-2, roll_coords=False))))
581        print ( f'{dy=}' )
582        diff = dy/100.
583
584    je     = jeq (plat)
585    jreg   = np.where (plat[...,je:,:].max(axis=-1) -
586                       plat[...,je:,:].min(axis=-1)< diff)[-1][-1] + je
587    lareg  = np.float64 (plat[...,jreg,:].mean(axis=-1))
588
589    return jreg, lareg
590
591def lat1d (plat) :
592    '''Returns 1D latitudes for zonal means and simple plots.
593
594    plat : latitudes of the grid (2D)
595    '''
596    mmath = __mmath__ (plat)
597    iy = __find_axis__ (plat, 'y')[-1]
598    jpj = plat.shape[iy]
599
600    dy     = np.float64 (np.mean (np.abs (plat - np.roll (plat, shift=1,axis=-2))))
601    je     = jeq (plat)
602    lat_eq = np.float64 (plat[...,je,:].mean(axis=-1))
603
604    jreg, lat_reg = latreg (plat)
605    lat_ave = np.mean (plat, axis=-1)
606
607    if np.abs (lat_eq) < dy/100. : # T, U or W grid
608        if jpj-1 > jreg :
609            dys = (90.-lat_reg) / (jpj-jreg-1)*0.5
610        else            :
611            dys = (90.-lat_reg) / 2.0
612        yrange = 90.-dys-lat_reg
613    else                           :  # V or F grid
614        yrange = 90.-lat_reg
615
616    if jpj-1 > jreg :
617        lat_1d = mmath.where (lat_ave<lat_reg,
618                 lat_ave,
619                 lat_reg + yrange * (np.arange(jpj)-jreg)/(jpj-jreg-1) )
620    else :
621        lat_1d = lat_ave
622    lat_1d[-1] = 90.0
623
624    if mmath == xr :
625        lat_1d = xr.DataArray( lat_1d.values, dims=('lat',), coords=(lat_1d,))
626        lat_1d.attrs.update (plat.attrs)
627        lat_1d.attrs ['units']         = 'degrees_north'
628        lat_1d.attrs ['standard_name'] = 'latitude'
629        lat_1d.attrs ['long_name :']   = 'Latitude'
630
631    return lat_1d
632
633def latlon1d (plat, plon) :
634    '''Returns simple latitude and longitude (1D) for simple plots.
635
636    plat, plon : latitudes and longitudes of the grid (2D)
637    '''
638    return lat1d (plat),  lon1d (plon, plat)
639
640def ff (plat) :
641    '''Returns Coriolis factor
642    '''
643    zff   = np.sin (RAD * plat) * OMEGA
644    return zff
645
646def beta (plat) :
647    '''Return Beta factor (derivative of Coriolis factor)
648    '''
649    zbeta = np.cos (RAD * plat) * OMEGA / RA
650    return zbeta
651
652def mask_lonlat (ptab, x0, x1, y0, y1, lon, lat, sval=np.nan) :
653    '''Returns masked values outside a lat/lon box
654    '''
655    mmath = __mmath__ (ptab)
656    if mmath == xr :
657        lon = lon.copy().to_masked_array()
658        lat = lat.copy().to_masked_array()
659
660    mask = np.logical_and (np.logical_and(lat>y0, lat<y1),
661            np.logical_or (np.logical_or (
662                np.logical_and(lon>x0, lon<x1),
663                np.logical_and(lon+360>x0, lon+360<x1)),
664                np.logical_and(lon-360>x0, lon-360<x1)))
665    tab = mmath.where (mask, ptab, sval)
666
667    return tab
668
669def extend (ptab, blon=False, jplus=25, jpi=None, nperio=4) :
670    '''Returns extended field eastward to have better plots,
671    and box average crossing the boundary
672
673    Works only for xarray and numpy data (?)
674    Useful for plotting vertical sections in OCE and ATM.
675
676    ptab : field to extend.
677    blon  : (optional, default=False) : if True, add 360 in the extended
678          parts of the field
679    jpi   : normal longitude dimension of the field. extend does nothing
680          if the actual size of the field != jpi
681          (avoid to extend several times in notebooks)
682    jplus (optional, default=25) : number of points added on
683          the east side of the field
684
685    '''
686    mmath = __mmath__ (ptab)
687
688    if ptab.shape[-1] == 1 :
689        tabex = ptab
690
691    else :
692        if jpi is None :
693            jpi = ptab.shape[-1]
694
695        if blon :
696            xplus = -360.0
697        else   :
698            xplus =    0.0
699
700        if ptab.shape[-1] > jpi :
701            tabex = ptab
702        else :
703            if nperio in [ 0, 4.2 ] :
704                istart, le, la = 0, jpi+1, 0
705            if nperio == 1 :
706                istart, le, la = 0, jpi+1, 0
707            if nperio in [4, 6] : # OPA case with two halo points for periodicity
708                # Perfect, except at the pole that should be masked by lbc_plot
709                istart, le, la = 1, jpi-2, 1
710            if mmath == xr :
711                tabex = np.concatenate (
712                     (ptab.values[..., istart   :istart+le+1    ] + xplus,
713                      ptab.values[..., istart+la:istart+la+jplus]         ),
714                      axis=-1)
715                lon    = ptab.dims[-1]
716                new_coords = []
717                for coord in ptab.dims :
718                    if coord == lon :
719                        new_coords.append ( np.arange( tabex.shape[-1]))
720                    else            :
721                        new_coords.append ( ptab.coords[coord].values)
722                tabex = xr.DataArray ( tabex, dims=ptab.dims,
723                                           coords=new_coords )
724            else :
725                tabex = np.concatenate (
726                    (ptab [..., istart   :istart+le+1    ] + xplus,
727                     ptab [..., istart+la:istart+la+jplus]          ),
728                     axis=-1)
729    return tabex
730
731def orca2reg (dd, lat_name='nav_lat', lon_name='nav_lon',
732                  y_name='y', x_name='x') :
733    '''Assign an ORCA dataset on a regular grid.
734
735    For use in the tropical region.
736    Inputs :
737      ff : xarray dataset
738      lat_name, lon_name : name of latitude and longitude 2D field in ff
739      y_name, x_name     : namex of dimensions in ff
740
741      Returns : xarray dataset with rectangular grid. Incorrect above 20°N
742    '''
743    # Compute 1D longitude and latitude
744    (zlat, zlon) = latlon1d ( dd[lat_name], dd[lon_name])
745
746    zdd = dd
747    # Assign lon and lat as dimensions of the dataset
748    if y_name in zdd.dims :
749        zlat = xr.DataArray (zlat, coords=[zlat,], dims=['lat',])
750        zdd  = zdd.rename_dims ({y_name: "lat",}).assign_coords (lat=zlat)
751    if x_name in zdd.dims :
752        zlon = xr.DataArray (zlon, coords=[zlon,], dims=['lon',])
753        zdd  = zdd.rename_dims ({x_name: "lon",}).assign_coords (lon=zlon)
754    # Force dimensions to be in the right order
755    coord_order = ['lat', 'lon']
756    for dim in [ 'depthw', 'depthv', 'depthu', 'deptht', 'depth', 'z',
757                 'time_counter', 'time', 'tbnds',
758                 'bnds', 'axis_nbounds', 'two2', 'two1', 'two', 'four',] :
759        if dim in zdd.dims :
760            coord_order.insert (0, dim)
761
762    zdd = zdd.transpose (*coord_order)
763    return zdd
764
765def lbc_init (ptab, nperio=None) :
766    '''Prepare for all lbc calls
767
768    Set periodicity on input field
769    nperio    : Type of periodicity
770      0       : No periodicity
771      1, 4, 6 : Cyclic on i dimension (generaly longitudes) with 2 points halo
772      2       : Obsolete (was symmetric condition at southern boundary ?)
773      3, 4    : North fold T-point pivot (legacy ORCA2)
774      5, 6    : North fold F-point pivot (ORCA1, ORCA025, ORCA2 with new grid for paleo)
775    cd_type   : Grid specification : T, U, V or F
776
777    See NEMO documentation for further details
778    '''
779    jpi, jpj = None, None
780    ax, ix = __find_axis__ (ptab, 'x')
781    ay, jy = __find_axis__ (ptab, 'y')
782    if ax :
783        jpi = ptab.shape[ix]
784    if ay :
785        jpj = ptab.shape[jy]
786
787    if nperio is None :
788        nperio = __guess_nperio__ (jpj, jpi, nperio)
789
790    if nperio not in NPERIO_VALID_RANGE :
791        raise AttributeError ( f'{nperio=} is not in the valid range {NPERIO_VALID_RANGE}' )
792
793    return jpj, jpi, nperio
794
795def lbc (ptab, nperio=None, cd_type='T', psgn=1.0, nemo_4u_bug=False) :
796    '''Set periodicity on input field
797
798    ptab      : Input array (works for rank 2 at least : ptab[...., lat, lon])
799    nperio    : Type of periodicity
800    cd_type   : Grid specification : T, U, V or F
801    psgn      : For change of sign for vector components (1 for scalars, -1 for vector components)
802
803    See NEMO documentation for further details
804    '''
805    jpi, nperio = lbc_init (ptab, nperio)[1:]
806    ax = __find_axis__ (ptab, 'x')[0]
807    ay = __find_axis__ (ptab, 'y')[0]
808    psgn   = ptab.dtype.type (psgn)
809    mmath  = __mmath__ (ptab)
810
811    if mmath == xr :
812        ztab = ptab.values.copy ()
813    else           :
814        ztab = ptab.copy ()
815
816    if ax :
817        #
818        #> East-West boundary conditions
819        # ------------------------------
820        if nperio in [1, 4, 6] :
821        # ... cyclic
822            ztab [...,  0] = ztab [..., -2]
823            ztab [..., -1] = ztab [...,  1]
824
825        if ay :
826            #
827            #> North-South boundary conditions
828            # --------------------------------
829            if nperio in [3, 4] :  # North fold T-point pivot
830                if cd_type in [ 'T', 'W' ] : # T-, W-point
831                    ztab [..., -1, 1:       ] = psgn * ztab [..., -3, -1:0:-1      ]
832                    ztab [..., -1, 0        ] = psgn * ztab [..., -3, 2            ]
833                    ztab [..., -2, jpi//2:  ] = psgn * ztab [..., -2, jpi//2:0:-1  ]
834
835                if cd_type == 'U' :
836                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -3, -1:0:-1      ]
837                    ztab [..., -1,  0       ] = psgn * ztab [..., -3,  1           ]
838                    ztab [..., -1, -1       ] = psgn * ztab [..., -3, -2           ]
839
840                    if nemo_4u_bug :
841                        ztab [..., -2, jpi//2+1:-1] = psgn * ztab [..., -2, jpi//2-2:0:-1]
842                        ztab [..., -2, jpi//2-1   ] = psgn * ztab [..., -2, jpi//2       ]
843                    else :
844                        ztab [..., -2, jpi//2-1:-1] = psgn * ztab [..., -2, jpi//2:0:-1]
845
846                if cd_type == 'V' :
847                    ztab [..., -2, 1:       ] = psgn * ztab [..., -3, jpi-1:0:-1   ]
848                    ztab [..., -1, 1:       ] = psgn * ztab [..., -4, -1:0:-1      ]
849                    ztab [..., -1, 0        ] = psgn * ztab [..., -4, 2            ]
850
851                if cd_type == 'F' :
852                    ztab [..., -2, 0:-1     ] = psgn * ztab [..., -3, -1:0:-1      ]
853                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -4, -1:0:-1      ]
854                    ztab [..., -1,  0       ] = psgn * ztab [..., -4,  1           ]
855                    ztab [..., -1, -1       ] = psgn * ztab [..., -4, -2           ]
856
857            if nperio in [4.2] :  # North fold T-point pivot
858                if cd_type in [ 'T', 'W' ] : # T-, W-point
859                    ztab [..., -1, jpi//2:  ] = psgn * ztab [..., -1, jpi//2:0:-1  ]
860
861                if cd_type == 'U' :
862                    ztab [..., -1, jpi//2-1:-1] = psgn * ztab [..., -1, jpi//2:0:-1]
863
864                if cd_type == 'V' :
865                    ztab [..., -1, 1:       ] = psgn * ztab [..., -2, jpi-1:0:-1   ]
866
867                if cd_type == 'F' :
868                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -2, -1:0:-1      ]
869
870            if nperio in [5, 6] :            #  North fold F-point pivot
871                if cd_type in ['T', 'W']  :
872                    ztab [..., -1, 0:       ] = psgn * ztab [..., -2, -1::-1       ]
873
874                if cd_type == 'U' :
875                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -2, -2::-1       ]
876                    ztab [..., -1, -1       ] = psgn * ztab [..., -2, 0            ] # Bug ?
877
878                if cd_type == 'V' :
879                    ztab [..., -1, 0:       ] = psgn * ztab [..., -3, -1::-1       ]
880                    ztab [..., -2, jpi//2:  ] = psgn * ztab [..., -2, jpi//2-1::-1 ]
881
882                if cd_type == 'F' :
883                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -3, -2::-1       ]
884                    ztab [..., -1, -1       ] = psgn * ztab [..., -3, 0            ]
885                    ztab [..., -2, jpi//2:-1] = psgn * ztab [..., -2, jpi//2-2::-1 ]
886
887            #
888            #> East-West boundary conditions
889            # ------------------------------
890            if nperio in [1, 4, 6] :
891                # ... cyclic
892                ztab [...,  0] = ztab [..., -2]
893                ztab [..., -1] = ztab [...,  1]
894
895    if mmath == xr :
896        ztab = xr.DataArray ( ztab, dims=ptab.dims, coords=ptab.coords )
897        ztab.attrs = ptab.attrs
898
899    return ztab
900
901def lbc_mask (ptab, nperio=None, cd_type='T', sval=np.nan) :
902    '''Mask fields on duplicated points
903
904    ptab      : Input array. Rank 2 at least : ptab [...., lat, lon]
905    nperio    : Type of periodicity
906    cd_type   : Grid specification : T, U, V or F
907
908    See NEMO documentation for further details
909    '''
910    jpi, nperio = lbc_init (ptab, nperio)[1:]
911    ax = __find_axis__ (ptab, 'x')[0]
912    ay = __find_axis__ (ptab, 'y')[0]
913    ztab = ptab.copy ()
914
915    if ax :
916        #
917        #> East-West boundary conditions
918        # ------------------------------
919        if nperio in [1, 4, 6] :
920        # ... cyclic
921            ztab [...,  0] = sval
922            ztab [..., -1] = sval
923
924        if ay :
925            #
926            #> South (in which nperio cases ?)
927            # --------------------------------
928            if nperio in [1, 3, 4, 5, 6] :
929                ztab [..., 0, :] = sval
930
931            #
932            #> North-South boundary conditions
933            # --------------------------------
934            if nperio in [3, 4] :  # North fold T-point pivot
935                if cd_type in [ 'T', 'W' ] : # T-, W-point
936                    ztab [..., -1,  :         ] = sval
937                    ztab [..., -2, :jpi//2  ] = sval
938
939                if cd_type == 'U' :
940                    ztab [..., -1,  :         ] = sval
941                    ztab [..., -2, jpi//2+1:  ] = sval
942
943                if cd_type == 'V' :
944                    ztab [..., -2, :       ] = sval
945                    ztab [..., -1, :       ] = sval
946
947                if cd_type == 'F' :
948                    ztab [..., -2, :       ] = sval
949                    ztab [..., -1, :       ] = sval
950
951            if nperio in [4.2] :  # North fold T-point pivot
952                if cd_type in [ 'T', 'W' ] : # T-, W-point
953                    ztab [..., -1, jpi//2  :  ] = sval
954
955                if cd_type == 'U' :
956                    ztab [..., -1, jpi//2-1:-1] = sval
957
958                if cd_type == 'V' :
959                    ztab [..., -1, 1:       ] = sval
960
961                if cd_type == 'F' :
962                    ztab [..., -1, 0:-1     ] = sval
963
964            if nperio in [5, 6] :            #  North fold F-point pivot
965                if cd_type in ['T', 'W']  :
966                    ztab [..., -1, 0:       ] = sval
967
968                if cd_type == 'U' :
969                    ztab [..., -1, 0:-1     ] = sval
970                    ztab [..., -1, -1       ] = sval
971
972                if cd_type == 'V' :
973                    ztab [..., -1, 0:       ] = sval
974                    ztab [..., -2, jpi//2:  ] = sval
975
976                if cd_type == 'F' :
977                    ztab [..., -1, 0:-1       ] = sval
978                    ztab [..., -1, -1         ] = sval
979                    ztab [..., -2, jpi//2+1:-1] = sval
980
981    return ztab
982
983def lbc_plot (ptab, nperio=None, cd_type='T', psgn=1.0, sval=np.nan) :
984    '''Set periodicity on input field, for plotting for any cartopy projection
985
986      Points at the north fold are masked
987      Points for zonal periodicity are kept
988    ptab      : Input array. Rank 2 at least : ptab[...., lat, lon]
989    nperio    : Type of periodicity
990    cd_type   : Grid specification : T, U, V or F
991    psgn      : For change of sign for vector components
992           (1 for scalars, -1 for vector components)
993
994    See NEMO documentation for further details
995    '''
996    jpi, nperio = lbc_init (ptab, nperio)[1:]
997    ax = __find_axis__ (ptab, 'x')[0]
998    ay = __find_axis__ (ptab, 'y')[0]
999    psgn   = ptab.dtype.type (psgn)
1000    ztab   = ptab.copy ()
1001
1002    if ax :
1003        #
1004        #> East-West boundary conditions
1005        # ------------------------------
1006        if nperio in [1, 4, 6] :
1007            # ... cyclic
1008            ztab [..., :,  0] = ztab [..., :, -2]
1009            ztab [..., :, -1] = ztab [..., :,  1]
1010
1011        if ay :
1012            #> Masks south
1013            # ------------
1014            if nperio in [4, 6] :
1015                ztab [..., 0, : ] = sval
1016
1017            #
1018            #> North-South boundary conditions
1019            # --------------------------------
1020            if nperio in [3, 4] :  # North fold T-point pivot
1021                if cd_type in [ 'T', 'W' ] : # T-, W-point
1022                    ztab [..., -1,  :      ] = sval
1023                    #ztab [..., -2, jpi//2: ] = sval
1024                    ztab [..., -2, :jpi//2 ] = sval # Give better plots than above
1025                if cd_type == 'U' :
1026                    ztab [..., -1, : ] = sval
1027
1028                if cd_type == 'V' :
1029                    ztab [..., -2, : ] = sval
1030                    ztab [..., -1, : ] = sval
1031
1032                if cd_type == 'F' :
1033                    ztab [..., -2, : ] = sval
1034                    ztab [..., -1, : ] = sval
1035
1036            if nperio in [4.2] :  # North fold T-point pivot
1037                if cd_type in [ 'T', 'W' ] : # T-, W-point
1038                    ztab [..., -1, jpi//2:  ] = sval
1039
1040                if cd_type == 'U' :
1041                    ztab [..., -1, jpi//2-1:-1] = sval
1042
1043                if cd_type == 'V' :
1044                    ztab [..., -1, 1:       ] = sval
1045
1046                if cd_type == 'F' :
1047                    ztab [..., -1, 0:-1     ] = sval
1048
1049            if nperio in [5, 6] :            #  North fold F-point pivot
1050                if cd_type in ['T', 'W']  :
1051                    ztab [..., -1, : ] = sval
1052
1053                if cd_type == 'U' :
1054                    ztab [..., -1, : ] = sval
1055
1056                if cd_type == 'V' :
1057                    ztab [..., -1, :        ] = sval
1058                    ztab [..., -2, jpi//2:  ] = sval
1059
1060                if cd_type == 'F' :
1061                    ztab [..., -1, :          ] = sval
1062                    ztab [..., -2, jpi//2+1:-1] = sval
1063
1064    return ztab
1065
1066def lbc_add (ptab, nperio=None, cd_type=None, psgn=1) :
1067    '''Handles NEMO domain changes between NEMO 4.0 to NEMO 4.2
1068
1069    Periodicity and north fold halos has been removed in NEMO 4.2
1070    This routine adds the halos if needed
1071
1072    ptab      : Input array (works
1073      rank 2 at least : ptab[...., lat, lon]
1074    nperio    : Type of periodicity
1075
1076    See NEMO documentation for further details
1077    '''
1078    mmath = __mmath__ (ptab)
1079    nperio = lbc_init (ptab, nperio)[-1]
1080    lshape = get_shape (ptab)
1081    ix = __find_axis__ (ptab, 'x')[-1]
1082    jy = __find_axis__ (ptab, 'y')[-1]
1083
1084    t_shape = np.array (ptab.shape)
1085
1086    if nperio in [4.2, 6.2] :
1087
1088        ext_shape = t_shape.copy()
1089        if 'X' in lshape :
1090            ext_shape[ix] = ext_shape[ix] + 2
1091        if 'Y' in lshape :
1092            ext_shape[jy] = ext_shape[jy] + 1
1093
1094        if mmath == xr :
1095            ptab_ext = xr.DataArray (np.zeros (ext_shape), dims=ptab.dims)
1096            if 'X' in lshape and 'Y' in lshape :
1097                ptab_ext.values[..., :-1, 1:-1] = ptab.values.copy ()
1098            else :
1099                if 'X' in lshape     :
1100                    ptab_ext.values[...,      1:-1] = ptab.values.copy ()
1101                if 'Y' in lshape     :
1102                    ptab_ext.values[..., :-1      ] = ptab.values.copy ()
1103        else           :
1104            ptab_ext =               np.zeros (ext_shape)
1105            if 'X' in lshape and 'Y' in lshape :
1106                ptab_ext       [..., :-1, 1:-1] = ptab.copy ()
1107            else :
1108                if 'X' in lshape     :
1109                    ptab_ext       [...,      1:-1] = ptab.copy ()
1110                if 'Y' in lshape     :
1111                    ptab_ext       [..., :-1      ] = ptab.copy ()
1112
1113        if nperio == 4.2 :
1114            ptab_ext = lbc (ptab_ext, nperio=4, cd_type=cd_type, psgn=psgn)
1115        if nperio == 6.2 :
1116            ptab_ext = lbc (ptab_ext, nperio=6, cd_type=cd_type, psgn=psgn)
1117
1118        if mmath == xr :
1119            ptab_ext.attrs = ptab.attrs
1120            az = __find_axis__ (ptab, 'z')[0]
1121            at = __find_axis__ (ptab, 't')[0]
1122            if az :
1123                ptab_ext = ptab_ext.assign_coords ( {az:ptab.coords[az]} )
1124            if at :
1125                ptab_ext = ptab_ext.assign_coords ( {at:ptab.coords[at]} )
1126
1127    else : ptab_ext = lbc (ptab, nperio=nperio, cd_type=cd_type, psgn=psgn)
1128
1129    return ptab_ext
1130
1131def lbc_del (ptab, nperio=None, cd_type='T', psgn=1) :
1132    '''Handles NEMO domain changes between NEMO 4.0 to NEMO 4.2
1133
1134    Periodicity and north fold halos has been removed in NEMO 4.2
1135    This routine removes the halos if needed
1136
1137    ptab      : Input array (works
1138      rank 2 at least : ptab[...., lat, lon]
1139    nperio    : Type of periodicity
1140
1141    See NEMO documentation for further details
1142    '''
1143    nperio = lbc_init (ptab, nperio)[-1]
1144    #lshape = get_shape (ptab)
1145    ax = __find_axis__ (ptab, 'x')[0]
1146    ay = __find_axis__ (ptab, 'y')[0]
1147
1148    if nperio in [4.2, 6.2] :
1149        if ax or ay :
1150            if ax and ay :
1151                ztab = lbc (ptab[..., :-1, 1:-1],
1152                            nperio=nperio, cd_type=cd_type, psgn=psgn)
1153            else :
1154                if ax :
1155                    ztab = lbc (ptab[...,      1:-1],
1156                                nperio=nperio, cd_type=cd_type, psgn=psgn)
1157                if ay :
1158                    ztab = lbc (ptab[..., -1],
1159                                nperio=nperio, cd_type=cd_type, psgn=psgn)
1160        else :
1161            ztab = ptab
1162    else :
1163        ztab = ptab
1164
1165    return ztab
1166
1167def lbc_index (jj, ii, jpj, jpi, nperio=None, cd_type='T') :
1168    '''For indexes of a NEMO point, give the corresponding point
1169        inside the domain (i.e. not in the halo)
1170
1171    jj, ii    : indexes
1172    jpi, jpi  : size of domain
1173    nperio    : type of periodicity
1174    cd_type   : grid specification : T, U, V or F
1175
1176    See NEMO documentation for further details
1177    '''
1178
1179    if nperio is None :
1180        nperio = __guess_nperio__ (jpj, jpi, nperio)
1181
1182    ## For the sake of simplicity, switch to the convention of original
1183    ## lbc Fortran routine from NEMO : starts indexes at 1
1184    jy = jj + 1
1185    ix = ii + 1
1186
1187    mmath = __mmath__ (jj)
1188    if mmath is None :
1189        mmath=np
1190
1191    #
1192    #> East-West boundary conditions
1193    # ------------------------------
1194    if nperio in [1, 4, 6] :
1195        #... cyclic
1196        ix = mmath.where (ix==jpi, 2   , ix)
1197        ix = mmath.where (ix== 1 ,jpi-1, ix)
1198
1199    #
1200    def mod_ij (cond, jy_new, ix_new) :
1201        jy_r = mmath.where (cond, jy_new, jy)
1202        ix_r = mmath.where (cond, ix_new, ix)
1203        return jy_r, ix_r
1204    #
1205    #> North-South boundary conditions
1206    # --------------------------------
1207    if nperio in [ 3 , 4 ]  :
1208        if cd_type in  [ 'T' , 'W' ] :
1209            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix>=2       ), jpj-2, jpi-ix+2)
1210            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==1       ), jpj-1, 3       )
1211            jy, ix = mod_ij (np.logical_and (jy==jpj-1, ix>=jpi//2+1),
1212                                  jy , jpi-ix+2)
1213
1214        if cd_type in [ 'U' ] :
1215            jy, ix = mod_ij (np.logical_and (
1216                      jy==jpj  ,
1217                      np.logical_and (ix>=1, ix <= jpi-1)   ),
1218                                         jy   , jpi-ix+1)
1219            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==1  ) , jpj-2, 2       )
1220            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==jpi) , jpj-2, jpi-1   )
1221            jy, ix = mod_ij (np.logical_and (jy==jpj-1,
1222                            np.logical_and (ix>=jpi//2, ix<=jpi-1)), jy   , jpi-ix+1)
1223
1224        if cd_type in [ 'V' ] :
1225            jy, ix = mod_ij (np.logical_and (jy==jpj-1, ix>=2  ), jpj-2, jpi-ix+2)
1226            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix>=2  ), jpj-3, jpi-ix+2)
1227            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==1  ), jpj-3,  3      )
1228
1229        if cd_type in [ 'F' ] :
1230            jy, ix = mod_ij (np.logical_and (jy==jpj-1, ix<=jpi-1), jpj-2, jpi-ix+1)
1231            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix<=jpi-1), jpj-3, jpi-ix+1)
1232            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==1    ), jpj-3, 2       )
1233            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==jpi  ), jpj-3, jpi-1   )
1234
1235    if nperio in [ 5 , 6 ] :
1236        if cd_type in [ 'T' , 'W' ] :                        # T-, W-point
1237            jy, ix = mod_ij (jy==jpj, jpj-1, jpi-ix+1)
1238
1239        if cd_type in [ 'U' ] :                              # U-point
1240            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix<=jpi-1   ), jpj-1, jpi-ix  )
1241            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==jpi     ), jpi-1, 1       )
1242
1243        if cd_type in [ 'V' ] :    # V-point
1244            jy, ix = mod_ij (jy==jpj                                 , jy   , jpi-ix+1)
1245            jy, ix = mod_ij (np.logical_and (jy==jpj-1, ix>=jpi//2+1), jy   , jpi-ix+1)
1246
1247        if cd_type in [ 'F' ] :                              # F-point
1248            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix<=jpi-1   ), jpj-2, jpi-ix  )
1249            jy, ix = mod_ij (np.logical_and (ix==jpj  , ix==jpi     ), jpj-2, 1       )
1250            jy, ix = mod_ij (np.logical_and (jy==jpj-1, ix>=jpi//2+1), jy   , jpi-ix  )
1251
1252    ## Restore convention to Python/C : indexes start at 0
1253    jy += -1
1254    ix += -1
1255
1256    if isinstance (jj, int) :
1257        jy = jy.item ()
1258    if isinstance (ii, int) :
1259        ix = ix.item ()
1260
1261    return jy, ix
1262
1263def find_ji (lat_data, lon_data, lat_grid, lon_grid, mask=1.0, verbose=False, out=None) :
1264    '''
1265    Description: seeks J,I indices of the grid point which is the closest
1266       of a given point
1267
1268    Usage: go FindJI  <data latitude> <data longitude> <grid latitudes> <grid longitudes> [mask]
1269    <grid latitudes><grid longitudes> are 2D fields on J/I (Y/X) dimensions
1270    mask : if given, seek only non masked grid points (i.e with mask=1)
1271
1272    Example : findIJ (40, -20, nav_lat, nav_lon, mask=1.0)
1273
1274    Note : all longitudes and latitudes in degrees
1275
1276    Note : may work with 1D lon/lat (?)
1277    '''
1278    # Get grid dimensions
1279    if len (lon_grid.shape) == 2 :
1280        jpi = lon_grid.shape[-1]
1281    else                         :
1282        jpi = len(lon_grid)
1283
1284    #mmath = __mmath__ (lat_grid)
1285
1286    # Compute distance from the point to all grid points (in RADian)
1287    arg      = ( np.sin (RAD*lat_data) * np.sin (RAD*lat_grid)
1288               + np.cos (RAD*lat_data) * np.cos (RAD*lat_grid) *
1289                 np.cos(RAD*(lon_data-lon_grid)) )
1290    # Send masked points to 'infinite'
1291    distance = np.arccos (arg) + 4.0*RPI*(1.0-mask)
1292
1293    # Truncates to alleviate some precision problem with some grids
1294    prec = int (1E7)
1295    distance = (distance*prec).astype(int) / prec
1296
1297    # Compute minimum of distance, and index of minimum
1298    #
1299    #distance_min = distance.min    ()
1300    jimin        = int (distance.argmin ())
1301
1302    # Compute 2D indices (Python/C flavor : starting at 0)
1303    jmin = jimin // jpi
1304    imin = jimin - jmin*jpi
1305
1306    # Result
1307    if verbose :
1308        # Compute distance achieved
1309        #mindist = distance [jmin, imin]
1310
1311        # Compute azimuth
1312        dlon = lon_data-lon_grid[jmin,imin]
1313        arg  = np.sin (RAD*dlon) / (
1314                  np.cos(RAD*lat_data)*np.tan(RAD*lat_grid[jmin,imin])
1315                - np.sin(RAD*lat_data)*np.cos(RAD*dlon))
1316        azimuth = DAR*np.arctan (arg)
1317        print ( f'I={imin:d} J={jmin:d} - Data:{lat_data:5.1f}°N {lon_data:5.1f}°E' \
1318            +   f'- Grid:{lat_grid[jmin,imin]:4.1f}°N '   \
1319            +   f'{lon_grid[jmin,imin]:4.1f}°E - Dist: {RA*distance[jmin,imin]:6.1f}km' \
1320            +   f' {DAR*distance[jmin,imin]:5.2f}° ' \
1321            +   f'- Azimuth: {RAD*azimuth:3.2f}RAD - {azimuth:5.1f}°' )
1322
1323    if   out=='dict'                     :
1324        return {'x':imin, 'y':jmin}
1325    elif out in ['array', 'numpy', 'np'] :
1326        return np.array ( [jmin, imin] )
1327    elif out in ['xarray', 'xr']         :
1328        return xr.DataArray ( [jmin, imin] )
1329    elif out=='list'                     :
1330        return [jmin, imin]
1331    elif out=='tuple'                    :
1332        return jmin, imin
1333    else                                 :
1334        return jmin, imin
1335
1336def curl (tx, ty, e1f, e2f, nperio=None) :
1337    '''Returns curl of a vector field
1338    '''
1339    ax = __find_axis__ (tx, 'x')[0]
1340    ay = __find_axis__ (ty, 'y')[0]
1341
1342    tx_0  = lbc_add (tx , nperio=nperio, cd_type='U', psgn=-1)
1343    ty_0  = lbc_add (ty , nperio=nperio, cd_type='V', psgn=-1)
1344    e1f_0 = lbc_add (e1f, nperio=nperio, cd_type='U', psgn=-1)
1345    e2f_0 = lbc_add (e2f, nperio=nperio, cd_type='V', psgn=-1)
1346
1347    tx_1  = tx_0.roll ( {ay:-1} )
1348    ty_1  = ty_0.roll ( {ax:-1} )
1349    tx_1  = lbc (tx_1, nperio=nperio, cd_type='U', psgn=-1)
1350    ty_1  = lbc (ty_1, nperio=nperio, cd_type='V', psgn=-1)
1351
1352    zcurl = (ty_1 - ty_0)/e1f_0 - (tx_1 - tx_0)/e2f_0
1353
1354    mask = np.logical_or (
1355             np.logical_or ( np.isnan(tx_0), np.isnan(tx_1)),
1356             np.logical_or ( np.isnan(ty_0), np.isnan(ty_1)) )
1357
1358    zcurl = zcurl.where (np.logical_not (mask), np.nan)
1359
1360    zcurl = lbc_del (zcurl, nperio=nperio, cd_type='F', psgn=1)
1361    zcurl = lbc (zcurl, nperio=nperio, cd_type='F', psgn=1)
1362
1363    return zcurl
1364
1365def div (ux, uy, e1t, e2t, nperio=None) :
1366    '''Returns divergence of a vector field
1367    '''
1368    ax = __find_axis__ (ux, 'x')[0]
1369    ay = __find_axis__ (ux, 'y')[0]
1370
1371    ux_0  = lbc_add (ux , nperio=nperio, cd_type='U', psgn=-1)
1372    uy_0  = lbc_add (uy , nperio=nperio, cd_type='V', psgn=-1)
1373    e1t_0 = lbc_add (e1t, nperio=nperio, cd_type='U', psgn=-1)
1374    e2t_0 = lbc_add (e2t, nperio=nperio, cd_type='V', psgn=-1)
1375
1376    ux_1 = ux_0.roll ( {ay:1} )
1377    uy_1 = uy_0.roll ( {ax:1} )
1378    ux_1 = lbc (ux_1, nperio=nperio, cd_type='U', psgn=-1)
1379    uy_1 = lbc (uy_1, nperio=nperio, cd_type='V', psgn=-1)
1380
1381    zdiv = (ux_0 - ux_1)/e2t_0 + (uy_0 - uy_1)/e1t_0
1382
1383    mask = np.logical_or (
1384             np.logical_or ( np.isnan(ux_0), np.isnan(ux_1)),
1385             np.logical_or ( np.isnan(uy_0), np.isnan(uy_1)) )
1386
1387    zdiv = zdiv.where (np.logical_not (mask), np.nan)
1388
1389    zdiv = lbc_del (zdiv, nperio=nperio, cd_type='T', psgn=1)
1390    zdiv = lbc (zdiv, nperio=nperio, cd_type='T', psgn=1)
1391
1392    return zdiv
1393
1394def geo2en (pxx, pyy, pzz, glam, gphi) :
1395    '''Change vector from geocentric to east/north
1396
1397    Inputs :
1398        pxx, pyy, pzz : components on the geocentric system
1399        glam, gphi : longitude and latitude of the points
1400    '''
1401
1402    gsinlon = np.sin (RAD * glam)
1403    gcoslon = np.cos (RAD * glam)
1404    gsinlat = np.sin (RAD * gphi)
1405    gcoslat = np.cos (RAD * gphi)
1406
1407    pte = - pxx * gsinlon            + pyy * gcoslon
1408    ptn = - pxx * gcoslon * gsinlat  - pyy * gsinlon * gsinlat + pzz * gcoslat
1409
1410    return pte, ptn
1411
1412def en2geo (pte, ptn, glam, gphi) :
1413    '''Change vector from east/north to geocentric
1414
1415    Inputs :
1416        pte, ptn   : eastward/northward components
1417        glam, gphi : longitude and latitude of the points
1418    '''
1419
1420    gsinlon = np.sin (RAD * glam)
1421    gcoslon = np.cos (RAD * glam)
1422    gsinlat = np.sin (RAD * gphi)
1423    gcoslat = np.cos (RAD * gphi)
1424
1425    pxx = - pte * gsinlon - ptn * gcoslon * gsinlat
1426    pyy =   pte * gcoslon - ptn * gsinlon * gsinlat
1427    pzz =   ptn * gcoslat
1428
1429    return pxx, pyy, pzz
1430
1431
1432def clo_lon (lon, lon0=0., rad=False, deg=True) :
1433    '''Choose closest to lon0 longitude, adding/substacting 360°
1434    if needed
1435    '''
1436    mmath = __mmath__ (lon, np)
1437    if rad :
1438        lon_range = 2.*np.pi
1439    if deg :
1440        lon_range = 360.
1441    c_lon = lon
1442    c_lon = mmath.where (c_lon > lon0 + lon_range*0.5,
1443                             c_lon-lon_range, clo_lon)
1444    c_lon = mmath.where (c_lon < lon0 - lon_range*0.5,
1445                             c_lon+lon_range, clo_lon)
1446    c_lon = mmath.where (c_lon > lon0 + lon_range*0.5,
1447                             c_lon-lon_range, clo_lon)
1448    c_lon = mmath.where (c_lon < lon0 - lon_range*0.5,
1449                             c_lon+lon_range, clo_lon)
1450    if c_lon.shape == () :
1451        c_lon = c_lon.item ()
1452    if mmath == xr :
1453        if lon.attrs :
1454            c_lon.attrs.update ( lon.attrs )
1455    return c_lon
1456
1457def index2depth (pk, gdept_0) :
1458    '''From index (real, continuous), get depth
1459
1460    Needed to use transforms in Matplotlib
1461    '''
1462    jpk = gdept_0.shape[0]
1463    kk = xr.DataArray(pk)
1464    k  = np.maximum (0, np.minimum (jpk-1, kk    ))
1465    k0 = np.floor (k).astype (int)
1466    k1 = np.maximum (0, np.minimum (jpk-1,  k0+1))
1467    zz = k - k0
1468    gz = (1.0-zz)*gdept_0[k0]+ zz*gdept_0[k1]
1469    return gz.values
1470
1471def depth2index (pz, gdept_0) :
1472    '''From depth, get index (real, continuous)
1473
1474    Needed to use transforms in Matplotlib
1475    '''
1476    jpk  = gdept_0.shape[0]
1477    if   isinstance (pz, xr.core.dataarray.DataArray ) :
1478        zz   = xr.DataArray (pz.values, dims=('zz',))
1479    elif isinstance (pz,  np.ndarray) :
1480        zz   = xr.DataArray (pz.ravel(), dims=('zz',))
1481    else :
1482        zz   = xr.DataArray (np.array([pz]).ravel(), dims=('zz',))
1483    zz   = np.minimum (gdept_0[-1], np.maximum (0, zz))
1484
1485    idk1 = np.minimum ( (gdept_0-zz), 0.).argmax (axis=0).astype(int)
1486    idk1 = np.maximum (0, np.minimum (jpk-1,  idk1  ))
1487    idk2 = np.maximum (0, np.minimum (jpk-1,  idk1-1))
1488
1489    zff = (zz - gdept_0[idk2])/(gdept_0[idk1]-gdept_0[idk2])
1490    idk = zff*idk1 + (1.0-zff)*idk2
1491    idk = xr.where ( np.isnan(idk), idk1, idk)
1492    return idk.values
1493
1494def index2depth_panels (pk, gdept_0, depth0, fact) :
1495    '''From  index (real, continuous), get depth, with bottom part compressed
1496
1497    Needed to use transforms in Matplotlib
1498    '''
1499    jpk = gdept_0.shape[0]
1500    kk = xr.DataArray (pk)
1501    k  = np.maximum (0, np.minimum (jpk-1, kk    ))
1502    k0 = np.floor (k).astype (int)
1503    k1 = np.maximum (0, np.minimum (jpk-1,  k0+1))
1504    zz = k - k0
1505    gz = (1.0-zz)*gdept_0[k0]+ zz*gdept_0[k1]
1506    gz = xr.where ( gz<depth0, gz, depth0 + (gz-depth0)*fact)
1507    return gz.values
1508
1509def depth2index_panels (pz, gdept_0, depth0, fact) :
1510    '''From  index (real, continuous), get depth, with bottom part compressed
1511
1512    Needed to use transforms in Matplotlib
1513    '''
1514    jpk = gdept_0.shape[0]
1515    if isinstance (pz, xr.core.dataarray.DataArray) :
1516        zz   = xr.DataArray (pz.values , dims=('zz',))
1517    elif isinstance (pz, np.ndarray) :
1518        zz   = xr.DataArray (pz.ravel(), dims=('zz',))
1519    else :
1520        zz   = xr.DataArray (np.array([pz]).ravel(), dims=('zz',))
1521    zz         = np.minimum (gdept_0[-1], np.maximum (0, zz))
1522    gdept_comp = xr.where ( gdept_0>depth0,
1523                                (gdept_0-depth0)*fact+depth0, gdept_0)
1524    zz_comp    = xr.where ( zz     >depth0, (zz     -depth0)*fact+depth0,
1525                                zz     )
1526    zz_comp    = np.minimum (gdept_comp[-1], np.maximum (0, zz_comp))
1527
1528    idk1 = np.minimum ( (gdept_0-zz_comp), 0.).argmax (axis=0).astype(int)
1529    idk1 = np.maximum (0, np.minimum (jpk-1,  idk1  ))
1530    idk2 = np.maximum (0, np.minimum (jpk-1,  idk1-1))
1531
1532    zff = (zz_comp - gdept_0[idk2])/(gdept_0[idk1]-gdept_0[idk2])
1533    idk = zff*idk1 + (1.0-zff)*idk2
1534    idk = xr.where ( np.isnan(idk), idk1, idk)
1535    return idk.values
1536
1537def depth2comp (pz, depth0, fact ) :
1538    '''Form depth, get compressed depth, with bottom part compressed
1539
1540    Needed to use transforms in Matplotlib
1541    '''
1542    #print ('start depth2comp')
1543    if isinstance (pz, xr.core.dataarray.DataArray) :
1544        zz   = pz.values
1545    elif isinstance(pz, list) :
1546        zz = np.array (pz)
1547    else :
1548        zz   = pz
1549    gz = np.where ( zz>depth0, (zz-depth0)*fact+depth0, zz)
1550    #print ( f'depth2comp : {gz=}' )
1551    if type (pz) in [int, float] :
1552        return gz.item()
1553    else :
1554        return gz
1555
1556def comp2depth (pz, depth0, fact ) :
1557    '''Form compressed depth, get depth, with bottom part compressed
1558
1559    Needed to use transforms in Matplotlib
1560    '''
1561    if isinstance (pz, xr.core.dataarray.DataArray) :
1562        zz   = pz.values
1563    elif isinstance (pz, list) :
1564        zz = np.array (pz)
1565    else :
1566        zz   = pz
1567    gz = np.where ( zz>depth0, (zz-depth0)/fact+depth0, zz)
1568    if type (pz) in [int, float] :
1569        gz = gz.item()
1570
1571    return gz
1572
1573def index2lon (pi, plon_1d) :
1574    '''From index (real, continuous), get longitude
1575
1576    Needed to use transforms in Matplotlib
1577    '''
1578    jpi = plon_1d.shape[0]
1579    ii = xr.DataArray (pi)
1580    i =  np.maximum (0, np.minimum (jpi-1, ii    ))
1581    i0 = np.floor (i).astype (int)
1582    i1 = np.maximum (0, np.minimum (jpi-1,  i0+1))
1583    xx = i - i0
1584    gx = (1.0-xx)*plon_1d[i0]+ xx*plon_1d[i1]
1585    return gx.values
1586
1587def lon2index (px, plon_1d) :
1588    '''From longitude, get index (real, continuous)
1589
1590    Needed to use transforms in Matplotlib
1591    '''
1592    jpi  = plon_1d.shape[0]
1593    if isinstance (px, xr.core.dataarray.DataArray) :
1594        xx   = xr.DataArray (px.values , dims=('xx',))
1595    elif isinstance (px, np.ndarray) :
1596        xx   = xr.DataArray (px.ravel(), dims=('xx',))
1597    else :
1598        xx   = xr.DataArray (np.array([px]).ravel(), dims=('xx',))
1599    xx   = xr.where ( xx>plon_1d.max(), xx-360.0, xx)
1600    xx   = xr.where ( xx<plon_1d.min(), xx+360.0, xx)
1601    xx   = np.minimum (plon_1d.max(), np.maximum(xx, plon_1d.min() ))
1602    idi1 = np.minimum ( (plon_1d-xx), 0.).argmax (axis=0).astype(int)
1603    idi1 = np.maximum (0, np.minimum (jpi-1,  idi1  ))
1604    idi2 = np.maximum (0, np.minimum (jpi-1,  idi1-1))
1605
1606    zff = (xx - plon_1d[idi2])/(plon_1d[idi1]-plon_1d[idi2])
1607    idi = zff*idi1 + (1.0-zff)*idi2
1608    idi = xr.where ( np.isnan(idi), idi1, idi)
1609    return idi.values
1610
1611def index2lat (pj, plat_1d) :
1612    '''From index (real, continuous), get latitude
1613
1614    Needed to use transforms in Matplotlib
1615    '''
1616    jpj = plat_1d.shape[0]
1617    jj  = xr.DataArray (pj)
1618    j   = np.maximum (0, np.minimum (jpj-1, jj    ))
1619    j0  = np.floor (j).astype (int)
1620    j1  = np.maximum (0, np.minimum (jpj-1,  j0+1))
1621    yy  = j - j0
1622    gy  = (1.0-yy)*plat_1d[j0]+ yy*plat_1d[j1]
1623    return gy.values
1624
1625def lat2index (py, plat_1d) :
1626    '''From latitude, get index (real, continuous)
1627
1628    Needed to use transforms in Matplotlib
1629    '''
1630    jpj = plat_1d.shape[0]
1631    if isinstance (py, xr.core.dataarray.DataArray) :
1632        yy   = xr.DataArray (py.values , dims=('yy',))
1633    elif isinstance (py, np.ndarray) :
1634        yy   = xr.DataArray (py.ravel(), dims=('yy',))
1635    else :
1636        yy   = xr.DataArray (np.array([py]).ravel(), dims=('yy',))
1637    yy   = np.minimum (plat_1d.max(), np.minimum(yy, plat_1d.max() ))
1638    idj1 = np.minimum ( (plat_1d-yy), 0.).argmax (axis=0).astype(int)
1639    idj1 = np.maximum (0, np.minimum (jpj-1,  idj1  ))
1640    idj2 = np.maximum (0, np.minimum (jpj-1,  idj1-1))
1641
1642    zff = (yy - plat_1d[idj2])/(plat_1d[idj1]-plat_1d[idj2])
1643    idj = zff*idj1 + (1.0-zff)*idj2
1644    idj = xr.where ( np.isnan(idj), idj1, idj)
1645    return idj.values
1646
1647def angle_full (glamt, gphit, glamu, gphiu, glamv, gphiv,
1648                glamf, gphif, nperio=None) :
1649    '''Computes sinus and cosinus of model line direction with
1650    respect to east
1651    '''
1652    mmath = __mmath__ (glamt)
1653
1654    zlamt = lbc_add (glamt, nperio, 'T', 1.)
1655    zphit = lbc_add (gphit, nperio, 'T', 1.)
1656    zlamu = lbc_add (glamu, nperio, 'U', 1.)
1657    zphiu = lbc_add (gphiu, nperio, 'U', 1.)
1658    zlamv = lbc_add (glamv, nperio, 'V', 1.)
1659    zphiv = lbc_add (gphiv, nperio, 'V', 1.)
1660    zlamf = lbc_add (glamf, nperio, 'F', 1.)
1661    zphif = lbc_add (gphif, nperio, 'F', 1.)
1662
1663    # north pole direction & modulous (at T-point)
1664    zxnpt = 0. - 2.0 * np.cos (RAD*zlamt) * np.tan (RPI/4.0 - RAD*zphit/2.0)
1665    zynpt = 0. - 2.0 * np.sin (RAD*zlamt) * np.tan (RPI/4.0 - RAD*zphit/2.0)
1666    znnpt = zxnpt*zxnpt + zynpt*zynpt
1667
1668    # north pole direction & modulous (at U-point)
1669    zxnpu = 0. - 2.0 * np.cos (RAD*zlamu) * np.tan (RPI/4.0 - RAD*zphiu/2.0)
1670    zynpu = 0. - 2.0 * np.sin (RAD*zlamu) * np.tan (RPI/4.0 - RAD*zphiu/2.0)
1671    znnpu = zxnpu*zxnpu + zynpu*zynpu
1672
1673    # north pole direction & modulous (at V-point)
1674    zxnpv = 0. - 2.0 * np.cos (RAD*zlamv) * np.tan (RPI/4.0 - RAD*zphiv/2.0)
1675    zynpv = 0. - 2.0 * np.sin (RAD*zlamv) * np.tan (RPI/4.0 - RAD*zphiv/2.0)
1676    znnpv = zxnpv*zxnpv + zynpv*zynpv
1677
1678    # north pole direction & modulous (at F-point)
1679    zxnpf = 0. - 2.0 * np.cos( RAD*zlamf ) * np.tan ( RPI/4. - RAD*zphif/2. )
1680    zynpf = 0. - 2.0 * np.sin( RAD*zlamf ) * np.tan ( RPI/4. - RAD*zphif/2. )
1681    znnpf = zxnpf*zxnpf + zynpf*zynpf
1682
1683    # j-direction: v-point segment direction (around T-point)
1684    zlam = zlamv
1685    zphi = zphiv
1686    zlan = np.roll ( zlamv, axis=-2, shift=1)  # glamv (ji,jj-1)
1687    zphh = np.roll ( zphiv, axis=-2, shift=1)  # gphiv (ji,jj-1)
1688    zxvvt =  2.0 * np.cos ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1689          -  2.0 * np.cos ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1690    zyvvt =  2.0 * np.sin ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1691          -  2.0 * np.sin ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1692    znvvt = np.sqrt ( znnpt * ( zxvvt*zxvvt + zyvvt*zyvvt )  )
1693
1694    # j-direction: f-point segment direction (around u-point)
1695    zlam = zlamf
1696    zphi = zphif
1697    zlan = np.roll (zlamf, axis=-2, shift=1) # glamf (ji,jj-1)
1698    zphh = np.roll (zphif, axis=-2, shift=1) # gphif (ji,jj-1)
1699    zxffu =  2.0 * np.cos ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1700          -  2.0 * np.cos ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1701    zyffu =  2.0 * np.sin ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1702          -  2.0 * np.sin ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1703    znffu = np.sqrt ( znnpu * ( zxffu*zxffu + zyffu*zyffu )  )
1704
1705    # i-direction: f-point segment direction (around v-point)
1706    zlam = zlamf
1707    zphi = zphif
1708    zlan = np.roll (zlamf, axis=-1, shift=1) # glamf (ji-1,jj)
1709    zphh = np.roll (zphif, axis=-1, shift=1) # gphif (ji-1,jj)
1710    zxffv =  2.0 * np.cos ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1711          -  2.0 * np.cos ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1712    zyffv =  2.0 * np.sin ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1713          -  2.0 * np.sin ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1714    znffv = np.sqrt ( znnpv * ( zxffv*zxffv + zyffv*zyffv )  )
1715
1716    # j-direction: u-point segment direction (around f-point)
1717    zlam = np.roll (zlamu, axis=-2, shift=-1) # glamu (ji,jj+1)
1718    zphi = np.roll (zphiu, axis=-2, shift=-1) # gphiu (ji,jj+1)
1719    zlan = zlamu
1720    zphh = zphiu
1721    zxuuf =  2. * np.cos ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1722          -  2. * np.cos ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1723    zyuuf =  2. * np.sin ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1724          -  2. * np.sin ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1725    znuuf = np.sqrt ( znnpf * ( zxuuf*zxuuf + zyuuf*zyuuf )  )
1726
1727
1728    # cosinus and sinus using scalar and vectorial products
1729    gsint = ( zxnpt*zyvvt - zynpt*zxvvt ) / znvvt
1730    gcost = ( zxnpt*zxvvt + zynpt*zyvvt ) / znvvt
1731
1732    gsinu = ( zxnpu*zyffu - zynpu*zxffu ) / znffu
1733    gcosu = ( zxnpu*zxffu + zynpu*zyffu ) / znffu
1734
1735    gsinf = ( zxnpf*zyuuf - zynpf*zxuuf ) / znuuf
1736    gcosf = ( zxnpf*zxuuf + zynpf*zyuuf ) / znuuf
1737
1738    gsinv = ( zxnpv*zxffv + zynpv*zyffv ) / znffv
1739    # (caution, rotation of 90 degres)
1740    gcosv =-( zxnpv*zyffv - zynpv*zxffv ) / znffv
1741
1742    gsint = lbc_del (gsint, cd_type='T', nperio=nperio, psgn=-1.)
1743    gcost = lbc_del (gcost, cd_type='T', nperio=nperio, psgn=-1.)
1744    gsinu = lbc_del (gsinu, cd_type='U', nperio=nperio, psgn=-1.)
1745    gcosu = lbc_del (gcosu, cd_type='U', nperio=nperio, psgn=-1.)
1746    gsinv = lbc_del (gsinv, cd_type='V', nperio=nperio, psgn=-1.)
1747    gcosv = lbc_del (gcosv, cd_type='V', nperio=nperio, psgn=-1.)
1748    gsinf = lbc_del (gsinf, cd_type='F', nperio=nperio, psgn=-1.)
1749    gcosf = lbc_del (gcosf, cd_type='F', nperio=nperio, psgn=-1.)
1750
1751    if mmath == xr :
1752        gsint = gsint.assign_coords ( glamt.coords )
1753        gcost = gcost.assign_coords ( glamt.coords )
1754        gsinu = gsinu.assign_coords ( glamu.coords )
1755        gcosu = gcosu.assign_coords ( glamu.coords )
1756        gsinv = gsinv.assign_coords ( glamv.coords )
1757        gcosv = gcosv.assign_coords ( glamv.coords )
1758        gsinf = gsinf.assign_coords ( glamf.coords )
1759        gcosf = gcosf.assign_coords ( glamf.coords )
1760
1761    return gsint, gcost, gsinu, gcosu, gsinv, gcosv, gsinf, gcosf
1762
1763def angle (glam, gphi, nperio, cd_type='T') :
1764    '''Computes sinus and cosinus of model line direction with
1765    respect to east
1766    '''
1767    mmath = __mmath__ (glam)
1768
1769    zlam = lbc_add (glam, nperio, cd_type, 1.)
1770    zphi = lbc_add (gphi, nperio, cd_type, 1.)
1771
1772    # north pole direction & modulous
1773    zxnp = 0. - 2.0 * np.cos (RAD*zlam) * np.tan (RPI/4.0 - RAD*zphi/2.0)
1774    zynp = 0. - 2.0 * np.sin (RAD*zlam) * np.tan (RPI/4.0 - RAD*zphi/2.0)
1775    znnp = zxnp*zxnp + zynp*zynp
1776
1777    # j-direction: segment direction (around point)
1778    zlan_n = np.roll (zlam, axis=-2, shift=-1) # glam [jj+1, ji]
1779    zphh_n = np.roll (zphi, axis=-2, shift=-1) # gphi [jj+1, ji]
1780    zlan_s = np.roll (zlam, axis=-2, shift= 1) # glam [jj-1, ji]
1781    zphh_s = np.roll (zphi, axis=-2, shift= 1) # gphi [jj-1, ji]
1782
1783    zxff = 2.0 * np.cos (RAD*zlan_n) * np.tan (RPI/4.0 - RAD*zphh_n/2.0) \
1784        -  2.0 * np.cos (RAD*zlan_s) * np.tan (RPI/4.0 - RAD*zphh_s/2.0)
1785    zyff = 2.0 * np.sin (RAD*zlan_n) * np.tan (RPI/4.0 - RAD*zphh_n/2.0) \
1786        -  2.0 * np.sin (RAD*zlan_s) * np.tan (RPI/4.0 - RAD*zphh_s/2.0)
1787    znff = np.sqrt (znnp * (zxff*zxff + zyff*zyff) )
1788
1789    gsin = (zxnp*zyff - zynp*zxff) / znff
1790    gcos = (zxnp*zxff + zynp*zyff) / znff
1791
1792    gsin = lbc_del (gsin, cd_type=cd_type, nperio=nperio, psgn=-1.)
1793    gcos = lbc_del (gcos, cd_type=cd_type, nperio=nperio, psgn=-1.)
1794
1795    if mmath == xr :
1796        gsin = gsin.assign_coords ( glam.coords )
1797        gcos = gcos.assign_coords ( glam.coords )
1798
1799    return gsin, gcos
1800
1801def rot_en2ij ( u_e, v_n, gsin, gcos, nperio, cd_type ) :
1802    '''Rotates the Repere: Change vector componantes between
1803    geographic grid --> stretched coordinates grid.
1804
1805    All components are on the same grid (T, U, V or F)
1806    '''
1807
1808    u_i = + u_e * gcos + v_n * gsin
1809    v_j = - u_e * gsin + v_n * gcos
1810
1811    u_i = lbc (u_i, nperio=nperio, cd_type=cd_type, psgn=-1.0)
1812    v_j = lbc (v_j, nperio=nperio, cd_type=cd_type, psgn=-1.0)
1813
1814    return u_i, v_j
1815
1816def rot_ij2en ( u_i, v_j, gsin, gcos, nperio, cd_type='T' ) :
1817    '''Rotates the Repere: Change vector componantes from
1818    stretched coordinates grid --> geographic grid
1819
1820    All components are on the same grid (T, U, V or F)
1821    '''
1822    u_e = + u_i * gcos - v_j * gsin
1823    v_n = + u_i * gsin + v_j * gcos
1824
1825    u_e = lbc (u_e, nperio=nperio, cd_type=cd_type, psgn=1.0)
1826    v_n = lbc (v_n, nperio=nperio, cd_type=cd_type, psgn=1.0)
1827
1828    return u_e, v_n
1829
1830def rot_uv2en ( uo, vo, gsint, gcost, nperio, zdim=None ) :
1831    '''Rotate the Repere: Change vector componantes from
1832    stretched coordinates grid --> geographic grid
1833
1834    uo : velocity along i at the U grid point
1835    vo : valocity along j at the V grid point
1836   
1837    Returns east-north components on the T grid point
1838    '''
1839    ut = u2t (uo, nperio=nperio, psgn=-1.0, zdim=zdim)
1840    vt = v2t (vo, nperio=nperio, psgn=-1.0, zdim=zdim)
1841
1842    u_e = + ut * gcost - vt * gsint
1843    v_n = + ut * gsint + vt * gcost
1844
1845    u_e = lbc (u_e, nperio=nperio, cd_type='T', psgn=1.0)
1846    v_n = lbc (v_n, nperio=nperio, cd_type='T', psgn=1.0)
1847
1848    return u_e, v_n
1849
1850def rot_uv2enf ( uo, vo, gsinf, gcosf, nperio, zdim=None ) :
1851    '''Rotates the Repere: Change vector componantes from
1852    stretched coordinates grid --> geographic grid
1853
1854    uo : velocity along i at the U grid point
1855    vo : valocity along j at the V grid point
1856   
1857    Returns east-north components on the F grid point
1858    '''
1859    uf = u2f (uo, nperio=nperio, psgn=-1.0, zdim=zdim)
1860    vf = v2f (vo, nperio=nperio, psgn=-1.0, zdim=zdim)
1861
1862    u_e = + uf * gcosf - vf * gsinf
1863    v_n = + uf * gsinf + vf * gcosf
1864
1865    u_e = lbc (u_e, nperio=nperio, cd_type='F', psgn= 1.0)
1866    v_n = lbc (v_n, nperio=nperio, cd_type='F', psgn= 1.0)
1867
1868    return u_e, v_n
1869
1870def u2t (utab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
1871    '''Interpolates an array from U grid to T grid (i-mean)
1872    '''
1873    mmath = __mmath__ (utab)
1874    utab_0 = mmath.where ( np.isnan(utab), 0., utab)
1875    #lperio, aperio = lbc_diag (nperio)
1876    utab_0 = lbc_add (utab_0, nperio=nperio, cd_type='U', psgn=psgn)
1877    ax, ix = __find_axis__ (utab_0, 'x')
1878    az     = __find_axis__ (utab_0, 'z')[0]
1879
1880    if ax :
1881        if action == 'ave' :
1882            ttab = 0.5 *      (utab_0 + np.roll (utab_0, axis=ix, shift=1))
1883        if action == 'min' :
1884            ttab = np.minimum (utab_0 , np.roll (utab_0, axis=ix, shift=1))
1885        if action == 'max' :
1886            ttab = np.maximum (utab_0 , np.roll (utab_0, axis=ix, shift=1))
1887        if action == 'mult':
1888            ttab =             utab_0 * np.roll (utab_0, axis=ix, shift=1)
1889        ttab = lbc_del (ttab  , nperio=nperio, cd_type='T', psgn=psgn)
1890    else :
1891        ttab = lbc_del (utab_0, nperio=nperio, cd_type='T', psgn=psgn)
1892
1893    if mmath == xr :
1894        if ax :
1895            ttab = ttab.assign_coords({ax:np.arange (ttab.shape[ix])+1.})
1896        if zdim and az :
1897            if az != zdim :
1898                ttab = ttab.rename( {az:zdim})
1899    return ttab
1900
1901def v2t (vtab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
1902    '''Interpolates an array from V grid to T grid (j-mean)
1903    '''
1904    mmath = __mmath__ (vtab)
1905    #lperio, aperio = lbc_diag (nperio)
1906    vtab_0 = mmath.where ( np.isnan(vtab), 0., vtab)
1907    vtab_0 = lbc_add (vtab_0, nperio=nperio, cd_type='V', psgn=psgn)
1908    ay, jy = __find_axis__ (vtab_0, 'y')
1909    az     = __find_axis__ (vtab_0, 'z')[0]
1910    if ay :
1911        if action == 'ave'  :
1912            ttab = 0.5 *      (vtab_0 + np.roll (vtab_0, axis=jy, shift=1))
1913        if action == 'min'  :
1914            ttab = np.minimum (vtab_0 , np.roll (vtab_0, axis=jy, shift=1))
1915        if action == 'max'  :
1916            ttab = np.maximum (vtab_0 , np.roll (vtab_0, axis=jy, shift=1))
1917        if action == 'mult' :
1918            ttab =             vtab_0 * np.roll (vtab_0, axis=jy, shift=1)
1919        ttab = lbc_del (ttab  , nperio=nperio, cd_type='T', psgn=psgn)
1920    else :
1921        ttab = lbc_del (vtab_0, nperio=nperio, cd_type='T', psgn=psgn)
1922
1923    if mmath == xr :
1924        if ay :
1925            ttab = ttab.assign_coords({ay:np.arange(ttab.shape[jy])+1.})
1926        if zdim and az :
1927            if az != zdim :
1928                ttab = ttab.rename( {az:zdim})
1929    return ttab
1930
1931def f2t (ftab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1932    '''Interpolates an array from F grid to T grid (i- and j- means)
1933    '''
1934    mmath = __mmath__ (ftab)
1935    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
1936    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
1937    ttab = v2t (f2v (ftab_0, nperio=nperio, psgn=psgn, zdim=zdim, action=action),
1938                     nperio=nperio, psgn=psgn, zdim=zdim, action=action)
1939    return lbc_del (ttab, nperio=nperio, cd_type='T', psgn=psgn)
1940
1941def t2u (ttab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1942    '''Interpolates an array from T grid to U grid (i-mean)
1943    '''
1944    mmath = __mmath__ (ttab)
1945    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1946    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
1947    ax, ix = __find_axis__ (ttab_0, 'x')[0]
1948    az     = __find_axis__ (ttab_0, 'z')
1949    if ix :
1950        if action == 'ave'  :
1951            utab = 0.5 *      (ttab_0 + np.roll (ttab_0, axis=ix, shift=-1))
1952        if action == 'min'  :
1953            utab = np.minimum (ttab_0 , np.roll (ttab_0, axis=ix, shift=-1))
1954        if action == 'max'  :
1955            utab = np.maximum (ttab_0 , np.roll (ttab_0, axis=ix, shift=-1))
1956        if action == 'mult' :
1957            utab =             ttab_0 * np.roll (ttab_0, axis=ix, shift=-1)
1958        utab = lbc_del (utab  , nperio=nperio, cd_type='U', psgn=psgn)
1959    else :
1960        utab = lbc_del (ttab_0, nperio=nperio, cd_type='U', psgn=psgn)
1961
1962    if mmath == xr :
1963        if ax :
1964            utab = ttab.assign_coords({ax:np.arange(utab.shape[ix])+1.})
1965        if zdim and az :
1966            if az != zdim :
1967                utab = utab.rename( {az:zdim})
1968    return utab
1969
1970def t2v (ttab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1971    '''Interpolates an array from T grid to V grid (j-mean)
1972    '''
1973    mmath = __mmath__ (ttab)
1974    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1975    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
1976    ay, jy = __find_axis__ (ttab_0, 'y')
1977    az     = __find_axis__ (ttab_0, 'z')[0]
1978    if jy :
1979        if action == 'ave'  :
1980            vtab = 0.5 *      (ttab_0 + np.roll (ttab_0, axis=jy, shift=-1))
1981        if action == 'min'  :
1982            vtab = np.minimum (ttab_0 , np.roll (ttab_0, axis=jy, shift=-1))
1983        if action == 'max'  :
1984            vtab = np.maximum (ttab_0 , np.roll (ttab_0, axis=jy, shift=-1))
1985        if action == 'mult' :
1986            vtab =             ttab_0 * np.roll (ttab_0, axis=jy, shift=-1)
1987        vtab = lbc_del (vtab  , nperio=nperio, cd_type='V', psgn=psgn)
1988    else :
1989        vtab = lbc_del (ttab_0, nperio=nperio, cd_type='V', psgn=psgn)
1990
1991    if mmath == xr :
1992        if ay :
1993            vtab = vtab.assign_coords({ay:np.arange(vtab.shape[jy])+1.})
1994        if zdim and az :
1995            if az != zdim :
1996                vtab = vtab.rename( {az:zdim})
1997    return vtab
1998
1999def v2f (vtab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
2000    '''Interpolates an array from V grid to F grid (i-mean)
2001    '''
2002    mmath = __mmath__ (vtab)
2003    vtab_0 = mmath.where ( np.isnan(vtab), 0., vtab)
2004    vtab_0 = lbc_add (vtab_0 , nperio=nperio, cd_type='V', psgn=psgn)
2005    ax, ix = __find_axis__ (vtab_0, 'x')
2006    az     = __find_axis__ (vtab_0, 'z')[0]
2007    if ix :
2008        if action == 'ave'  :
2009            ftab = 0.5 *      (vtab_0 + np.roll (vtab_0, axis=ix, shift=-1))
2010        if action == 'min'  :
2011            ftab = np.minimum (vtab_0 , np.roll (vtab_0, axis=ix, shift=-1))
2012        if action == 'max'  :
2013            ftab = np.maximum (vtab_0 , np.roll (vtab_0, axis=ix, shift=-1))
2014        if action == 'mult' :
2015            ftab =             vtab_0 * np.roll (vtab_0, axis=ix, shift=-1)
2016        ftab = lbc_del (ftab  , nperio=nperio, cd_type='F', psgn=psgn)
2017    else :
2018        ftab = lbc_del (vtab_0, nperio=nperio, cd_type='F', psgn=psgn)
2019
2020    if mmath == xr :
2021        if ax :
2022            ftab = ftab.assign_coords({ax:np.arange(ftab.shape[ix])+1.})
2023        if zdim and az :
2024            if az != zdim :
2025                ftab = ftab.rename( {az:zdim})
2026    return lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
2027
2028def u2f (utab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
2029    '''Interpolates an array from U grid to F grid i-mean)
2030    '''
2031    mmath = __mmath__ (utab)
2032    utab_0 = mmath.where ( np.isnan(utab), 0., utab)
2033    utab_0 = lbc_add (utab_0 , nperio=nperio, cd_type='U', psgn=psgn)
2034    ay, jy = __find_axis__ (utab_0, 'y')
2035    az     = __find_axis__ (utab_0, 'z')[0]
2036    if jy :
2037        if action == 'ave'  :
2038            ftab = 0.5 *      (utab_0 + np.roll (utab_0, axis=jy, shift=-1))
2039        if action == 'min'  :
2040            ftab = np.minimum (utab_0 , np.roll (utab_0, axis=jy, shift=-1))
2041        if action == 'max'  :
2042            ftab = np.maximum (utab_0 , np.roll (utab_0, axis=jy, shift=-1))
2043        if action == 'mult' :
2044            ftab =             utab_0 * np.roll (utab_0, axis=jy, shift=-1)
2045        ftab = lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
2046    else :
2047        ftab = lbc_del (utab_0, nperio=nperio, cd_type='F', psgn=psgn)
2048
2049    if mmath == xr :
2050        if ay :
2051            ftab = ftab.assign_coords({'y':np.arange(ftab.shape[jy])+1.})
2052        if zdim and az :
2053            if az != zdim :
2054                ftab = ftab.rename( {az:zdim})
2055    return ftab
2056
2057def t2f (ttab, nperio=None, psgn=1.0, zdim=None, action='mean') :
2058    '''Interpolates an array on T grid to F grid (i- and j- means)
2059    '''
2060    mmath = __mmath__ (ttab)
2061    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
2062    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
2063    ftab = t2u (u2f (ttab, nperio=nperio, psgn=psgn, zdim=zdim, action=action),
2064                     nperio=nperio, psgn=psgn, zdim=zdim, action=action)
2065
2066    return lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
2067
2068def f2u (ftab, nperio=None, psgn=1.0, zdim=None, action='ave') :
2069    '''Interpolates an array on F grid to U grid (j-mean)
2070    '''
2071    mmath = __mmath__ (ftab)
2072    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
2073    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
2074    ay, jy = __find_axis__ (ftab_0, 'y')
2075    az     = __find_axis__ (ftab_0, 'z')[0]
2076    if jy :
2077        if action == 'ave'  :
2078            utab = 0.5 *      (ftab_0 + np.roll (ftab_0, axis=jy, shift=-1))
2079        if action == 'min'  :
2080            utab = np.minimum (ftab_0 , np.roll (ftab_0, axis=jy, shift=-1))
2081        if action == 'max'  :
2082            utab = np.maximum (ftab_0 , np.roll (ftab_0, axis=jy, shift=-1))
2083        if action == 'mult' :
2084            utab =             ftab_0 * np.roll (ftab_0, axis=jy, shift=-1)
2085        utab = lbc_del (utab  , nperio=nperio, cd_type='U', psgn=psgn)
2086    else :
2087        utab = lbc_del (ftab_0, nperio=nperio, cd_type='U', psgn=psgn)
2088
2089    if mmath == xr :
2090        utab = utab.assign_coords({ay:np.arange(ftab.shape[jy])+1.})
2091        if zdim and az and az != zdim :
2092            utab = utab.rename( {az:zdim})
2093    return utab
2094
2095def f2v (ftab, nperio=None, psgn=1.0, zdim=None, action='ave') :
2096    '''Interpolates an array from F grid to V grid (i-mean)
2097    '''
2098    mmath = __mmath__ (ftab)
2099    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
2100    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
2101    ax, ix = __find_axis__ (ftab_0, 'x')
2102    az     = __find_axis__ (ftab_0, 'z')[0]
2103    if ix :
2104        if action == 'ave'  :
2105            vtab = 0.5 *      (ftab_0 + np.roll (ftab_0, axis=ix, shift=-1))
2106        if action == 'min'  :
2107            vtab = np.minimum (ftab_0 , np.roll (ftab_0, axis=ix, shift=-1))
2108        if action == 'max'  :
2109            vtab = np.maximum (ftab_0 , np.roll (ftab_0, axis=ix, shift=-1))
2110        if action == 'mult' :
2111            vtab =             ftab_0 * np.roll (ftab_0, axis=ix, shift=-1)
2112        vtab = lbc_del (vtab  , nperio=nperio, cd_type='V', psgn=psgn)
2113    else :
2114        vtab = lbc_del (ftab_0, nperio=nperio, cd_type='V', psgn=psgn)
2115
2116    if mmath == xr :
2117        vtab = vtab.assign_coords({ax:np.arange(ftab.shape[ix])+1.})
2118        if zdim and az :
2119            if az != zdim :
2120                vtab = vtab.rename( {az:zdim})
2121    return vtab
2122
2123def w2t (wtab, zcoord=None, zdim=None, sval=np.nan) :
2124    '''Interpolates an array on W grid to T grid (k-mean)
2125
2126    sval is the bottom value
2127    '''
2128    mmath = __mmath__ (wtab)
2129    wtab_0 = mmath.where ( np.isnan(wtab), 0., wtab)
2130
2131    az, kz = __find_axis__ (wtab_0, 'z')
2132
2133    if kz :
2134        ttab = 0.5 * ( wtab_0 + np.roll (wtab_0, axis=kz, shift=-1) )
2135    else :
2136        ttab = wtab_0
2137
2138    if mmath == xr :
2139        ttab[{az:kz}] = sval
2140        if zdim and az :
2141            if az != zdim :
2142                ttab = ttab.rename ( {az:zdim} )
2143        if zcoord is not None :
2144            ttab = ttab.assign_coords ( {zdim:zcoord} )
2145    else :
2146        ttab[..., -1, :, :] = sval
2147
2148    return ttab
2149
2150def t2w (ttab, zcoord=None, zdim=None, sval=np.nan, extrap_surf=False) :
2151    '''Interpolates an array from T grid to W grid (k-mean)
2152
2153    sval is the surface value
2154    if extrap_surf==True, surface value is taken from 1st level value.
2155    '''
2156    mmath = __mmath__ (ttab)
2157    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
2158    az, kz = __find_axis__ (ttab_0, 'z')
2159    wtab = 0.5 * ( ttab_0 + np.roll (ttab_0, axis=kz, shift=1) )
2160
2161    if mmath == xr :
2162        if extrap_surf :
2163            wtab[{az:0}] = ttab[{az:0}]
2164        else           :
2165            wtab[{az:0}] = sval
2166    else :
2167        if extrap_surf :
2168            wtab[..., 0, :, :] = ttab[..., 0, :, :]
2169        else           :
2170            wtab[..., 0, :, :] = sval
2171
2172    if mmath == xr :
2173        if zdim and az and az != zdim :
2174            wtab = wtab.rename ( {az:zdim})
2175        if zcoord is not None :
2176            wtab = wtab.assign_coords ( {zdim:zcoord})
2177        else :
2178            wtab = wtab.assign_coords ( {zdim:np.arange(ttab.shape[kz])+1.} )
2179    return wtab
2180
2181def fill (ptab, nperio, cd_type='T', npass=1, sval=np.nan) :
2182    '''Fills np.nan values with mean of neighbours
2183
2184    Inputs :
2185       ptab : input field to fill
2186       nperio, cd_type : periodicity characteristics
2187    '''
2188
2189    mmath = __mmath__ (ptab)
2190
2191    do_perio  = False
2192    lperio    = nperio
2193    if nperio == 4.2 :
2194        do_perio, lperio = True, 4
2195    if nperio == 6.2 :
2196        do_perio, lperio = True, 6
2197
2198    if do_perio :
2199        ztab = lbc_add (ptab, nperio=nperio)
2200    else :
2201        ztab = ptab
2202
2203    if np.isnan (sval) :
2204        ztab   = mmath.where (np.isnan(ztab), np.nan, ztab)
2205    else :
2206        ztab   = mmath.where (ztab==sval    , np.nan, ztab)
2207
2208    for _ in np.arange (npass) :
2209        zmask = mmath.where ( np.isnan(ztab), 0., 1.   )
2210        ztab0 = mmath.where ( np.isnan(ztab), 0., ztab )
2211        # Compte du nombre de voisins
2212        zcount = 1./6. * ( zmask \
2213          + np.roll(zmask, shift=1, axis=-1) + np.roll(zmask, shift=-1, axis=-1) \
2214          + np.roll(zmask, shift=1, axis=-2) + np.roll(zmask, shift=-1, axis=-2) \
2215          + 0.5 * ( \
2216                + np.roll(np.roll(zmask, shift= 1, axis=-2), shift= 1, axis=-1) \
2217                + np.roll(np.roll(zmask, shift=-1, axis=-2), shift= 1, axis=-1) \
2218                + np.roll(np.roll(zmask, shift= 1, axis=-2), shift=-1, axis=-1) \
2219                + np.roll(np.roll(zmask, shift=-1, axis=-2), shift=-1, axis=-1) ) )
2220
2221        znew =1./6. * ( ztab0 \
2222           + np.roll(ztab0, shift=1, axis=-1) + np.roll(ztab0, shift=-1, axis=-1) \
2223           + np.roll(ztab0, shift=1, axis=-2) + np.roll(ztab0, shift=-1, axis=-2) \
2224           + 0.5 * ( \
2225                + np.roll(np.roll(ztab0 , shift= 1, axis=-2), shift= 1, axis=-1) \
2226                + np.roll(np.roll(ztab0 , shift=-1, axis=-2), shift= 1, axis=-1) \
2227                + np.roll(np.roll(ztab0 , shift= 1, axis=-2), shift=-1, axis=-1) \
2228                + np.roll(np.roll(ztab0 , shift=-1, axis=-2), shift=-1, axis=-1) ) )
2229
2230        zcount = lbc (zcount, nperio=lperio, cd_type=cd_type)
2231        znew   = lbc (znew  , nperio=lperio, cd_type=cd_type)
2232
2233        ztab = mmath.where (np.logical_and (zmask==0., zcount>0), znew/zcount, ztab)
2234
2235    ztab = mmath.where (zcount==0, sval, ztab)
2236    if do_perio :
2237        ztab = lbc_del (ztab, nperio=lperio)
2238
2239    return ztab
2240
2241def correct_uv (u, v, lat) :
2242    '''
2243    Corrects a Cartopy bug in orthographic projection
2244
2245    See https://github.com/SciTools/cartopy/issues/1179
2246
2247    The correction is needed with cartopy <= 0.20
2248    It seems that version 0.21 will correct the bug (https://github.com/SciTools/cartopy/pull/1926)
2249    Later note : the bug is still present in Cartopy 0.22
2250
2251    Inputs :
2252       u, v : eastward/northward components
2253       lat  : latitude of the point (degrees north)
2254
2255    Outputs :
2256       modified eastward/nothward components to have correct polar projections in cartopy
2257    '''
2258    uv = np.sqrt (u*u + v*v)           # Original modulus
2259    zu = u
2260    zv = v * np.cos (RAD*lat)
2261    zz = np.sqrt ( zu*zu + zv*zv )     # Corrected modulus
2262    uc = zu*uv/zz
2263    vc = zv*uv/zz      # Final corrected values
2264    return uc, vc
2265
2266def norm_uv (u, v) :
2267    '''Returns norm of a 2 components vector
2268    '''
2269    return np.sqrt (u*u + v*v)
2270
2271def normalize_uv (u, v) :
2272    '''Normalizes 2 components vector
2273    '''
2274    uv = norm_uv (u, v)
2275    return u/uv, v/uv
2276
2277def msf (vv, e1v_e3v, plat1d, depthw) :
2278    '''Computes the meridonal stream function
2279
2280    vv : meridional_velocity
2281    e1v_e3v : prodcut of scale factors e1v*e3v
2282    '''
2283
2284    v_e1v_e3v = vv * e1v_e3v
2285    v_e1v_e3v.attrs = vv.attrs
2286
2287    ax = __find_axis__ (v_e1v_e3v, 'x')[0]
2288    az = __find_axis__ (v_e1v_e3v, 'z')[0]
2289    if az == 'olevel' :
2290        new_az = 'olevel'
2291    else :
2292        new_az = 'depthw'
2293
2294    zomsf = -v_e1v_e3v.cumsum ( dim=az, keep_attrs=True).sum (dim=ax, keep_attrs=True)*1.E-6
2295    zomsf = zomsf - zomsf.isel ( { az:-1} )
2296
2297    ay = __find_axis__ (zomsf, 'y' )[0]
2298    zomsf = zomsf.assign_coords ( {az:depthw.values, ay:plat1d.values})
2299
2300    zomsf = zomsf.rename ( {ay:'lat'})
2301    if az != new_az :
2302        zomsf = zomsf.rename ( {az:new_az} )
2303    zomsf.attrs['standard_name'] = 'Meridional stream function'
2304    zomsf.attrs['long_name'] = 'Meridional stream function'
2305    zomsf.attrs['units'] = 'Sv'
2306    zomsf[new_az].attrs  = depthw.attrs
2307    zomsf.lat.attrs=plat1d.attrs
2308
2309    return zomsf
2310
2311def bsf (uu, e2u_e3u, mask, nperio=None, bsf0=None ) :
2312    '''Computes the barotropic stream function
2313
2314    uu      : zonal_velocity
2315    e2u_e3u : product of scales factor e2u*e3u
2316    bsf0    : the point with bsf=0
2317    (ex: bsf0={'x':3, 'y':120} for orca2,
2318         bsf0={'x':5, 'y':300} for eORCA1
2319    '''
2320    u_e2u_e3u       = uu * e2u_e3u
2321    u_e2u_e3u.attrs = uu.attrs
2322
2323    ay = __find_axis__ (u_e2u_e3u, 'y')[0]
2324    az = __find_axis__ (u_e2u_e3u, 'z')[0]
2325
2326    zbsf = -u_e2u_e3u.cumsum ( dim=ay, keep_attrs=True )
2327    zbsf = zbsf.sum (dim=az, keep_attrs=True)*1.E-6
2328
2329    if bsf0 :
2330        zbsf = zbsf - zbsf.isel (bsf0)
2331
2332    zbsf = zbsf.where (mask !=0, np.nan)
2333    zbsf.attrs.update (uu.attrs)
2334    zbsf.attrs['standard_name'] = 'barotropic_stream_function'
2335    zbsf.attrs['long_name']     = 'Barotropic stream function'
2336    zbsf.attrs['units']         = 'Sv'
2337    zbsf = lbc (zbsf, nperio=nperio, cd_type='F')
2338
2339    return zbsf
2340
2341if f90nml :
2342    def namelist_read (ref=None, cfg=None, out='dict', flat=False, verbose=False) :
2343        '''Read NEMO namelist(s) and return either a dictionnary or an xarray dataset
2344
2345        ref : file with reference namelist, or a f90nml.namelist.Namelist object
2346        cfg : file with config namelist, or a f90nml.namelist.Namelist object
2347        At least one namelist neaded
2348
2349        out:
2350        'dict' to return a dictonnary
2351        'xr'   to return an xarray dataset
2352        flat : only for dict output. Output a flat dictionary with all values.
2353
2354        '''
2355        if ref :
2356            if isinstance (ref, str) :
2357                nml_ref = f90nml.read (ref)
2358            if isinstance (ref, f90nml.namelist.Namelist) :
2359                nml_ref = ref
2360
2361        if cfg :
2362            if isinstance (cfg, str) :
2363                nml_cfg = f90nml.read (cfg)
2364            if isinstance (cfg, f90nml.namelist.Namelist) :
2365                nml_cfg = cfg
2366
2367        if out == 'dict' :
2368            dict_namelist = {}
2369        if out == 'xr'   :
2370            xr_namelist = xr.Dataset ()
2371
2372        list_nml     = []
2373        list_comment = []
2374
2375        if ref :
2376            list_nml.append (nml_ref)
2377            list_comment.append ('ref')
2378        if cfg :
2379            list_nml.append (nml_cfg)
2380            list_comment.append ('cfg')
2381
2382        for nml, comment in zip (list_nml, list_comment) :
2383            if verbose :
2384                print (comment)
2385            if flat and out =='dict' :
2386                for nam in nml.keys () :
2387                    if verbose :
2388                        print (nam)
2389                    for value in nml[nam] :
2390                        if out == 'dict' :
2391                            dict_namelist[value] = nml[nam][value]
2392                        if verbose :
2393                            print (nam, ':', value, ':', nml[nam][value])
2394            else :
2395                for nam in nml.keys () :
2396                    if verbose :
2397                        print (nam)
2398                    if out == 'dict' :
2399                        if nam not in dict_namelist.keys () :
2400                            dict_namelist[nam] = {}
2401                    for value in nml[nam] :
2402                        if out == 'dict' :
2403                            dict_namelist[nam][value] = nml[nam][value]
2404                        if out == 'xr'   :
2405                            xr_namelist[value] = nml[nam][value]
2406                        if verbose :
2407                            print (nam, ':', value, ':', nml[nam][value])
2408
2409        if out == 'dict' :
2410            return dict_namelist
2411        if out == 'xr'   :
2412            return xr_namelist
2413else :
2414     def namelist_read (ref=None, cfg=None, out='dict', flat=False, verbose=False) :
2415        '''Shadow version of namelist read, when f90nm module was not found
2416
2417        namelist_read :
2418        Read NEMO namelist(s) and return either a dictionnary or an xarray dataset
2419        '''
2420        print ( 'Error : module f90nml not found' )
2421        print ( 'Cannot call namelist_read' )
2422        print ( 'Call parameters where : ')
2423        print ( f'{err=} {ref=} {cfg=} {out=} {flat=} {verbose=}' )
2424
2425def fill_closed_seas (imask, nperio=None,  cd_type='T') :
2426    '''Fill closed seas with image processing library
2427
2428    imask : mask, 1 on ocean, 0 on land
2429    '''
2430    from scipy import ndimage
2431
2432    imask_filled = ndimage.binary_fill_holes ( lbc (imask, nperio=nperio, cd_type=cd_type))
2433    imask_filled = lbc ( imask_filled, nperio=nperio, cd_type=cd_type)
2434
2435    return imask_filled
2436
2437# ======================================================
2438# Sea water state function parameters from NEMO code
2439
2440RDELTAS = 32.
2441R1_S0   = 0.875/35.16504
2442R1_T0   = 1./40.
2443R1_Z0   = 1.e-4
2444
2445EOS000 =  8.0189615746e+02
2446EOS100 =  8.6672408165e+02
2447EOS200 = -1.7864682637e+03
2448EOS300 =  2.0375295546e+03
2449EOS400 = -1.2849161071e+03
2450EOS500 =  4.3227585684e+02
2451EOS600 = -6.0579916612e+01
2452EOS010 =  2.6010145068e+01
2453EOS110 = -6.5281885265e+01
2454EOS210 =  8.1770425108e+01
2455EOS310 = -5.6888046321e+01
2456EOS410 =  1.7681814114e+01
2457EOS510 = -1.9193502195
2458EOS020 = -3.7074170417e+01
2459EOS120 =  6.1548258127e+01
2460EOS220 = -6.0362551501e+01
2461EOS320 =  2.9130021253e+01
2462EOS420 = -5.4723692739
2463EOS030 =  2.1661789529e+01
2464EOS130 = -3.3449108469e+01
2465EOS230 =  1.9717078466e+01
2466EOS330 = -3.1742946532
2467EOS040 = -8.3627885467
2468EOS140 =  1.1311538584e+01
2469EOS240 = -5.3563304045
2470EOS050 =  5.4048723791e-01
2471EOS150 =  4.8169980163e-01
2472EOS060 = -1.9083568888e-01
2473EOS001 =  1.9681925209e+01
2474EOS101 = -4.2549998214e+01
2475EOS201 =  5.0774768218e+01
2476EOS301 = -3.0938076334e+01
2477EOS401 =  6.6051753097
2478EOS011 = -1.3336301113e+01
2479EOS111 = -4.4870114575
2480EOS211 =  5.0042598061
2481EOS311 = -6.5399043664e-01
2482EOS021 =  6.7080479603
2483EOS121 =  3.5063081279
2484EOS221 = -1.8795372996
2485EOS031 = -2.4649669534
2486EOS131 = -5.5077101279e-01
2487EOS041 =  5.5927935970e-01
2488EOS002 =  2.0660924175
2489EOS102 = -4.9527603989
2490EOS202 =  2.5019633244
2491EOS012 =  2.0564311499
2492EOS112 = -2.1311365518e-01
2493EOS022 = -1.2419983026
2494EOS003 = -2.3342758797e-02
2495EOS103 = -1.8507636718e-02
2496EOS013 =  3.7969820455e-01
2497
2498def rhop ( ptemp, psal ) :
2499    '''Returns potential density referenced to surface
2500
2501    Computation from NEMO code
2502    '''
2503    zt      = ptemp * R1_T0                                  # Temperature (°C)
2504    zs      = np.sqrt ( np.abs( psal + RDELTAS ) * R1_S0 )   # Square root of salinity (PSS)
2505    #
2506    prhop = (
2507      (((((EOS060*zt
2508         + EOS150*zs     + EOS050)*zt
2509         + (EOS240*zs    + EOS140)*zs + EOS040)*zt
2510         + ((EOS330*zs   + EOS230)*zs + EOS130)*zs + EOS030)*zt
2511         + (((EOS420*zs  + EOS320)*zs + EOS220)*zs + EOS120)*zs + EOS020)*zt
2512         + ((((EOS510*zs + EOS410)*zs + EOS310)*zs + EOS210)*zs + EOS110)*zs + EOS010)*zt
2513         + (((((EOS600*zs+ EOS500)*zs + EOS400)*zs + EOS300)*zs + EOS200)*zs + EOS100)*zs + EOS000 )
2514    #
2515    return prhop
2516
2517def rho ( pdep, ptemp, psal ) :
2518    '''Returns in situ density
2519
2520    Computation from NEMO code
2521    '''
2522    zh      = pdep  * R1_Z0                                  # Depth (m)
2523    zt      = ptemp * R1_T0                                  # Temperature (°C)
2524    zs      = np.sqrt ( np.abs( psal + RDELTAS ) * R1_S0 )   # Square root salinity (PSS)
2525    #
2526    zn3 = EOS013*zt + EOS103*zs+EOS003
2527    #
2528    zn2 = (EOS022*zt + EOS112*zs+EOS012)*zt + (EOS202*zs+EOS102)*zs+EOS002
2529    #
2530    zn1 = (
2531      (((EOS041*zt
2532       + EOS131*zs   + EOS031)*zt
2533       + (EOS221*zs  + EOS121)*zs + EOS021)*zt
2534       + ((EOS311*zs  + EOS211)*zs + EOS111)*zs + EOS011)*zt
2535       + (((EOS401*zs + EOS301)*zs + EOS201)*zs + EOS101)*zs + EOS001 )
2536    #
2537    zn0 = (
2538      (((((EOS060*zt
2539         + EOS150*zs      + EOS050)*zt
2540         + (EOS240*zs     + EOS140)*zs + EOS040)*zt
2541         + ((EOS330*zs    + EOS230)*zs + EOS130)*zs + EOS030)*zt
2542         + (((EOS420*zs   + EOS320)*zs + EOS220)*zs + EOS120)*zs + EOS020)*zt
2543         + ((((EOS510*zs  + EOS410)*zs + EOS310)*zs + EOS210)*zs + EOS110)*zs + EOS010)*zt
2544         + (((((EOS600*zs + EOS500)*zs + EOS400)*zs + EOS300)*zs +
2545                                       EOS200)*zs + EOS100)*zs + EOS000 )
2546    #
2547    prho  = ( ( zn3 * zh + zn2 ) * zh + zn1 ) * zh + zn0
2548    #
2549    return prho
2550
2551## ===========================================================================
2552##
2553##                               That's all folk's !!!
2554##
2555## ===========================================================================
2556
2557# def __is_orca_north_fold__ ( Xtest, cname_long='T' ) :
2558#     '''
2559#     Ported (pirated !!?) from Sosie
2560
2561#     Tell if there is a 2/point band overlaping folding at the north pole typical of the ORCA grid
2562
2563#     0 => not an orca grid (or unknown one)
2564#     4 => North fold T-point pivot (ex: ORCA2)
2565#     6 => North fold F-point pivot (ex: ORCA1)
2566
2567#     We need all this 'cname_long' stuff because with our method, there is a
2568#     confusion between "Grid_U with T-fold" and "Grid_V with F-fold"
2569#     => so knowing the name of the longitude array (as in namelist, and hence as
2570#     in netcdf file) might help taking the righ decision !!! UGLY!!!
2571#     => not implemented yet
2572#     '''
2573
2574#     ifld_nord =  0 ; cgrd_type = 'X'
2575#     ny, nx = Xtest.shape[-2:]
2576
2577#     if ny > 3 : # (case if called with a 1D array, ignoring...)
2578#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-1:nx-nx//2+1:-1] ).sum() == 0. :
2579#           ifld_nord = 4 ; cgrd_type = 'T' # T-pivot, grid_T
2580
2581#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-2:nx-nx//2  :-1] ).sum() == 0. :
2582#             if cnlon == 'U' : ifld_nord = 4 ;  cgrd_type = 'U' # T-pivot, grid_T
2583#                 ## LOLO: PROBLEM == 6, V !!!
2584
2585#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-1:nx-nx//2+1:-1] ).sum() == 0. :
2586#             ifld_nord = 4 ; cgrd_type = 'V' # T-pivot, grid_V
2587
2588#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-2, nx-1-1:nx-nx//2:-1] ).sum() == 0. :
2589#             ifld_nord = 6 ; cgrd_type = 'T'# F-pivot, grid_T
2590
2591#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-1, nx-1:nx-nx//2-1:-1] ).sum() == 0. :
2592#             ifld_nord = 6 ;  cgrd_type = 'U' # F-pivot, grid_U
2593
2594#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-2:nx-nx//2  :-1] ).sum() == 0. :
2595#             if cnlon == 'V' : ifld_nord = 6 ; cgrd_type = 'V' # F-pivot, grid_V
2596#                 ## LOLO: PROBLEM == 4, U !!!
2597
2598#     return ifld_nord, cgrd_type
Note: See TracBrowser for help on using the repository browser.