source: TOOLS/MOSAIX/nemo.py @ 6764

Last change on this file since 6764 was 6666, checked in by omamce, 8 months ago

O.M. : MOSAIX

Improved code with pylint analysis

  • Property svn:keywords set to Date Revision HeadURL Author Id
File size: 93.0 KB
RevLine 
[6064]1# -*- coding: utf-8 -*-
2## ===========================================================================
3##
[6245]4##  This software is governed by the CeCILL  license under French law and
5##  abiding by the rules of distribution of free software.  You can  use,
6##  modify and/ or redistribute the software under the terms of the CeCILL
7##  license as circulated by CEA, CNRS and INRIA at the following URL
8##  "http://www.cecill.info".
[6064]9##
[6245]10##  Warning, to install, configure, run, use any of Olivier Marti's
11##  software or to read the associated documentation you'll need at least
12##  one (1) brain in a reasonably working order. Lack of this implement
13##  will void any warranties (either express or implied).
14##  O. Marti assumes no responsability for errors, omissions,
15##  data loss, or any other consequences caused directly or indirectly by
16##  the usage of his software by incorrectly or partially configured
17##  personal.
[6190]18##
19## ===========================================================================
[6666]20'''Utilities to plot NEMO ORCA fields,
[6064]21
[6666]22Handles periodicity and other stuff
[6064]23
[6666]24- Lots of tests for xarray object
25- Not much tested for numpy objects
26
27Author: olivier.marti@lsce.ipsl.fr
28
[6190]29## SVN information
[6666]30Author   = "$Author$"
31Date     = "$Date$"
32Revision = "$Revision$"
33Id       = "$Id$"
34HeadURL  = "$HeadURL$"
35'''
[6190]36
[6245]37import numpy as np
[6666]38import xarray as xr
[6245]39
[6666]40# Tries to import some moldules that are available in all Python installations
41#  and used in only a few specific functions
42try :
43    from sklearn.impute import SimpleImputer
44except ImportError as err :
45    print ( f'===> Warning : Module nemo : Import error of sklearn.impute.SimpleImputer : {err}' )
46    SimpleImputer = None
[6064]47
[6666]48try :
49    import f90nml
50except ImportError as err :
51    print ( f'===> Warning : Module nemo : Import error of f90nml : {err}' )
52    f90nml = None
[6245]53
[6666]54# SVN information
55__SVN__ = ({
56    'Author'   : '$Author$',
57    'Date'     : '$Date$',
58    'Revision' : '$Revision$',
59    'Id'       : '$Id$',
60    'HeadURL'  : '$HeadURL$',
61    'SVN_Date' : '$SVN_Date: $',
62    })
63##
64   
65RPI   = np.pi
66RAD   = np.deg2rad (1.0)
67DAR   = np.rad2deg (1.0)
68REPSI = np.finfo (1.0).eps
[6245]69
[6666]70NPERIO_VALID_RANGE = [0, 1, 4, 4.2, 5, 6, 6.2]
[6064]71
[6666]72RAAMO    =      12          # Number of months in one year
73RJJHH    =      24          # Number of hours in one day
74RHHMM    =      60          # Number of minutes in one hour
75RMMSS    =      60          # Number of seconds in one minute
76RA       = 6371229.0        # Earth radius                                  [m]
77GRAV     =       9.80665    # Gravity                                       [m/s2]
78RT0      =     273.15       # Freezing point of fresh water                 [Kelvin]
79RAU0     =    1026.0        # Volumic mass of sea water                     [kg/m3]
80SICE     =       6.0        # Salinity of ice (for pisces)                  [psu]
81SOCE     =      34.7        # Salinity of sea (for pisces and isf)          [psu]
82RLEVAP   =       2.5e+6     # Latent heat of evaporation (water)            [J/K]
83VKARMN   =       0.4        # Von Karman constant
84STEFAN   =       5.67e-8    # Stefan-Boltzmann constant                     [W/m2/K4]
85RHOS     =     330.         # Volumic mass of snow                          [kg/m3]
86RHOI     =     917.         # Volumic mass of sea ice                       [kg/m3]
87RHOW     =    1000.         # Volumic mass of freshwater in melt ponds      [kg/m3]
88RCND_I   =       2.034396   # Thermal conductivity of fresh ice             [W/m/K]
89RCPI     =    2067.0        # Specific heat of fresh ice                    [J/kg/K]
90RLSUB    =       2.834e+6   # Pure ice latent heat of sublimation           [J/kg]
91RLFUS    =       0.334e+6   # Latent heat of fusion of fresh ice            [J/kg]
92RTMLT    =       0.054      # Decrease of seawater meltpoint with salinity
[6064]93
[6666]94RDAY     = RJJHH * RHHMM * RMMSS                # Day length               [s]
95RSIYEA   = 365.25 * RDAY * 2. * RPI / 6.283076  # Sideral year length      [s]
96RSIDAY   = RDAY / (1. + RDAY / RSIYEA)          # Sideral day length       [s]
97OMEGA    = 2. * RPI / RSIDAY                    # Earth rotation parameter [s-1]
[6190]98
[6666]99## Default names of dimensions
100UDIMS = {'x':'x', 'y':'y', 'z':'olevel', 't':'time_counter'}
[6190]101
[6666]102## All possibles name of dimensions in Nemo files
103XNAME = [ 'x', 'X', 'X1', 'xx', 'XX',
104              'x_grid_T', 'x_grid_U', 'x_grid_V', 'x_grid_F', 'x_grid_W',
105              'lon', 'nav_lon', 'longitude', 'X1', 'x_c', 'x_f', ]
106YNAME = [ 'y', 'Y', 'Y1', 'yy', 'YY',
107              'y_grid_T', 'y_grid_U', 'y_grid_V', 'y_grid_F', 'y_grid_W',
108              'lat', 'nav_lat', 'latitude' , 'Y1', 'y_c', 'y_f', ]
109ZNAME = [ 'z', 'Z', 'Z1', 'zz', 'ZZ', 'depth', 'tdepth', 'udepth',
110              'vdepth', 'wdepth', 'fdepth', 'deptht', 'depthu',
111              'depthv', 'depthw', 'depthf', 'olevel', 'z_c', 'z_f', ]
112TNAME = [ 't', 'T', 'tt', 'TT', 'time', 'time_counter', 'time_centered', ]
113
114## All possibles name of units of dimensions in Nemo files
115XUNIT = [ 'degrees_east', ]
116YUNIT = [ 'degrees_north', ]
117ZUNIT = [ 'm', 'meter', ]
118TUNIT = [ 'second', 'minute', 'hour', 'day', 'month', 'year', ]
119
120## All possibles size of dimensions in Orca files
121XLENGTH = [ 180, 182, 360, 362, 1440 ]
122YLENGTH = [ 148, 149, 331, 332 ]
123ZLENGTH = [ 31, 75]
124
[6245]125## ===========================================================================
[6666]126def __mmath__ (ptab, default=None) :
127    '''Determines the type of tab : xarray, numpy or numpy.ma object ?
128
129    Returns type
130    '''
[6245]131    mmath = default
[6666]132    if isinstance (ptab, xr.core.dataarray.DataArray) :
133        mmath = xr
134    if isinstance (ptab, np.ndarray) :
135        mmath = np
136    if isinstance (ptab, np.ma.MaskType) :
137        mmath = np.ma
[6190]138
[6245]139    return mmath
[6190]140
[6666]141def __guess_nperio__ (jpj, jpi, nperio=None, out='nperio') :
142    '''Tries to guess the value of nperio (periodicity parameter.
[6245]143
[6666]144    See NEMO documentation for details)
[6064]145    Inputs
[6105]146    jpj    : number of latitudes
[6064]147    jpi    : number of longitudes
148    nperio : periodicity parameter
149    '''
[6666]150    if nperio is None :
151        nperio = __guess_config__ (jpj, jpi, nperio=None, out=out)
[6245]152    return nperio
153
[6666]154def __guess_config__ (jpj, jpi, nperio=None, config=None, out='nperio') :
155    '''Tries to guess the value of nperio (periodicity parameter).
156
157    See NEMO documentation for details)
[6245]158    Inputs
159    jpj    : number of latitudes
160    jpi    : number of longitudes
161    nperio : periodicity parameter
162    '''
[6666]163    print ( jpi, jpj)
164    if nperio is None :
[6105]165        ## Values for NEMO version < 4.2
[6666]166        if ( (jpj == 149 and jpi == 182) or (jpj is None and jpi == 182) or
167             (jpj == 149 or jpi is None) ) :
168            # ORCA2. We choose legacy orca2.
169            config, nperio, iperio, jperio, nfold, nftype = 'ORCA2.3' , 4, 1, 0, 1, 'T'
170        if ((jpj == 332 and jpi == 362) or (jpj is None and jpi == 362) or
171            (jpj ==  332 and jpi is None) ) : # eORCA1.
172            config, nperio, iperio, jperio, nfold, nftype = 'eORCA1.2', 6, 1, 0, 1, 'F'
[6190]173        if jpi == 1442 :  # ORCA025.
[6666]174            config, nperio, iperio, jperio, nfold, nftype = 'ORCA025' , 6, 1, 0, 1, 'F'
[6190]175        if jpj ==  294 : # ORCA1
[6666]176            config, nperio, iperio, jperio, nfold, nftype = 'ORCA1'   , 6, 1, 0, 1, 'F'
177
[6105]178        ## Values for NEMO version >= 4.2. No more halo points
[6666]179        if  (jpj == 148 and jpi == 180) or (jpj is None and jpi == 180) or \
180            (jpj == 148 and jpi is None) :  # ORCA2. We choose legacy orca2.
181            config, nperio, iperio, jperio, nfold, nftype = 'ORCA2.4' , 4.2, 1, 0, 1, 'F'
182        if  (jpj == 331 and jpi == 360) or (jpj is None and jpi == 360) or \
183            (jpj == 331 and jpi is None) : # eORCA1.
184            config, nperio, iperio, jperio, nfold, nftype = 'eORCA1.4', 6.2, 1, 0, 1, 'F'
[6190]185        if jpi == 1440 : # ORCA025.
[6666]186            config, nperio, iperio, jperio, nfold, nftype = 'ORCA025' , 6.2, 1, 0, 1, 'F'
187
188        if nperio is None :
189            raise ValueError ('in nemo module : nperio not found, and cannot by guessed')
190
191        if nperio in NPERIO_VALID_RANGE :
192            print ( f'nperio set as {nperio} (deduced from {jpj=} and {jpi=})' )
[6064]193        else :
[6666]194            raise ValueError ( f'nperio set as {nperio} (deduced from {jpi=} and {jpj=}) : \n'+
195                                'nemo.py is not ready for this value' )
[6064]196
[6666]197    if out == 'nperio' :
198        return nperio
199    if out == 'config' :
200        return config
201    if out == 'perio'  :
202        return iperio, jperio, nfold, nftype
203    if out in ['full', 'all'] :
204        return {'nperio':nperio, 'iperio':iperio, 'jperio':jperio, 'nfold':nfold, 'nftype':nftype}
205
206def __guess_point__ (ptab) :
207    '''Tries to guess the grid point (periodicity parameter.
208
209    See NEMO documentation for details)
[6064]210    For array conforments with xgcm requirements
211
212    Inputs
[6190]213         ptab : xarray array
[6064]214
215    Credits : who is the original author ?
216    '''
[6666]217
218    gp = None
[6245]219    mmath = __mmath__ (ptab)
220    if mmath == xr :
[6666]221        if ('x_c' in ptab.dims and 'y_c' in ptab.dims ) :
222            gp = 'T'
223        if ('x_f' in ptab.dims and 'y_c' in ptab.dims ) :
224            gp = 'U'
225        if ('x_c' in ptab.dims and 'y_f' in ptab.dims ) :
226            gp = 'V'
227        if ('x_f' in ptab.dims and 'y_f' in ptab.dims ) :
228            gp = 'F'
229        if ('x_c' in ptab.dims and 'y_c' in ptab.dims
230              and 'z_c' in ptab.dims )                :
231            gp = 'T'
232        if ('x_c' in ptab.dims and 'y_c' in ptab.dims
233                and 'z_f' in ptab.dims )                :
234            gp = 'W'
235        if ('x_f' in ptab.dims and 'y_c' in ptab.dims
236                and 'z_f' in ptab.dims )                :
237            gp = 'U'
238        if ('x_c' in ptab.dims and 'y_f' in ptab.dims
239                and 'z_f' in ptab.dims ) :
240            gp = 'V'
241        if ('x_f' in ptab.dims and 'y_f' in ptab.dims
242                and 'z_f' in ptab.dims ) :
243            gp = 'F'
244
245        if gp is None :
246            raise AttributeError ('in nemo module : cd_type not found, and cannot by guessed')
247        print ( f'Grid set as {gp} deduced from dims {ptab.dims}' )
248        return gp
[6064]249    else :
[6666]250        raise AttributeError  ('in nemo module : cd_type not found, input is not an xarray data')
[6064]251
[6666]252def get_shape ( ptab ) :
253    '''Get shape of ptab return a string with axes names
254
255    shape may contain X, Y, Z or T
256    Y is missing for a latitudinal slice
257    X is missing for on longitudinal slice
258    etc ...
259    '''
260
261    g_shape = ''
262    if __find_axis__ (ptab, 'x')[0] :
263        g_shape = 'X'
264    if __find_axis__ (ptab, 'y')[0] :
265        g_shape = 'Y' + g_shape
266    if __find_axis__ (ptab, 'z')[0] :
267        g_shape = 'Z' + g_shape
268    if __find_axis__ (ptab, 't')[0] :
269        g_shape = 'T' + g_shape
270    return g_shape
271
[6245]272def lbc_diag (nperio) :
[6666]273    '''Useful to switch between field with and without halo'''
274    lperio, aperio = nperio, False
[6245]275    if nperio == 4.2 :
[6666]276        lperio, aperio = 4, True
[6245]277    if nperio == 6.2 :
[6666]278        lperio, aperio = 6, True
279    return lperio, aperio
[6245]280
[6666]281def __find_axis__ (ptab, axis='z', back=True) :
282    '''Returns name and name of the requested axis'''
283    mmath = __mmath__ (ptab)
284    ax, ix = None, None
[6190]285
[6666]286    if axis in XNAME :
287        ax_name, unit_list, length = XNAME, XUNIT, XLENGTH
288    if axis in YNAME :
289        ax_name, unit_list, length = YNAME, YUNIT, YLENGTH
290    if axis in ZNAME :
291        ax_name, unit_list, length = ZNAME, ZUNIT, ZLENGTH
292    if axis in TNAME :
293        ax_name, unit_list, length = TNAME, TUNIT, None
294
[6245]295    if mmath == xr :
[6666]296        # Try by name
297        for dim in ax_name :
298            if dim in ptab.dims :
299                ix, ax = ptab.dims.index (dim), dim
[6190]300
[6666]301        # If not found, try by axis attributes
302        if not ix :
303            for i, dim in enumerate (ptab.dims) :
304                if 'axis' in ptab.coords[dim].attrs.keys() :
305                    l_axis = ptab.coords[dim].attrs['axis']
306                    if axis in ax_name and l_axis == 'X' :
307                        ix, ax = (i, dim)
308                    if axis in ax_name and l_axis == 'Y' :
309                        ix, ax = (i, dim)
310                    if axis in ax_name and l_axis == 'Z' :
311                        ix, ax = (i, dim)
312                    if axis in ax_name and l_axis == 'T' :
313                        ix, ax = (i, dim)
[6245]314
[6666]315        # If not found, try by units
316        if not ix :
317            for i, dim in enumerate (ptab.dims) :
318                if 'units' in ptab.coords[dim].attrs.keys() :
319                    for name in unit_list :
320                        if name in ptab.coords[dim].attrs['units'] :
321                            ix, ax = i, dim
[6064]322
[6666]323    # If numpy array or dimension not found, try by length
324    if mmath != xr or not ix :
325        if length :
326            l_shape = ptab.shape
327            for nn in np.arange ( len(l_shape) ) :
328                if l_shape[nn] in length :
329                    ix = nn
330
331    if ix and back :
332        ix -= len(ptab.shape)
333
334    return ax, ix
335
336def find_axis ( ptab, axis='z', back=True ) :
337    '''Version of find_axis with no __'''
338    ix, xx = __find_axis__ (ptab, axis, back)
339    return xx, ix
340
341def fixed_lon (plon, center_lon=0.0) :
342    '''Returns corrected longitudes for nicer plots
343
[6190]344    lon        : longitudes of the grid. At least 2D.
345    center_lon : center longitude. Default=0.
[6064]346
[6666]347    Designed by Phil Pelson.
348    See https://gist.github.com/pelson/79cf31ef324774c97ae7
[6064]349    '''
[6666]350    mmath = __mmath__ (plon)
[6064]351
[6666]352    f_lon = plon.copy ()
353
354    f_lon = mmath.where (f_lon > center_lon+180., f_lon-360.0, f_lon)
355    f_lon = mmath.where (f_lon < center_lon-180., f_lon+360.0, f_lon)
356
357    for i, start in enumerate (np.argmax (np.abs (np.diff (f_lon, axis=-1)) > 180., axis=-1)) :
358        f_lon [..., i, start+1:] += 360.
359
[6064]360    # Special case for eORCA025
[6666]361    if f_lon.shape [-1] == 1442 :
362        f_lon [..., -2, :] = f_lon [..., -3, :]
363    if f_lon.shape [-1] == 1440 :
364        f_lon [..., -1, :] = f_lon [..., -2, :]
[6190]365
[6666]366    if f_lon.min () > center_lon :
367        f_lon += -360.0
368    if f_lon.max () < center_lon :
369        f_lon +=  360.0
[6064]370
[6666]371    if f_lon.min () < center_lon-360.0 :
372        f_lon +=  360.0
373    if f_lon.max () > center_lon+360.0 :
374        f_lon += -360.0
375
376    return f_lon
377
378def bounds_clolon ( pbounds_lon, plon, rad=False, deg=True) :
379    '''Choose closest to lon0 longitude, adding/substacting 360° if needed
[6245]380    '''
381
[6666]382    if rad :
383        lon_range = 2.0*np.pi
384    if deg :
385        lon_range = 360.0
386    b_clolon = pbounds_lon.copy ()
387
388    b_clolon = xr.where ( b_clolon < plon-lon_range/2.,
389                          b_clolon+lon_range,
390                          b_clolon )
391    b_clolon = xr.where ( b_clolon > plon+lon_range/2.,
392                          b_clolon-lon_range,
393                          b_clolon )
394    return b_clolon
395
396def unify_dims ( dd, x='x', y='y', z='olevel', t='time_counter', verbose=False ) :
397    '''Rename dimensions to unify them between NEMO versions
398    '''
399    for xx in XNAME :
400        if xx in dd.dims and xx != x :
401            if verbose :
402                print ( f"{xx} renamed to {x}" )
403            dd = dd.rename ( {xx:x})
404
405    for yy in YNAME :
406        if yy in dd.dims and yy != y  :
407            if verbose :
408                print ( f"{yy} renamed to {y}" )
409            dd = dd.rename ( {yy:y} )
410
411    for zz in ZNAME :
412        if zz in dd.dims and zz != z :
413            if verbose :
414                print ( f"{zz} renamed to {z}" )
415            dd = dd.rename ( {zz:z} )
416
417    for tt in TNAME  :
418        if tt in dd.dims and tt != t :
419            if verbose :
420                print ( f"{tt} renamed to {t}" )
421            dd = dd.rename ( {tt:t} )
422
423    return dd
424
425
426if SimpleImputer : 
427  def fill_empty (ptab, sval=np.nan, transpose=False) :
428        '''Fill empty values
429
430        Useful when NEMO has run with no wet points options :
431        some parts of the domain, with no ocean points, have no
432        values
433        '''
434        mmath = __mmath__ (ptab)
435
436        imp = SimpleImputer (missing_values=sval, strategy='mean')
437        if transpose :
438            imp.fit (ptab.T)
439            ztab = imp.transform (ptab.T).T
440        else :
441            imp.fit (ptab)
442            ztab = imp.transform (ptab)
443
444        if mmath == xr :
445            ztab = xr.DataArray (ztab, dims=ztab.dims, coords=ztab.coords)
446            ztab.attrs.update (ptab.attrs)
447
448        return ztab
449
450   
451else : 
452    print ("Import error of sklearn.impute.SimpleImputer")
453    def fill_empty (ptab, sval=np.nan, transpose=False) :
454        '''Void version of fill_empy, because module sklearn.impute.SimpleImputer is not available
455
456        fill_empty :
457          Fill values
458
459          Useful when NEMO has run with no wet points options :
460          some parts of the domain, with no ocean points, have no
461          values
462        '''
463        print ( 'Error : module sklearn.impute.SimpleImputer not found' )
464        print ( 'Can not call fill_empty' )
465        print ( 'Call arguments where : ' )
466        print ( f'{ptab.shape=} {sval=} {transpose=}' )
467 
468def fill_lonlat (plon, plat, sval=-1) :
469    '''Fill longitude/latitude values
470
471    Useful when NEMO has run with no wet points options :
472    some parts of the domain, with no ocean points, have no
[6245]473    lon/lat values
474    '''
[6666]475    from sklearn.impute import SimpleImputer
476    mmath = __mmath__ (plon)
[6245]477
478    imp = SimpleImputer (missing_values=sval, strategy='mean')
[6666]479    imp.fit (plon)
480    zlon = imp.transform (plon)
481    imp.fit (plat.T)
482    zlat = imp.transform (plat.T).T
483
[6245]484    if mmath == xr :
[6666]485        zlon = xr.DataArray (zlon, dims=plon.dims, coords=plon.coords)
486        zlat = xr.DataArray (zlat, dims=plat.dims, coords=plat.coords)
487        zlon.attrs.update (plon.attrs)
488        zlat.attrs.update (plat.attrs)
[6245]489
[6666]490    zlon = fixed_lon (zlon)
[6245]491
[6666]492    return zlon, zlat
493
494def fill_bounds_lonlat (pbounds_lon, pbounds_lat, sval=-1) :
495    '''Fill longitude/latitude bounds values
496
497    Useful when NEMO has run with no wet points options :
[6245]498    some parts of the domain, with no ocean points, as no
499    lon/lat values
500    '''
[6666]501    mmath = __mmath__ (pbounds_lon)
[6245]502
[6666]503    z_bounds_lon = np.empty ( pbounds_lon.shape )
504    z_bounds_lat = np.empty ( pbounds_lat.shape )
505
[6245]506    imp = SimpleImputer (missing_values=sval, strategy='mean')
507
[6666]508    for n in np.arange (4) :
509        imp.fit (pbounds_lon[:,:,n])
510        z_bounds_lon[:,:,n] = imp.transform (pbounds_lon[:,:,n])
511        imp.fit (pbounds_lat[:,:,n].T)
512        z_bounds_lat[:,:,n] = imp.transform (pbounds_lat[:,:,n].T).T
513
[6245]514    if mmath == xr :
[6666]515        z_bounds_lon = xr.DataArray (pbounds_lon, dims=pbounds_lon.dims,
516                                         coords=pbounds_lon.coords)
517        z_bounds_lat = xr.DataArray (pbounds_lat, dims=pbounds_lat.dims,
518                                         coords=pbounds_lat.coords)
519        z_bounds_lon.attrs.update (pbounds_lat.attrs)
520        z_bounds_lat.attrs.update (pbounds_lat.attrs)
[6245]521
[6666]522    return z_bounds_lon, z_bounds_lat
523
524def jeq (plat) :
525    '''Returns j index of equator in the grid
526
[6064]527    lat : latitudes of the grid. At least 2D.
528    '''
[6666]529    mmath = __mmath__ (plat)
530    jy = __find_axis__ (plat, 'y')[-1]
[6190]531
[6245]532    if mmath == xr :
[6666]533        jj = int ( np.mean ( np.argmin (np.abs (np.float64 (plat)),
534                                                axis=jy) ) )
535    else :
536        jj = np.argmin (np.abs (np.float64 (plat[...,:, 0])))
[6064]537
[6666]538    return jj
539
540def lon1d (plon, plat=None) :
541    '''Returns 1D longitude for simple plots.
542
543    plon : longitudes of the grid
544    plat (optionnal) : latitudes of the grid
[6064]545    '''
[6666]546    mmath = __mmath__ (plon)
547    jpj, jpi  = plon.shape [-2:]
548    if np.max (plat) :
549        je    = jeq (plat)
550        lon0 = plon [..., je, 0].copy()
551        dlon = plon [..., je, 1].copy() - plon [..., je, 0].copy()
552        lon_1d = np.linspace ( start=lon0, stop=lon0+360.+2*dlon, num=jpi )
[6064]553    else :
[6666]554        lon0 = plon [..., jpj//3, 0].copy()
555        dlon = plon [..., jpj//3, 1].copy() - plon [..., jpj//3, 0].copy()
556        lon_1d = np.linspace ( start=lon0, stop=lon0+360.+2*dlon, num=jpi )
[6064]557
[6666]558    #start = np.argmax (np.abs (np.diff (lon1D, axis=-1)) > 180.0, axis=-1)
559    #lon1D [..., start+1:] += 360
[6245]560
561    if mmath == xr :
[6666]562        lon_1d = xr.DataArray( lon_1d, dims=('lon',), coords=(lon_1d,))
563        lon_1d.attrs.update (plon.attrs)
564        lon_1d.attrs['units']         = 'degrees_east'
565        lon_1d.attrs['standard_name'] = 'longitude'
566        lon_1d.attrs['long_name :']   = 'Longitude'
[6064]567
[6666]568    return lon_1d
569
570def latreg (plat, diff=0.1) :
571    '''Returns maximum j index where gridlines are along latitudes
572    in the northern hemisphere
573
[6064]574    lat : latitudes of the grid (2D)
575    diff [optional] : tolerance
576    '''
[6666]577    #mmath = __mmath__ (plat)
578    if diff is None :
579        dy = np.float64 (np.mean (np.abs (plat -
580                        np.roll (plat,shift=1,axis=-2, roll_coords=False))))
581        print ( f'{dy=}' )
[6064]582        diff = dy/100.
583
[6666]584    je     = jeq (plat)
585    jreg   = np.where (plat[...,je:,:].max(axis=-1) -
586                       plat[...,je:,:].min(axis=-1)< diff)[-1][-1] + je
587    lareg  = np.float64 (plat[...,jreg,:].mean(axis=-1))
[6064]588
[6666]589    return jreg, lareg
[6064]590
[6666]591def lat1d (plat) :
592    '''Returns 1D latitudes for zonal means and simple plots.
593
594    plat : latitudes of the grid (2D)
[6064]595    '''
[6666]596    mmath = __mmath__ (plat)
597    iy = __find_axis__ (plat, 'y')[-1]
598    jpj = plat.shape[iy]
[6064]599
[6666]600    dy     = np.float64 (np.mean (np.abs (plat - np.roll (plat, shift=1,axis=-2))))
601    je     = jeq (plat)
602    lat_eq = np.float64 (plat[...,je,:].mean(axis=-1))
[6064]603
[6666]604    jreg, lat_reg = latreg (plat)
605    lat_ave = np.mean (plat, axis=-1)
606
607    if np.abs (lat_eq) < dy/100. : # T, U or W grid
608        if jpj-1 > jreg :
609            dys = (90.-lat_reg) / (jpj-jreg-1)*0.5
610        else            :
611            dys = (90.-lat_reg) / 2.0
[6064]612        yrange = 90.-dys-lat_reg
[6190]613    else                           :  # V or F grid
[6666]614        yrange = 90.-lat_reg
615
616    if jpj-1 > jreg :
617        lat_1d = mmath.where (lat_ave<lat_reg,
618                 lat_ave,
619                 lat_reg + yrange * (np.arange(jpj)-jreg)/(jpj-jreg-1) )
620    else :
621        lat_1d = lat_ave
622    lat_1d[-1] = 90.0
623
[6245]624    if mmath == xr :
[6666]625        lat_1d = xr.DataArray( lat_1d.values, dims=('lat',), coords=(lat_1d,))
626        lat_1d.attrs.update (plat.attrs)
627        lat_1d.attrs ['units']         = 'degrees_north'
628        lat_1d.attrs ['standard_name'] = 'latitude'
629        lat_1d.attrs ['long_name :']   = 'Latitude'
[6064]630
[6666]631    return lat_1d
[6064]632
[6666]633def latlon1d (plat, plon) :
634    '''Returns simple latitude and longitude (1D) for simple plots.
635
636    plat, plon : latitudes and longitudes of the grid (2D)
[6064]637    '''
[6666]638    return lat1d (plat),  lon1d (plon, plat)
[6064]639
[6666]640def ff (plat) :
641    '''Returns Coriolis factor
[6064]642    '''
[6666]643    zff   = np.sin (RAD * plat) * OMEGA
644    return zff
[6064]645
[6666]646def beta (plat) :
647    '''Return Beta factor (derivative of Coriolis factor)
648    '''
649    zbeta = np.cos (RAD * plat) * OMEGA / RA
650    return zbeta
651
[6064]652def mask_lonlat (ptab, x0, x1, y0, y1, lon, lat, sval=np.nan) :
[6666]653    '''Returns masked values outside a lat/lon box
654    '''
[6245]655    mmath = __mmath__ (ptab)
[6666]656    if mmath == xr :
[6190]657        lon = lon.copy().to_masked_array()
658        lat = lat.copy().to_masked_array()
[6666]659
660    mask = np.logical_and (np.logical_and(lat>y0, lat<y1),
661            np.logical_or (np.logical_or (
662                np.logical_and(lon>x0, lon<x1),
663                np.logical_and(lon+360>x0, lon+360<x1)),
664                np.logical_and(lon-360>x0, lon-360<x1)))
665    tab = mmath.where (mask, ptab, sval)
666
[6064]667    return tab
[6245]668
[6666]669def extend (ptab, blon=False, jplus=25, jpi=None, nperio=4) :
670    '''Returns extended field eastward to have better plots,
671    and box average crossing the boundary
672
[6064]673    Works only for xarray and numpy data (?)
[6666]674    Useful for plotting vertical sections in OCE and ATM.
[6064]675
[6666]676    ptab : field to extend.
677    blon  : (optional, default=False) : if True, add 360 in the extended
678          parts of the field
679    jpi   : normal longitude dimension of the field. extend does nothing
680          if the actual size of the field != jpi
681          (avoid to extend several times in notebooks)
682    jplus (optional, default=25) : number of points added on
683          the east side of the field
684
[6064]685    '''
[6666]686    mmath = __mmath__ (ptab)
[6064]687
[6666]688    if ptab.shape[-1] == 1 :
689        tabex = ptab
690
[6064]691    else :
[6666]692        if jpi is None :
693            jpi = ptab.shape[-1]
[6064]694
[6666]695        if blon :
696            xplus = -360.0
697        else   :
698            xplus =    0.0
[6064]699
[6666]700        if ptab.shape[-1] > jpi :
701            tabex = ptab
[6064]702        else :
[6666]703            if nperio in [ 0, 4.2 ] :
704                istart, le, la = 0, jpi+1, 0
[6064]705            if nperio == 1 :
[6666]706                istart, le, la = 0, jpi+1, 0
707            if nperio in [4, 6] : # OPA case with two halo points for periodicity
708                # Perfect, except at the pole that should be masked by lbc_plot
709                istart, le, la = 1, jpi-2, 1
[6245]710            if mmath == xr :
[6666]711                tabex = np.concatenate (
712                     (ptab.values[..., istart   :istart+le+1    ] + xplus,
713                      ptab.values[..., istart+la:istart+la+jplus]         ),
714                      axis=-1)
715                lon    = ptab.dims[-1]
[6190]716                new_coords = []
[6666]717                for coord in ptab.dims :
718                    if coord == lon :
719                        new_coords.append ( np.arange( tabex.shape[-1]))
720                    else            :
721                        new_coords.append ( ptab.coords[coord].values)
722                tabex = xr.DataArray ( tabex, dims=ptab.dims,
723                                           coords=new_coords )
724            else :
725                tabex = np.concatenate (
726                    (ptab [..., istart   :istart+le+1    ] + xplus,
727                     ptab [..., istart+la:istart+la+jplus]          ),
728                     axis=-1)
729    return tabex
[6064]730
[6666]731def orca2reg (dd, lat_name='nav_lat', lon_name='nav_lon',
732                  y_name='y', x_name='x') :
733    '''Assign an ORCA dataset on a regular grid.
734
[6064]735    For use in the tropical region.
[6666]736    Inputs :
[6064]737      ff : xarray dataset
738      lat_name, lon_name : name of latitude and longitude 2D field in ff
739      y_name, x_name     : namex of dimensions in ff
[6666]740
[6245]741      Returns : xarray dataset with rectangular grid. Incorrect above 20°N
[6064]742    '''
743    # Compute 1D longitude and latitude
[6666]744    (zlat, zlon) = latlon1d ( dd[lat_name], dd[lon_name])
[6245]745
[6666]746    zdd = dd
[6064]747    # Assign lon and lat as dimensions of the dataset
[6666]748    if y_name in zdd.dims :
749        zlat = xr.DataArray (zlat, coords=[zlat,], dims=['lat',])
750        zdd  = zdd.rename_dims ({y_name: "lat",}).assign_coords (lat=zlat)
751    if x_name in zdd.dims :
752        zlon = xr.DataArray (zlon, coords=[zlon,], dims=['lon',])
753        zdd  = zdd.rename_dims ({x_name: "lon",}).assign_coords (lon=zlon)
[6064]754    # Force dimensions to be in the right order
755    coord_order = ['lat', 'lon']
756    for dim in [ 'depthw', 'depthv', 'depthu', 'deptht', 'depth', 'z',
[6666]757                 'time_counter', 'time', 'tbnds',
[6064]758                 'bnds', 'axis_nbounds', 'two2', 'two1', 'two', 'four',] :
[6666]759        if dim in zdd.dims :
760            coord_order.insert (0, dim)
[6064]761
[6666]762    zdd = zdd.transpose (*coord_order)
763    return zdd
764
[6190]765def lbc_init (ptab, nperio=None) :
[6666]766    '''Prepare for all lbc calls
767
[6064]768    Set periodicity on input field
769    nperio    : Type of periodicity
770      0       : No periodicity
771      1, 4, 6 : Cyclic on i dimension (generaly longitudes) with 2 points halo
772      2       : Obsolete (was symmetric condition at southern boundary ?)
773      3, 4    : North fold T-point pivot (legacy ORCA2)
774      5, 6    : North fold F-point pivot (ORCA1, ORCA025, ORCA2 with new grid for paleo)
775    cd_type   : Grid specification : T, U, V or F
776
777    See NEMO documentation for further details
[6190]778    '''
[6666]779    jpi, jpj = None, None
780    ax, ix = __find_axis__ (ptab, 'x')
781    ay, jy = __find_axis__ (ptab, 'y')
782    if ax :
783        jpi = ptab.shape[ix]
784    if ay :
785        jpj = ptab.shape[jy]
[6064]786
[6666]787    if nperio is None :
788        nperio = __guess_nperio__ (jpj, jpi, nperio)
789
790    if nperio not in NPERIO_VALID_RANGE :
791        raise AttributeError ( f'{nperio=} is not in the valid range {NPERIO_VALID_RANGE}' )
792
[6190]793    return jpj, jpi, nperio
[6666]794
795def lbc (ptab, nperio=None, cd_type='T', psgn=1.0, nemo_4u_bug=False) :
796    '''Set periodicity on input field
797
[6190]798    ptab      : Input array (works for rank 2 at least : ptab[...., lat, lon])
[6064]799    nperio    : Type of periodicity
800    cd_type   : Grid specification : T, U, V or F
801    psgn      : For change of sign for vector components (1 for scalars, -1 for vector components)
[6666]802
[6064]803    See NEMO documentation for further details
[6190]804    '''
[6666]805    jpi, nperio = lbc_init (ptab, nperio)[1:]
806    ax = __find_axis__ (ptab, 'x')[0]
807    ay = __find_axis__ (ptab, 'y')[0]
[6064]808    psgn   = ptab.dtype.type (psgn)
[6666]809    mmath  = __mmath__ (ptab)
810
811    if mmath == xr :
812        ztab = ptab.values.copy ()
813    else           :
814        ztab = ptab.copy ()
815
816    if ax :
817        #
818        #> East-West boundary conditions
819        # ------------------------------
820        if nperio in [1, 4, 6] :
[6064]821        # ... cyclic
[6666]822            ztab [...,  0] = ztab [..., -2]
823            ztab [..., -1] = ztab [...,  1]
[6105]824
[6666]825        if ay :
826            #
827            #> North-South boundary conditions
828            # --------------------------------
829            if nperio in [3, 4] :  # North fold T-point pivot
830                if cd_type in [ 'T', 'W' ] : # T-, W-point
831                    ztab [..., -1, 1:       ] = psgn * ztab [..., -3, -1:0:-1      ]
832                    ztab [..., -1, 0        ] = psgn * ztab [..., -3, 2            ]
833                    ztab [..., -2, jpi//2:  ] = psgn * ztab [..., -2, jpi//2:0:-1  ]
[6190]834
[6666]835                if cd_type == 'U' :
836                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -3, -1:0:-1      ]
837                    ztab [..., -1,  0       ] = psgn * ztab [..., -3,  1           ]
838                    ztab [..., -1, -1       ] = psgn * ztab [..., -3, -2           ]
[6064]839
[6666]840                    if nemo_4u_bug :
841                        ztab [..., -2, jpi//2+1:-1] = psgn * ztab [..., -2, jpi//2-2:0:-1]
842                        ztab [..., -2, jpi//2-1   ] = psgn * ztab [..., -2, jpi//2       ]
843                    else :
844                        ztab [..., -2, jpi//2-1:-1] = psgn * ztab [..., -2, jpi//2:0:-1]
[6190]845
[6666]846                if cd_type == 'V' :
847                    ztab [..., -2, 1:       ] = psgn * ztab [..., -3, jpi-1:0:-1   ]
848                    ztab [..., -1, 1:       ] = psgn * ztab [..., -4, -1:0:-1      ]
849                    ztab [..., -1, 0        ] = psgn * ztab [..., -4, 2            ]
850
851                if cd_type == 'F' :
852                    ztab [..., -2, 0:-1     ] = psgn * ztab [..., -3, -1:0:-1      ]
853                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -4, -1:0:-1      ]
854                    ztab [..., -1,  0       ] = psgn * ztab [..., -4,  1           ]
855                    ztab [..., -1, -1       ] = psgn * ztab [..., -4, -2           ]
856
857            if nperio in [4.2] :  # North fold T-point pivot
858                if cd_type in [ 'T', 'W' ] : # T-, W-point
859                    ztab [..., -1, jpi//2:  ] = psgn * ztab [..., -1, jpi//2:0:-1  ]
860
861                if cd_type == 'U' :
862                    ztab [..., -1, jpi//2-1:-1] = psgn * ztab [..., -1, jpi//2:0:-1]
863
864                if cd_type == 'V' :
865                    ztab [..., -1, 1:       ] = psgn * ztab [..., -2, jpi-1:0:-1   ]
866
867                if cd_type == 'F' :
868                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -2, -1:0:-1      ]
869
870            if nperio in [5, 6] :            #  North fold F-point pivot
871                if cd_type in ['T', 'W']  :
872                    ztab [..., -1, 0:       ] = psgn * ztab [..., -2, -1::-1       ]
873
874                if cd_type == 'U' :
875                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -2, -2::-1       ]
876                    ztab [..., -1, -1       ] = psgn * ztab [..., -2, 0            ] # Bug ?
877
878                if cd_type == 'V' :
879                    ztab [..., -1, 0:       ] = psgn * ztab [..., -3, -1::-1       ]
880                    ztab [..., -2, jpi//2:  ] = psgn * ztab [..., -2, jpi//2-1::-1 ]
881
882                if cd_type == 'F' :
883                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -3, -2::-1       ]
884                    ztab [..., -1, -1       ] = psgn * ztab [..., -3, 0            ]
885                    ztab [..., -2, jpi//2:-1] = psgn * ztab [..., -2, jpi//2-2::-1 ]
886
887            #
888            #> East-West boundary conditions
889            # ------------------------------
890            if nperio in [1, 4, 6] :
891                # ... cyclic
892                ztab [...,  0] = ztab [..., -2]
893                ztab [..., -1] = ztab [...,  1]
894
[6245]895    if mmath == xr :
[6666]896        ztab = xr.DataArray ( ztab, dims=ptab.dims, coords=ptab.coords )
[6245]897        ztab.attrs = ptab.attrs
[6666]898
[6190]899    return ztab
900
[6064]901def lbc_mask (ptab, nperio=None, cd_type='T', sval=np.nan) :
[6666]902    '''Mask fields on duplicated points
903
[6190]904    ptab      : Input array. Rank 2 at least : ptab [...., lat, lon]
[6064]905    nperio    : Type of periodicity
906    cd_type   : Grid specification : T, U, V or F
[6666]907
[6064]908    See NEMO documentation for further details
[6190]909    '''
[6666]910    jpi, nperio = lbc_init (ptab, nperio)[1:]
911    ax = __find_axis__ (ptab, 'x')[0]
912    ay = __find_axis__ (ptab, 'y')[0]
[6190]913    ztab = ptab.copy ()
[6064]914
[6666]915    if ax :
916        #
917        #> East-West boundary conditions
918        # ------------------------------
919        if nperio in [1, 4, 6] :
[6064]920        # ... cyclic
[6666]921            ztab [...,  0] = sval
922            ztab [..., -1] = sval
[6064]923
[6666]924        if ay :
925            #
926            #> South (in which nperio cases ?)
927            # --------------------------------
928            if nperio in [1, 3, 4, 5, 6] :
929                ztab [..., 0, :] = sval
[6105]930
[6666]931            #
932            #> North-South boundary conditions
933            # --------------------------------
934            if nperio in [3, 4] :  # North fold T-point pivot
935                if cd_type in [ 'T', 'W' ] : # T-, W-point
936                    ztab [..., -1,  :         ] = sval
937                    ztab [..., -2, :jpi//2  ] = sval
[6064]938
[6666]939                if cd_type == 'U' :
940                    ztab [..., -1,  :         ] = sval
941                    ztab [..., -2, jpi//2+1:  ] = sval
942
943                if cd_type == 'V' :
944                    ztab [..., -2, :       ] = sval
945                    ztab [..., -1, :       ] = sval
946
947                if cd_type == 'F' :
948                    ztab [..., -2, :       ] = sval
949                    ztab [..., -1, :       ] = sval
950
951            if nperio in [4.2] :  # North fold T-point pivot
952                if cd_type in [ 'T', 'W' ] : # T-, W-point
953                    ztab [..., -1, jpi//2  :  ] = sval
954
955                if cd_type == 'U' :
956                    ztab [..., -1, jpi//2-1:-1] = sval
957
958                if cd_type == 'V' :
959                    ztab [..., -1, 1:       ] = sval
960
961                if cd_type == 'F' :
962                    ztab [..., -1, 0:-1     ] = sval
963
964            if nperio in [5, 6] :            #  North fold F-point pivot
965                if cd_type in ['T', 'W']  :
966                    ztab [..., -1, 0:       ] = sval
967
968                if cd_type == 'U' :
969                    ztab [..., -1, 0:-1     ] = sval
970                    ztab [..., -1, -1       ] = sval
971
972                if cd_type == 'V' :
973                    ztab [..., -1, 0:       ] = sval
974                    ztab [..., -2, jpi//2:  ] = sval
975
976                if cd_type == 'F' :
977                    ztab [..., -1, 0:-1       ] = sval
978                    ztab [..., -1, -1         ] = sval
979                    ztab [..., -2, jpi//2+1:-1] = sval
980
[6190]981    return ztab
[6064]982
983def lbc_plot (ptab, nperio=None, cd_type='T', psgn=1.0, sval=np.nan) :
[6666]984    '''Set periodicity on input field, for plotting for any cartopy projection
985
986      Points at the north fold are masked
987      Points for zonal periodicity are kept
[6190]988    ptab      : Input array. Rank 2 at least : ptab[...., lat, lon]
[6064]989    nperio    : Type of periodicity
990    cd_type   : Grid specification : T, U, V or F
[6666]991    psgn      : For change of sign for vector components
992           (1 for scalars, -1 for vector components)
993
[6064]994    See NEMO documentation for further details
[6190]995    '''
[6666]996    jpi, nperio = lbc_init (ptab, nperio)[1:]
997    ax = __find_axis__ (ptab, 'x')[0]
998    ay = __find_axis__ (ptab, 'y')[0]
[6064]999    psgn   = ptab.dtype.type (psgn)
[6190]1000    ztab   = ptab.copy ()
1001
[6666]1002    if ax :
1003        #
1004        #> East-West boundary conditions
1005        # ------------------------------
1006        if nperio in [1, 4, 6] :
1007            # ... cyclic
1008            ztab [..., :,  0] = ztab [..., :, -2]
1009            ztab [..., :, -1] = ztab [..., :,  1]
[6064]1010
[6666]1011        if ay :
1012            #> Masks south
1013            # ------------
1014            if nperio in [4, 6] :
1015                ztab [..., 0, : ] = sval
[6105]1016
[6666]1017            #
1018            #> North-South boundary conditions
1019            # --------------------------------
1020            if nperio in [3, 4] :  # North fold T-point pivot
1021                if cd_type in [ 'T', 'W' ] : # T-, W-point
1022                    ztab [..., -1,  :      ] = sval
1023                    #ztab [..., -2, jpi//2: ] = sval
1024                    ztab [..., -2, :jpi//2 ] = sval # Give better plots than above
1025                if cd_type == 'U' :
1026                    ztab [..., -1, : ] = sval
[6064]1027
[6666]1028                if cd_type == 'V' :
1029                    ztab [..., -2, : ] = sval
1030                    ztab [..., -1, : ] = sval
1031
1032                if cd_type == 'F' :
1033                    ztab [..., -2, : ] = sval
1034                    ztab [..., -1, : ] = sval
1035
1036            if nperio in [4.2] :  # North fold T-point pivot
1037                if cd_type in [ 'T', 'W' ] : # T-, W-point
1038                    ztab [..., -1, jpi//2:  ] = sval
1039
1040                if cd_type == 'U' :
1041                    ztab [..., -1, jpi//2-1:-1] = sval
1042
1043                if cd_type == 'V' :
1044                    ztab [..., -1, 1:       ] = sval
1045
1046                if cd_type == 'F' :
1047                    ztab [..., -1, 0:-1     ] = sval
1048
1049            if nperio in [5, 6] :            #  North fold F-point pivot
1050                if cd_type in ['T', 'W']  :
1051                    ztab [..., -1, : ] = sval
1052
1053                if cd_type == 'U' :
1054                    ztab [..., -1, : ] = sval
1055
1056                if cd_type == 'V' :
1057                    ztab [..., -1, :        ] = sval
1058                    ztab [..., -2, jpi//2:  ] = sval
1059
1060                if cd_type == 'F' :
1061                    ztab [..., -1, :          ] = sval
1062                    ztab [..., -2, jpi//2+1:-1] = sval
1063
[6190]1064    return ztab
1065
[6666]1066def lbc_add (ptab, nperio=None, cd_type=None, psgn=1) :
1067    '''Handles NEMO domain changes between NEMO 4.0 to NEMO 4.2
1068
1069    Periodicity and north fold halos has been removed in NEMO 4.2
[6190]1070    This routine adds the halos if needed
1071
[6666]1072    ptab      : Input array (works
[6190]1073      rank 2 at least : ptab[...., lat, lon]
1074    nperio    : Type of periodicity
[6666]1075
[6190]1076    See NEMO documentation for further details
1077    '''
[6666]1078    mmath = __mmath__ (ptab)
1079    nperio = lbc_init (ptab, nperio)[-1]
1080    lshape = get_shape (ptab)
1081    ix = __find_axis__ (ptab, 'x')[-1]
1082    jy = __find_axis__ (ptab, 'y')[-1]
[6190]1083
1084    t_shape = np.array (ptab.shape)
1085
[6666]1086    if nperio in [4.2, 6.2] :
[6190]1087
[6666]1088        ext_shape = t_shape.copy()
1089        if 'X' in lshape :
1090            ext_shape[ix] = ext_shape[ix] + 2
1091        if 'Y' in lshape :
1092            ext_shape[jy] = ext_shape[jy] + 1
1093
[6245]1094        if mmath == xr :
[6666]1095            ptab_ext = xr.DataArray (np.zeros (ext_shape), dims=ptab.dims)
1096            if 'X' in lshape and 'Y' in lshape :
1097                ptab_ext.values[..., :-1, 1:-1] = ptab.values.copy ()
1098            else :
1099                if 'X' in lshape     :
1100                    ptab_ext.values[...,      1:-1] = ptab.values.copy ()
1101                if 'Y' in lshape     :
1102                    ptab_ext.values[..., :-1      ] = ptab.values.copy ()
[6245]1103        else           :
1104            ptab_ext =               np.zeros (ext_shape)
[6666]1105            if 'X' in lshape and 'Y' in lshape :
1106                ptab_ext       [..., :-1, 1:-1] = ptab.copy ()
1107            else :
1108                if 'X' in lshape     :
1109                    ptab_ext       [...,      1:-1] = ptab.copy ()
1110                if 'Y' in lshape     :
1111                    ptab_ext       [..., :-1      ] = ptab.copy ()
1112
1113        if nperio == 4.2 :
1114            ptab_ext = lbc (ptab_ext, nperio=4, cd_type=cd_type, psgn=psgn)
1115        if nperio == 6.2 :
1116            ptab_ext = lbc (ptab_ext, nperio=6, cd_type=cd_type, psgn=psgn)
1117
[6245]1118        if mmath == xr :
1119            ptab_ext.attrs = ptab.attrs
[6666]1120            az = __find_axis__ (ptab, 'z')[0]
1121            at = __find_axis__ (ptab, 't')[0]
1122            if az :
1123                ptab_ext = ptab_ext.assign_coords ( {az:ptab.coords[az]} )
1124            if at :
1125                ptab_ext = ptab_ext.assign_coords ( {at:ptab.coords[at]} )
[6190]1126
[6245]1127    else : ptab_ext = lbc (ptab, nperio=nperio, cd_type=cd_type, psgn=psgn)
[6666]1128
[6245]1129    return ptab_ext
[6064]1130
[6245]1131def lbc_del (ptab, nperio=None, cd_type='T', psgn=1) :
[6666]1132    '''Handles NEMO domain changes between NEMO 4.0 to NEMO 4.2
1133
1134    Periodicity and north fold halos has been removed in NEMO 4.2
[6190]1135    This routine removes the halos if needed
1136
[6666]1137    ptab      : Input array (works
[6190]1138      rank 2 at least : ptab[...., lat, lon]
1139    nperio    : Type of periodicity
[6666]1140
[6190]1141    See NEMO documentation for further details
1142    '''
[6666]1143    nperio = lbc_init (ptab, nperio)[-1]
1144    #lshape = get_shape (ptab)
1145    ax = __find_axis__ (ptab, 'x')[0]
1146    ay = __find_axis__ (ptab, 'y')[0]
[6190]1147
[6666]1148    if nperio in [4.2, 6.2] :
1149        if ax or ay :
1150            if ax and ay :
1151                ztab = lbc (ptab[..., :-1, 1:-1],
1152                            nperio=nperio, cd_type=cd_type, psgn=psgn)
1153            else :
1154                if ax :
1155                    ztab = lbc (ptab[...,      1:-1],
1156                                nperio=nperio, cd_type=cd_type, psgn=psgn)
1157                if ay :
1158                    ztab = lbc (ptab[..., -1],
1159                                nperio=nperio, cd_type=cd_type, psgn=psgn)
1160        else :
1161            ztab = ptab
[6190]1162    else :
[6666]1163        ztab = ptab
[6190]1164
[6666]1165    return ztab
1166
[6190]1167def lbc_index (jj, ii, jpj, jpi, nperio=None, cd_type='T') :
[6666]1168    '''For indexes of a NEMO point, give the corresponding point
1169        inside the domain (i.e. not in the halo)
1170
[6190]1171    jj, ii    : indexes
1172    jpi, jpi  : size of domain
1173    nperio    : type of periodicity
1174    cd_type   : grid specification : T, U, V or F
[6666]1175
[6190]1176    See NEMO documentation for further details
1177    '''
1178
[6666]1179    if nperio is None :
1180        nperio = __guess_nperio__ (jpj, jpi, nperio)
[6190]1181
[6666]1182    ## For the sake of simplicity, switch to the convention of original
1183    ## lbc Fortran routine from NEMO : starts indexes at 1
1184    jy = jj + 1
1185    ix = ii + 1
1186
[6245]1187    mmath = __mmath__ (jj)
[6666]1188    if mmath is None :
1189        mmath=np
[6190]1190
1191    #
1192    #> East-West boundary conditions
1193    # ------------------------------
1194    if nperio in [1, 4, 6] :
1195        #... cyclic
[6245]1196        ix = mmath.where (ix==jpi, 2   , ix)
1197        ix = mmath.where (ix== 1 ,jpi-1, ix)
[6190]1198
1199    #
[6666]1200    def mod_ij (cond, jy_new, ix_new) :
[6245]1201        jy_r = mmath.where (cond, jy_new, jy)
1202        ix_r = mmath.where (cond, ix_new, ix)
[6190]1203        return jy_r, ix_r
1204    #
1205    #> North-South boundary conditions
1206    # --------------------------------
1207    if nperio in [ 3 , 4 ]  :
1208        if cd_type in  [ 'T' , 'W' ] :
[6666]1209            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix>=2       ), jpj-2, jpi-ix+2)
1210            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==1       ), jpj-1, 3       )
1211            jy, ix = mod_ij (np.logical_and (jy==jpj-1, ix>=jpi//2+1),
1212                                  jy , jpi-ix+2)
[6190]1213
1214        if cd_type in [ 'U' ] :
[6666]1215            jy, ix = mod_ij (np.logical_and (
1216                      jy==jpj  ,
1217                      np.logical_and (ix>=1, ix <= jpi-1)   ),
1218                                         jy   , jpi-ix+1)
1219            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==1  ) , jpj-2, 2       )
1220            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==jpi) , jpj-2, jpi-1   )
1221            jy, ix = mod_ij (np.logical_and (jy==jpj-1,
1222                            np.logical_and (ix>=jpi//2, ix<=jpi-1)), jy   , jpi-ix+1)
1223
[6190]1224        if cd_type in [ 'V' ] :
[6666]1225            jy, ix = mod_ij (np.logical_and (jy==jpj-1, ix>=2  ), jpj-2, jpi-ix+2)
1226            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix>=2  ), jpj-3, jpi-ix+2)
1227            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==1  ), jpj-3,  3      )
1228
[6190]1229        if cd_type in [ 'F' ] :
[6666]1230            jy, ix = mod_ij (np.logical_and (jy==jpj-1, ix<=jpi-1), jpj-2, jpi-ix+1)
1231            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix<=jpi-1), jpj-3, jpi-ix+1)
1232            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==1    ), jpj-3, 2       )
1233            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==jpi  ), jpj-3, jpi-1   )
[6190]1234
1235    if nperio in [ 5 , 6 ] :
1236        if cd_type in [ 'T' , 'W' ] :                        # T-, W-point
[6666]1237            jy, ix = mod_ij (jy==jpj, jpj-1, jpi-ix+1)
1238
[6190]1239        if cd_type in [ 'U' ] :                              # U-point
[6666]1240            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix<=jpi-1   ), jpj-1, jpi-ix  )
1241            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix==jpi     ), jpi-1, 1       )
1242
[6190]1243        if cd_type in [ 'V' ] :    # V-point
[6666]1244            jy, ix = mod_ij (jy==jpj                                 , jy   , jpi-ix+1)
1245            jy, ix = mod_ij (np.logical_and (jy==jpj-1, ix>=jpi//2+1), jy   , jpi-ix+1)
1246
[6190]1247        if cd_type in [ 'F' ] :                              # F-point
[6666]1248            jy, ix = mod_ij (np.logical_and (jy==jpj  , ix<=jpi-1   ), jpj-2, jpi-ix  )
1249            jy, ix = mod_ij (np.logical_and (ix==jpj  , ix==jpi     ), jpj-2, 1       )
1250            jy, ix = mod_ij (np.logical_and (jy==jpj-1, ix>=jpi//2+1), jy   , jpi-ix  )
[6190]1251
1252    ## Restore convention to Python/C : indexes start at 0
[6666]1253    jy += -1
1254    ix += -1
[6190]1255
[6666]1256    if isinstance (jj, int) :
1257        jy = jy.item ()
1258    if isinstance (ii, int) :
1259        ix = ix.item ()
[6190]1260
1261    return jy, ix
[6666]1262
1263def find_ji (lat_data, lon_data, lat_grid, lon_grid, mask=1.0, verbose=False, out=None) :
[6190]1264    '''
[6666]1265    Description: seeks J,I indices of the grid point which is the closest
1266       of a given point
[6064]1267
[6666]1268    Usage: go FindJI  <data latitude> <data longitude> <grid latitudes> <grid longitudes> [mask]
1269    <grid latitudes><grid longitudes> are 2D fields on J/I (Y/X) dimensions
1270    mask : if given, seek only non masked grid points (i.e with mask=1)
1271
1272    Example : findIJ (40, -20, nav_lat, nav_lon, mask=1.0)
1273
1274    Note : all longitudes and latitudes in degrees
1275
1276    Note : may work with 1D lon/lat (?)
1277    '''
1278    # Get grid dimensions
1279    if len (lon_grid.shape) == 2 :
1280        jpi = lon_grid.shape[-1]
1281    else                         :
1282        jpi = len(lon_grid)
1283
1284    #mmath = __mmath__ (lat_grid)
1285
1286    # Compute distance from the point to all grid points (in RADian)
1287    arg      = ( np.sin (RAD*lat_data) * np.sin (RAD*lat_grid)
1288               + np.cos (RAD*lat_data) * np.cos (RAD*lat_grid) *
1289                 np.cos(RAD*(lon_data-lon_grid)) )
1290    # Send masked points to 'infinite'
1291    distance = np.arccos (arg) + 4.0*RPI*(1.0-mask)
1292
1293    # Truncates to alleviate some precision problem with some grids
1294    prec = int (1E7)
1295    distance = (distance*prec).astype(int) / prec
1296
1297    # Compute minimum of distance, and index of minimum
1298    #
1299    #distance_min = distance.min    ()
1300    jimin        = int (distance.argmin ())
1301
1302    # Compute 2D indices (Python/C flavor : starting at 0)
1303    jmin = jimin // jpi
1304    imin = jimin - jmin*jpi
1305
1306    # Result
1307    if verbose :
1308        # Compute distance achieved
1309        #mindist = distance [jmin, imin]
1310
1311        # Compute azimuth
1312        dlon = lon_data-lon_grid[jmin,imin]
1313        arg  = np.sin (RAD*dlon) / (
1314                  np.cos(RAD*lat_data)*np.tan(RAD*lat_grid[jmin,imin])
1315                - np.sin(RAD*lat_data)*np.cos(RAD*dlon))
1316        azimuth = DAR*np.arctan (arg)
1317        print ( f'I={imin:d} J={jmin:d} - Data:{lat_data:5.1f}°N {lon_data:5.1f}°E' \
1318            +   f'- Grid:{lat_grid[jmin,imin]:4.1f}°N '   \
1319            +   f'{lon_grid[jmin,imin]:4.1f}°E - Dist: {RA*distance[jmin,imin]:6.1f}km' \
1320            +   f' {DAR*distance[jmin,imin]:5.2f}° ' \
1321            +   f'- Azimuth: {RAD*azimuth:3.2f}RAD - {azimuth:5.1f}°' )
1322
1323    if   out=='dict'                     :
1324        return {'x':imin, 'y':jmin}
1325    elif out in ['array', 'numpy', 'np'] :
1326        return np.array ( [jmin, imin] )
1327    elif out in ['xarray', 'xr']         :
1328        return xr.DataArray ( [jmin, imin] )
1329    elif out=='list'                     :
1330        return [jmin, imin]
1331    elif out=='tuple'                    :
1332        return jmin, imin
1333    else                                 :
1334        return jmin, imin
1335
1336def curl (tx, ty, e1f, e2f, nperio=None) :
1337    '''Returns curl of a vector field
1338    '''
1339    ax = __find_axis__ (tx, 'x')[0]
1340    ay = __find_axis__ (ty, 'y')[0]
1341
1342    tx_0  = lbc_add (tx , nperio=nperio, cd_type='U', psgn=-1)
1343    ty_0  = lbc_add (ty , nperio=nperio, cd_type='V', psgn=-1)
1344    e1f_0 = lbc_add (e1f, nperio=nperio, cd_type='U', psgn=-1)
1345    e2f_0 = lbc_add (e2f, nperio=nperio, cd_type='V', psgn=-1)
1346
1347    tx_1  = tx_0.roll ( {ay:-1} )
1348    ty_1  = ty_0.roll ( {ax:-1} )
1349    tx_1  = lbc (tx_1, nperio=nperio, cd_type='U', psgn=-1)
1350    ty_1  = lbc (ty_1, nperio=nperio, cd_type='V', psgn=-1)
1351
1352    zcurl = (ty_1 - ty_0)/e1f_0 - (tx_1 - tx_0)/e2f_0
1353
1354    mask = np.logical_or (
1355             np.logical_or ( np.isnan(tx_0), np.isnan(tx_1)),
1356             np.logical_or ( np.isnan(ty_0), np.isnan(ty_1)) )
1357
1358    zcurl = zcurl.where (np.logical_not (mask), np.nan)
1359
1360    zcurl = lbc_del (zcurl, nperio=nperio, cd_type='F', psgn=1)
1361    zcurl = lbc (zcurl, nperio=nperio, cd_type='F', psgn=1)
1362
1363    return zcurl
1364
1365def div (ux, uy, e1t, e2t, nperio=None) :
1366    '''Returns divergence of a vector field
1367    '''
1368    ax = __find_axis__ (ux, 'x')[0]
1369    ay = __find_axis__ (ux, 'y')[0]
1370
1371    ux_0  = lbc_add (ux , nperio=nperio, cd_type='U', psgn=-1)
1372    uy_0  = lbc_add (uy , nperio=nperio, cd_type='V', psgn=-1)
1373    e1t_0 = lbc_add (e1t, nperio=nperio, cd_type='U', psgn=-1)
1374    e2t_0 = lbc_add (e2t, nperio=nperio, cd_type='V', psgn=-1)
1375
1376    ux_1 = ux_0.roll ( {ay:1} )
1377    uy_1 = uy_0.roll ( {ax:1} )
1378    ux_1 = lbc (ux_1, nperio=nperio, cd_type='U', psgn=-1)
1379    uy_1 = lbc (uy_1, nperio=nperio, cd_type='V', psgn=-1)
1380
1381    zdiv = (ux_0 - ux_1)/e2t_0 + (uy_0 - uy_1)/e1t_0
1382
1383    mask = np.logical_or (
1384             np.logical_or ( np.isnan(ux_0), np.isnan(ux_1)),
1385             np.logical_or ( np.isnan(uy_0), np.isnan(uy_1)) )
1386
1387    zdiv = zdiv.where (np.logical_not (mask), np.nan)
1388
1389    zdiv = lbc_del (zdiv, nperio=nperio, cd_type='T', psgn=1)
1390    zdiv = lbc (zdiv, nperio=nperio, cd_type='T', psgn=1)
1391
1392    return zdiv
1393
1394def geo2en (pxx, pyy, pzz, glam, gphi) :
1395    '''Change vector from geocentric to east/north
1396
[6190]1397    Inputs :
1398        pxx, pyy, pzz : components on the geocentric system
1399        glam, gphi : longitude and latitude of the points
1400    '''
[6064]1401
[6666]1402    gsinlon = np.sin (RAD * glam)
1403    gcoslon = np.cos (RAD * glam)
1404    gsinlat = np.sin (RAD * gphi)
1405    gcoslat = np.cos (RAD * gphi)
1406
[6064]1407    pte = - pxx * gsinlon            + pyy * gcoslon
1408    ptn = - pxx * gcoslon * gsinlat  - pyy * gsinlon * gsinlat + pzz * gcoslat
1409
1410    return pte, ptn
1411
[6190]1412def en2geo (pte, ptn, glam, gphi) :
[6666]1413    '''Change vector from east/north to geocentric
[6190]1414
[6666]1415    Inputs :
[6190]1416        pte, ptn   : eastward/northward components
1417        glam, gphi : longitude and latitude of the points
1418    '''
[6064]1419
[6666]1420    gsinlon = np.sin (RAD * glam)
1421    gcoslon = np.cos (RAD * glam)
1422    gsinlat = np.sin (RAD * gphi)
1423    gcoslat = np.cos (RAD * gphi)
1424
[6064]1425    pxx = - pte * gsinlon - ptn * gcoslon * gsinlat
1426    pyy =   pte * gcoslon - ptn * gsinlon * gsinlat
1427    pzz =   ptn * gcoslat
[6666]1428
[6064]1429    return pxx, pyy, pzz
1430
[6666]1431
1432def clo_lon (lon, lon0=0., rad=False, deg=True) :
1433    '''Choose closest to lon0 longitude, adding/substacting 360°
1434    if needed
[6190]1435    '''
[6666]1436    mmath = __mmath__ (lon, np)
1437    if rad :
1438        lon_range = 2.*np.pi
1439    if deg :
1440        lon_range = 360.
1441    c_lon = lon
1442    c_lon = mmath.where (c_lon > lon0 + lon_range*0.5,
1443                             c_lon-lon_range, clo_lon)
1444    c_lon = mmath.where (c_lon < lon0 - lon_range*0.5,
1445                             c_lon+lon_range, clo_lon)
1446    c_lon = mmath.where (c_lon > lon0 + lon_range*0.5,
1447                             c_lon-lon_range, clo_lon)
1448    c_lon = mmath.where (c_lon < lon0 - lon_range*0.5,
1449                             c_lon+lon_range, clo_lon)
1450    if c_lon.shape == () :
1451        c_lon = c_lon.item ()
1452    if mmath == xr :
1453        if lon.attrs :
1454            c_lon.attrs.update ( lon.attrs )
1455    return c_lon
[6190]1456
[6666]1457def index2depth (pk, gdept_0) :
1458    '''From index (real, continuous), get depth
1459
1460    Needed to use transforms in Matplotlib
[6190]1461    '''
[6666]1462    jpk = gdept_0.shape[0]
1463    kk = xr.DataArray(pk)
1464    k  = np.maximum (0, np.minimum (jpk-1, kk    ))
1465    k0 = np.floor (k).astype (int)
1466    k1 = np.maximum (0, np.minimum (jpk-1,  k0+1))
1467    zz = k - k0
1468    gz = (1.0-zz)*gdept_0[k0]+ zz*gdept_0[k1]
1469    return gz.values
[6190]1470
[6666]1471def depth2index (pz, gdept_0) :
1472    '''From depth, get index (real, continuous)
[6190]1473
[6666]1474    Needed to use transforms in Matplotlib
1475    '''
1476    jpk  = gdept_0.shape[0]
1477    if   isinstance (pz, xr.core.dataarray.DataArray ) :
1478        zz   = xr.DataArray (pz.values, dims=('zz',))
1479    elif isinstance (pz,  np.ndarray) :
1480        zz   = xr.DataArray (pz.ravel(), dims=('zz',))
1481    else :
1482        zz   = xr.DataArray (np.array([pz]).ravel(), dims=('zz',))
1483    zz   = np.minimum (gdept_0[-1], np.maximum (0, zz))
[6190]1484
[6666]1485    idk1 = np.minimum ( (gdept_0-zz), 0.).argmax (axis=0).astype(int)
1486    idk1 = np.maximum (0, np.minimum (jpk-1,  idk1  ))
1487    idk2 = np.maximum (0, np.minimum (jpk-1,  idk1-1))
[6190]1488
[6666]1489    zff = (zz - gdept_0[idk2])/(gdept_0[idk1]-gdept_0[idk2])
1490    idk = zff*idk1 + (1.0-zff)*idk2
1491    idk = xr.where ( np.isnan(idk), idk1, idk)
1492    return idk.values
[6190]1493
[6666]1494def index2depth_panels (pk, gdept_0, depth0, fact) :
1495    '''From  index (real, continuous), get depth, with bottom part compressed
[6190]1496
[6666]1497    Needed to use transforms in Matplotlib
1498    '''
1499    jpk = gdept_0.shape[0]
1500    kk = xr.DataArray (pk)
1501    k  = np.maximum (0, np.minimum (jpk-1, kk    ))
1502    k0 = np.floor (k).astype (int)
1503    k1 = np.maximum (0, np.minimum (jpk-1,  k0+1))
1504    zz = k - k0
1505    gz = (1.0-zz)*gdept_0[k0]+ zz*gdept_0[k1]
1506    gz = xr.where ( gz<depth0, gz, depth0 + (gz-depth0)*fact)
1507    return gz.values
[6190]1508
[6666]1509def depth2index_panels (pz, gdept_0, depth0, fact) :
1510    '''From  index (real, continuous), get depth, with bottom part compressed
1511
1512    Needed to use transforms in Matplotlib
1513    '''
1514    jpk = gdept_0.shape[0]
1515    if isinstance (pz, xr.core.dataarray.DataArray) :
1516        zz   = xr.DataArray (pz.values , dims=('zz',))
1517    elif isinstance (pz, np.ndarray) :
1518        zz   = xr.DataArray (pz.ravel(), dims=('zz',))
1519    else :
1520        zz   = xr.DataArray (np.array([pz]).ravel(), dims=('zz',))
1521    zz         = np.minimum (gdept_0[-1], np.maximum (0, zz))
1522    gdept_comp = xr.where ( gdept_0>depth0,
1523                                (gdept_0-depth0)*fact+depth0, gdept_0)
1524    zz_comp    = xr.where ( zz     >depth0, (zz     -depth0)*fact+depth0,
1525                                zz     )
1526    zz_comp    = np.minimum (gdept_comp[-1], np.maximum (0, zz_comp))
1527
1528    idk1 = np.minimum ( (gdept_0-zz_comp), 0.).argmax (axis=0).astype(int)
1529    idk1 = np.maximum (0, np.minimum (jpk-1,  idk1  ))
1530    idk2 = np.maximum (0, np.minimum (jpk-1,  idk1-1))
1531
1532    zff = (zz_comp - gdept_0[idk2])/(gdept_0[idk1]-gdept_0[idk2])
1533    idk = zff*idk1 + (1.0-zff)*idk2
1534    idk = xr.where ( np.isnan(idk), idk1, idk)
1535    return idk.values
1536
1537def depth2comp (pz, depth0, fact ) :
1538    '''Form depth, get compressed depth, with bottom part compressed
1539
1540    Needed to use transforms in Matplotlib
1541    '''
1542    #print ('start depth2comp')
1543    if isinstance (pz, xr.core.dataarray.DataArray) :
1544        zz   = pz.values
1545    elif isinstance(pz, list) :
1546        zz = np.array (pz)
1547    else :
1548        zz   = pz
1549    gz = np.where ( zz>depth0, (zz-depth0)*fact+depth0, zz)
1550    #print ( f'depth2comp : {gz=}' )
1551    if type (pz) in [int, float] :
1552        return gz.item()
1553    else :
1554        return gz
1555
1556def comp2depth (pz, depth0, fact ) :
1557    '''Form compressed depth, get depth, with bottom part compressed
1558
1559    Needed to use transforms in Matplotlib
1560    '''
1561    if isinstance (pz, xr.core.dataarray.DataArray) :
1562        zz   = pz.values
1563    elif isinstance (pz, list) :
1564        zz = np.array (pz)
1565    else :
1566        zz   = pz
1567    gz = np.where ( zz>depth0, (zz-depth0)/fact+depth0, zz)
1568    if type (pz) in [int, float] :
1569        gz = gz.item()
1570
1571    return gz
1572
1573def index2lon (pi, plon_1d) :
1574    '''From index (real, continuous), get longitude
1575
1576    Needed to use transforms in Matplotlib
1577    '''
1578    jpi = plon_1d.shape[0]
1579    ii = xr.DataArray (pi)
1580    i =  np.maximum (0, np.minimum (jpi-1, ii    ))
1581    i0 = np.floor (i).astype (int)
1582    i1 = np.maximum (0, np.minimum (jpi-1,  i0+1))
1583    xx = i - i0
1584    gx = (1.0-xx)*plon_1d[i0]+ xx*plon_1d[i1]
1585    return gx.values
1586
1587def lon2index (px, plon_1d) :
1588    '''From longitude, get index (real, continuous)
1589
1590    Needed to use transforms in Matplotlib
1591    '''
1592    jpi  = plon_1d.shape[0]
1593    if isinstance (px, xr.core.dataarray.DataArray) :
1594        xx   = xr.DataArray (px.values , dims=('xx',))
1595    elif isinstance (px, np.ndarray) :
1596        xx   = xr.DataArray (px.ravel(), dims=('xx',))
1597    else :
1598        xx   = xr.DataArray (np.array([px]).ravel(), dims=('xx',))
1599    xx   = xr.where ( xx>plon_1d.max(), xx-360.0, xx)
1600    xx   = xr.where ( xx<plon_1d.min(), xx+360.0, xx)
1601    xx   = np.minimum (plon_1d.max(), np.maximum(xx, plon_1d.min() ))
1602    idi1 = np.minimum ( (plon_1d-xx), 0.).argmax (axis=0).astype(int)
1603    idi1 = np.maximum (0, np.minimum (jpi-1,  idi1  ))
1604    idi2 = np.maximum (0, np.minimum (jpi-1,  idi1-1))
1605
1606    zff = (xx - plon_1d[idi2])/(plon_1d[idi1]-plon_1d[idi2])
1607    idi = zff*idi1 + (1.0-zff)*idi2
1608    idi = xr.where ( np.isnan(idi), idi1, idi)
1609    return idi.values
1610
1611def index2lat (pj, plat_1d) :
1612    '''From index (real, continuous), get latitude
1613
1614    Needed to use transforms in Matplotlib
1615    '''
1616    jpj = plat_1d.shape[0]
1617    jj  = xr.DataArray (pj)
1618    j   = np.maximum (0, np.minimum (jpj-1, jj    ))
1619    j0  = np.floor (j).astype (int)
1620    j1  = np.maximum (0, np.minimum (jpj-1,  j0+1))
1621    yy  = j - j0
1622    gy  = (1.0-yy)*plat_1d[j0]+ yy*plat_1d[j1]
1623    return gy.values
1624
1625def lat2index (py, plat_1d) :
1626    '''From latitude, get index (real, continuous)
1627
1628    Needed to use transforms in Matplotlib
1629    '''
1630    jpj = plat_1d.shape[0]
1631    if isinstance (py, xr.core.dataarray.DataArray) :
1632        yy   = xr.DataArray (py.values , dims=('yy',))
1633    elif isinstance (py, np.ndarray) :
1634        yy   = xr.DataArray (py.ravel(), dims=('yy',))
1635    else :
1636        yy   = xr.DataArray (np.array([py]).ravel(), dims=('yy',))
1637    yy   = np.minimum (plat_1d.max(), np.minimum(yy, plat_1d.max() ))
1638    idj1 = np.minimum ( (plat_1d-yy), 0.).argmax (axis=0).astype(int)
1639    idj1 = np.maximum (0, np.minimum (jpj-1,  idj1  ))
1640    idj2 = np.maximum (0, np.minimum (jpj-1,  idj1-1))
1641
1642    zff = (yy - plat_1d[idj2])/(plat_1d[idj1]-plat_1d[idj2])
1643    idj = zff*idj1 + (1.0-zff)*idj2
1644    idj = xr.where ( np.isnan(idj), idj1, idj)
1645    return idj.values
1646
1647def angle_full (glamt, gphit, glamu, gphiu, glamv, gphiv,
1648                glamf, gphif, nperio=None) :
1649    '''Computes sinus and cosinus of model line direction with
1650    respect to east
1651    '''
[6245]1652    mmath = __mmath__ (glamt)
1653
1654    zlamt = lbc_add (glamt, nperio, 'T', 1.)
1655    zphit = lbc_add (gphit, nperio, 'T', 1.)
1656    zlamu = lbc_add (glamu, nperio, 'U', 1.)
1657    zphiu = lbc_add (gphiu, nperio, 'U', 1.)
1658    zlamv = lbc_add (glamv, nperio, 'V', 1.)
1659    zphiv = lbc_add (gphiv, nperio, 'V', 1.)
1660    zlamf = lbc_add (glamf, nperio, 'F', 1.)
1661    zphif = lbc_add (gphif, nperio, 'F', 1.)
[6666]1662
[6190]1663    # north pole direction & modulous (at T-point)
[6666]1664    zxnpt = 0. - 2.0 * np.cos (RAD*zlamt) * np.tan (RPI/4.0 - RAD*zphit/2.0)
1665    zynpt = 0. - 2.0 * np.sin (RAD*zlamt) * np.tan (RPI/4.0 - RAD*zphit/2.0)
[6190]1666    znnpt = zxnpt*zxnpt + zynpt*zynpt
[6666]1667
[6190]1668    # north pole direction & modulous (at U-point)
[6666]1669    zxnpu = 0. - 2.0 * np.cos (RAD*zlamu) * np.tan (RPI/4.0 - RAD*zphiu/2.0)
1670    zynpu = 0. - 2.0 * np.sin (RAD*zlamu) * np.tan (RPI/4.0 - RAD*zphiu/2.0)
[6190]1671    znnpu = zxnpu*zxnpu + zynpu*zynpu
[6666]1672
[6190]1673    # north pole direction & modulous (at V-point)
[6666]1674    zxnpv = 0. - 2.0 * np.cos (RAD*zlamv) * np.tan (RPI/4.0 - RAD*zphiv/2.0)
1675    zynpv = 0. - 2.0 * np.sin (RAD*zlamv) * np.tan (RPI/4.0 - RAD*zphiv/2.0)
[6190]1676    znnpv = zxnpv*zxnpv + zynpv*zynpv
1677
1678    # north pole direction & modulous (at F-point)
[6666]1679    zxnpf = 0. - 2.0 * np.cos( RAD*zlamf ) * np.tan ( RPI/4. - RAD*zphif/2. )
1680    zynpf = 0. - 2.0 * np.sin( RAD*zlamf ) * np.tan ( RPI/4. - RAD*zphif/2. )
[6190]1681    znnpf = zxnpf*zxnpf + zynpf*zynpf
1682
1683    # j-direction: v-point segment direction (around T-point)
[6666]1684    zlam = zlamv
[6245]1685    zphi = zphiv
1686    zlan = np.roll ( zlamv, axis=-2, shift=1)  # glamv (ji,jj-1)
1687    zphh = np.roll ( zphiv, axis=-2, shift=1)  # gphiv (ji,jj-1)
[6666]1688    zxvvt =  2.0 * np.cos ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1689          -  2.0 * np.cos ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1690    zyvvt =  2.0 * np.sin ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1691          -  2.0 * np.sin ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
[6190]1692    znvvt = np.sqrt ( znnpt * ( zxvvt*zxvvt + zyvvt*zyvvt )  )
1693
1694    # j-direction: f-point segment direction (around u-point)
[6245]1695    zlam = zlamf
1696    zphi = zphif
1697    zlan = np.roll (zlamf, axis=-2, shift=1) # glamf (ji,jj-1)
1698    zphh = np.roll (zphif, axis=-2, shift=1) # gphif (ji,jj-1)
[6666]1699    zxffu =  2.0 * np.cos ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1700          -  2.0 * np.cos ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1701    zyffu =  2.0 * np.sin ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1702          -  2.0 * np.sin ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
[6190]1703    znffu = np.sqrt ( znnpu * ( zxffu*zxffu + zyffu*zyffu )  )
1704
1705    # i-direction: f-point segment direction (around v-point)
[6666]1706    zlam = zlamf
[6245]1707    zphi = zphif
1708    zlan = np.roll (zlamf, axis=-1, shift=1) # glamf (ji-1,jj)
1709    zphh = np.roll (zphif, axis=-1, shift=1) # gphif (ji-1,jj)
[6666]1710    zxffv =  2.0 * np.cos ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1711          -  2.0 * np.cos ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1712    zyffv =  2.0 * np.sin ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1713          -  2.0 * np.sin ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
[6190]1714    znffv = np.sqrt ( znnpv * ( zxffv*zxffv + zyffv*zyffv )  )
1715
1716    # j-direction: u-point segment direction (around f-point)
[6666]1717    zlam = np.roll (zlamu, axis=-2, shift=-1) # glamu (ji,jj+1)
[6245]1718    zphi = np.roll (zphiu, axis=-2, shift=-1) # gphiu (ji,jj+1)
1719    zlan = zlamu
1720    zphh = zphiu
[6666]1721    zxuuf =  2. * np.cos ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1722          -  2. * np.cos ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
1723    zyuuf =  2. * np.sin ( RAD*zlam ) * np.tan ( RPI/4. - RAD*zphi/2. ) \
1724          -  2. * np.sin ( RAD*zlan ) * np.tan ( RPI/4. - RAD*zphh/2. )
[6190]1725    znuuf = np.sqrt ( znnpf * ( zxuuf*zxuuf + zyuuf*zyuuf )  )
1726
[6666]1727
[6190]1728    # cosinus and sinus using scalar and vectorial products
1729    gsint = ( zxnpt*zyvvt - zynpt*zxvvt ) / znvvt
1730    gcost = ( zxnpt*zxvvt + zynpt*zyvvt ) / znvvt
[6666]1731
[6190]1732    gsinu = ( zxnpu*zyffu - zynpu*zxffu ) / znffu
1733    gcosu = ( zxnpu*zxffu + zynpu*zyffu ) / znffu
[6666]1734
[6190]1735    gsinf = ( zxnpf*zyuuf - zynpf*zxuuf ) / znuuf
1736    gcosf = ( zxnpf*zxuuf + zynpf*zyuuf ) / znuuf
[6666]1737
[6190]1738    gsinv = ( zxnpv*zxffv + zynpv*zyffv ) / znffv
[6666]1739    # (caution, rotation of 90 degres)
1740    gcosv =-( zxnpv*zyffv - zynpv*zxffv ) / znffv
[6190]1741
[6245]1742    gsint = lbc_del (gsint, cd_type='T', nperio=nperio, psgn=-1.)
1743    gcost = lbc_del (gcost, cd_type='T', nperio=nperio, psgn=-1.)
1744    gsinu = lbc_del (gsinu, cd_type='U', nperio=nperio, psgn=-1.)
1745    gcosu = lbc_del (gcosu, cd_type='U', nperio=nperio, psgn=-1.)
1746    gsinv = lbc_del (gsinv, cd_type='V', nperio=nperio, psgn=-1.)
1747    gcosv = lbc_del (gcosv, cd_type='V', nperio=nperio, psgn=-1.)
1748    gsinf = lbc_del (gsinf, cd_type='F', nperio=nperio, psgn=-1.)
1749    gcosf = lbc_del (gcosf, cd_type='F', nperio=nperio, psgn=-1.)
1750
1751    if mmath == xr :
[6190]1752        gsint = gsint.assign_coords ( glamt.coords )
1753        gcost = gcost.assign_coords ( glamt.coords )
1754        gsinu = gsinu.assign_coords ( glamu.coords )
1755        gcosu = gcosu.assign_coords ( glamu.coords )
1756        gsinv = gsinv.assign_coords ( glamv.coords )
1757        gcosv = gcosv.assign_coords ( glamv.coords )
1758        gsinf = gsinf.assign_coords ( glamf.coords )
1759        gcosf = gcosf.assign_coords ( glamf.coords )
1760
1761    return gsint, gcost, gsinu, gcosu, gsinv, gcosv, gsinf, gcosf
1762
1763def angle (glam, gphi, nperio, cd_type='T') :
[6666]1764    '''Computes sinus and cosinus of model line direction with
1765    respect to east
1766    '''
[6245]1767    mmath = __mmath__ (glam)
1768
1769    zlam = lbc_add (glam, nperio, cd_type, 1.)
1770    zphi = lbc_add (gphi, nperio, cd_type, 1.)
[6666]1771
[6190]1772    # north pole direction & modulous
[6666]1773    zxnp = 0. - 2.0 * np.cos (RAD*zlam) * np.tan (RPI/4.0 - RAD*zphi/2.0)
1774    zynp = 0. - 2.0 * np.sin (RAD*zlam) * np.tan (RPI/4.0 - RAD*zphi/2.0)
[6190]1775    znnp = zxnp*zxnp + zynp*zynp
1776
1777    # j-direction: segment direction (around point)
[6245]1778    zlan_n = np.roll (zlam, axis=-2, shift=-1) # glam [jj+1, ji]
1779    zphh_n = np.roll (zphi, axis=-2, shift=-1) # gphi [jj+1, ji]
1780    zlan_s = np.roll (zlam, axis=-2, shift= 1) # glam [jj-1, ji]
1781    zphh_s = np.roll (zphi, axis=-2, shift= 1) # gphi [jj-1, ji]
[6666]1782
1783    zxff = 2.0 * np.cos (RAD*zlan_n) * np.tan (RPI/4.0 - RAD*zphh_n/2.0) \
1784        -  2.0 * np.cos (RAD*zlan_s) * np.tan (RPI/4.0 - RAD*zphh_s/2.0)
1785    zyff = 2.0 * np.sin (RAD*zlan_n) * np.tan (RPI/4.0 - RAD*zphh_n/2.0) \
1786        -  2.0 * np.sin (RAD*zlan_s) * np.tan (RPI/4.0 - RAD*zphh_s/2.0)
[6190]1787    znff = np.sqrt (znnp * (zxff*zxff + zyff*zyff) )
[6666]1788
[6190]1789    gsin = (zxnp*zyff - zynp*zxff) / znff
1790    gcos = (zxnp*zxff + zynp*zyff) / znff
1791
[6245]1792    gsin = lbc_del (gsin, cd_type=cd_type, nperio=nperio, psgn=-1.)
1793    gcos = lbc_del (gcos, cd_type=cd_type, nperio=nperio, psgn=-1.)
[6190]1794
[6245]1795    if mmath == xr :
[6190]1796        gsin = gsin.assign_coords ( glam.coords )
1797        gcos = gcos.assign_coords ( glam.coords )
[6666]1798
[6190]1799    return gsin, gcos
1800
1801def rot_en2ij ( u_e, v_n, gsin, gcos, nperio, cd_type ) :
[6666]1802    '''Rotates the Repere: Change vector componantes between
[6190]1803    geographic grid --> stretched coordinates grid.
[6666]1804
1805    All components are on the same grid (T, U, V or F)
[6190]1806    '''
1807
1808    u_i = + u_e * gcos + v_n * gsin
1809    v_j = - u_e * gsin + v_n * gcos
[6666]1810
[6190]1811    u_i = lbc (u_i, nperio=nperio, cd_type=cd_type, psgn=-1.0)
1812    v_j = lbc (v_j, nperio=nperio, cd_type=cd_type, psgn=-1.0)
[6666]1813
[6190]1814    return u_i, v_j
1815
1816def rot_ij2en ( u_i, v_j, gsin, gcos, nperio, cd_type='T' ) :
[6666]1817    '''Rotates the Repere: Change vector componantes from
[6190]1818    stretched coordinates grid --> geographic grid
[6666]1819
1820    All components are on the same grid (T, U, V or F)
[6190]1821    '''
1822    u_e = + u_i * gcos - v_j * gsin
1823    v_n = + u_i * gsin + v_j * gcos
[6666]1824
1825    u_e = lbc (u_e, nperio=nperio, cd_type=cd_type, psgn=1.0)
1826    v_n = lbc (v_n, nperio=nperio, cd_type=cd_type, psgn=1.0)
1827
[6190]1828    return u_e, v_n
1829
[6666]1830def rot_uv2en ( uo, vo, gsint, gcost, nperio, zdim=None ) :
1831    '''Rotate the Repere: Change vector componantes from
[6190]1832    stretched coordinates grid --> geographic grid
[6666]1833
1834    uo : velocity along i at the U grid point
1835    vo : valocity along j at the V grid point
1836   
1837    Returns east-north components on the T grid point
[6190]1838    '''
[6666]1839    ut = u2t (uo, nperio=nperio, psgn=-1.0, zdim=zdim)
1840    vt = v2t (vo, nperio=nperio, psgn=-1.0, zdim=zdim)
[6190]1841
1842    u_e = + ut * gcost - vt * gsint
1843    v_n = + ut * gsint + vt * gcost
1844
1845    u_e = lbc (u_e, nperio=nperio, cd_type='T', psgn=1.0)
1846    v_n = lbc (v_n, nperio=nperio, cd_type='T', psgn=1.0)
[6666]1847
[6190]1848    return u_e, v_n
1849
[6666]1850def rot_uv2enf ( uo, vo, gsinf, gcosf, nperio, zdim=None ) :
1851    '''Rotates the Repere: Change vector componantes from
[6190]1852    stretched coordinates grid --> geographic grid
[6666]1853
1854    uo : velocity along i at the U grid point
1855    vo : valocity along j at the V grid point
1856   
1857    Returns east-north components on the F grid point
[6190]1858    '''
[6666]1859    uf = u2f (uo, nperio=nperio, psgn=-1.0, zdim=zdim)
1860    vf = v2f (vo, nperio=nperio, psgn=-1.0, zdim=zdim)
[6190]1861
1862    u_e = + uf * gcosf - vf * gsinf
1863    v_n = + uf * gsinf + vf * gcosf
1864
1865    u_e = lbc (u_e, nperio=nperio, cd_type='F', psgn= 1.0)
1866    v_n = lbc (v_n, nperio=nperio, cd_type='F', psgn= 1.0)
[6666]1867
[6190]1868    return u_e, v_n
[6245]1869
[6666]1870def u2t (utab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
1871    '''Interpolates an array from U grid to T grid (i-mean)
1872    '''
[6245]1873    mmath = __mmath__ (utab)
1874    utab_0 = mmath.where ( np.isnan(utab), 0., utab)
[6666]1875    #lperio, aperio = lbc_diag (nperio)
[6245]1876    utab_0 = lbc_add (utab_0, nperio=nperio, cd_type='U', psgn=psgn)
[6666]1877    ax, ix = __find_axis__ (utab_0, 'x')
1878    az     = __find_axis__ (utab_0, 'z')[0]
1879
1880    if ax :
1881        if action == 'ave' :
1882            ttab = 0.5 *      (utab_0 + np.roll (utab_0, axis=ix, shift=1))
1883        if action == 'min' :
1884            ttab = np.minimum (utab_0 , np.roll (utab_0, axis=ix, shift=1))
1885        if action == 'max' :
1886            ttab = np.maximum (utab_0 , np.roll (utab_0, axis=ix, shift=1))
1887        if action == 'mult':
1888            ttab =             utab_0 * np.roll (utab_0, axis=ix, shift=1)
1889        ttab = lbc_del (ttab  , nperio=nperio, cd_type='T', psgn=psgn)
1890    else :
1891        ttab = lbc_del (utab_0, nperio=nperio, cd_type='T', psgn=psgn)
1892
[6245]1893    if mmath == xr :
[6666]1894        if ax :
[6190]1895            ttab = ttab.assign_coords({ax:np.arange (ttab.shape[ix])+1.})
[6666]1896        if zdim and az :
1897            if az != zdim :
1898                ttab = ttab.rename( {az:zdim})
[6190]1899    return ttab
1900
[6666]1901def v2t (vtab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
1902    '''Interpolates an array from V grid to T grid (j-mean)
1903    '''
[6245]1904    mmath = __mmath__ (vtab)
[6666]1905    #lperio, aperio = lbc_diag (nperio)
[6245]1906    vtab_0 = mmath.where ( np.isnan(vtab), 0., vtab)
1907    vtab_0 = lbc_add (vtab_0, nperio=nperio, cd_type='V', psgn=psgn)
[6666]1908    ay, jy = __find_axis__ (vtab_0, 'y')
1909    az     = __find_axis__ (vtab_0, 'z')[0]
1910    if ay :
1911        if action == 'ave'  :
1912            ttab = 0.5 *      (vtab_0 + np.roll (vtab_0, axis=jy, shift=1))
1913        if action == 'min'  :
1914            ttab = np.minimum (vtab_0 , np.roll (vtab_0, axis=jy, shift=1))
1915        if action == 'max'  :
1916            ttab = np.maximum (vtab_0 , np.roll (vtab_0, axis=jy, shift=1))
1917        if action == 'mult' :
1918            ttab =             vtab_0 * np.roll (vtab_0, axis=jy, shift=1)
1919        ttab = lbc_del (ttab  , nperio=nperio, cd_type='T', psgn=psgn)
1920    else :
1921        ttab = lbc_del (vtab_0, nperio=nperio, cd_type='T', psgn=psgn)
1922
[6245]1923    if mmath == xr :
[6666]1924        if ay :
1925            ttab = ttab.assign_coords({ay:np.arange(ttab.shape[jy])+1.})
1926        if zdim and az :
1927            if az != zdim :
1928                ttab = ttab.rename( {az:zdim})
[6190]1929    return ttab
1930
[6666]1931def f2t (ftab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1932    '''Interpolates an array from F grid to T grid (i- and j- means)
1933    '''
[6245]1934    mmath = __mmath__ (ftab)
1935    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
1936    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
[6666]1937    ttab = v2t (f2v (ftab_0, nperio=nperio, psgn=psgn, zdim=zdim, action=action),
1938                     nperio=nperio, psgn=psgn, zdim=zdim, action=action)
[6245]1939    return lbc_del (ttab, nperio=nperio, cd_type='T', psgn=psgn)
[6190]1940
[6666]1941def t2u (ttab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1942    '''Interpolates an array from T grid to U grid (i-mean)
1943    '''
[6245]1944    mmath = __mmath__ (ttab)
1945    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1946    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
[6666]1947    ax, ix = __find_axis__ (ttab_0, 'x')[0]
1948    az     = __find_axis__ (ttab_0, 'z')
1949    if ix :
1950        if action == 'ave'  :
1951            utab = 0.5 *      (ttab_0 + np.roll (ttab_0, axis=ix, shift=-1))
1952        if action == 'min'  :
1953            utab = np.minimum (ttab_0 , np.roll (ttab_0, axis=ix, shift=-1))
1954        if action == 'max'  :
1955            utab = np.maximum (ttab_0 , np.roll (ttab_0, axis=ix, shift=-1))
1956        if action == 'mult' :
1957            utab =             ttab_0 * np.roll (ttab_0, axis=ix, shift=-1)
1958        utab = lbc_del (utab  , nperio=nperio, cd_type='U', psgn=psgn)
1959    else :
1960        utab = lbc_del (ttab_0, nperio=nperio, cd_type='U', psgn=psgn)
[6190]1961
[6666]1962    if mmath == xr :
1963        if ax :
[6190]1964            utab = ttab.assign_coords({ax:np.arange(utab.shape[ix])+1.})
[6666]1965        if zdim and az :
1966            if az != zdim :
1967                utab = utab.rename( {az:zdim})
[6190]1968    return utab
1969
[6666]1970def t2v (ttab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1971    '''Interpolates an array from T grid to V grid (j-mean)
1972    '''
[6245]1973    mmath = __mmath__ (ttab)
1974    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1975    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
[6666]1976    ay, jy = __find_axis__ (ttab_0, 'y')
1977    az     = __find_axis__ (ttab_0, 'z')[0]
1978    if jy :
1979        if action == 'ave'  :
1980            vtab = 0.5 *      (ttab_0 + np.roll (ttab_0, axis=jy, shift=-1))
1981        if action == 'min'  :
1982            vtab = np.minimum (ttab_0 , np.roll (ttab_0, axis=jy, shift=-1))
1983        if action == 'max'  :
1984            vtab = np.maximum (ttab_0 , np.roll (ttab_0, axis=jy, shift=-1))
1985        if action == 'mult' :
1986            vtab =             ttab_0 * np.roll (ttab_0, axis=jy, shift=-1)
1987        vtab = lbc_del (vtab  , nperio=nperio, cd_type='V', psgn=psgn)
1988    else :
1989        vtab = lbc_del (ttab_0, nperio=nperio, cd_type='V', psgn=psgn)
[6190]1990
[6245]1991    if mmath == xr :
[6666]1992        if ay :
1993            vtab = vtab.assign_coords({ay:np.arange(vtab.shape[jy])+1.})
1994        if zdim and az :
1995            if az != zdim :
1996                vtab = vtab.rename( {az:zdim})
[6190]1997    return vtab
1998
[6666]1999def v2f (vtab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
2000    '''Interpolates an array from V grid to F grid (i-mean)
2001    '''
[6245]2002    mmath = __mmath__ (vtab)
2003    vtab_0 = mmath.where ( np.isnan(vtab), 0., vtab)
2004    vtab_0 = lbc_add (vtab_0 , nperio=nperio, cd_type='V', psgn=psgn)
[6666]2005    ax, ix = __find_axis__ (vtab_0, 'x')
2006    az     = __find_axis__ (vtab_0, 'z')[0]
2007    if ix :
2008        if action == 'ave'  :
2009            ftab = 0.5 *      (vtab_0 + np.roll (vtab_0, axis=ix, shift=-1))
2010        if action == 'min'  :
2011            ftab = np.minimum (vtab_0 , np.roll (vtab_0, axis=ix, shift=-1))
2012        if action == 'max'  :
2013            ftab = np.maximum (vtab_0 , np.roll (vtab_0, axis=ix, shift=-1))
2014        if action == 'mult' :
2015            ftab =             vtab_0 * np.roll (vtab_0, axis=ix, shift=-1)
2016        ftab = lbc_del (ftab  , nperio=nperio, cd_type='F', psgn=psgn)
2017    else :
2018        ftab = lbc_del (vtab_0, nperio=nperio, cd_type='F', psgn=psgn)
2019
[6245]2020    if mmath == xr :
[6666]2021        if ax :
[6190]2022            ftab = ftab.assign_coords({ax:np.arange(ftab.shape[ix])+1.})
[6666]2023        if zdim and az :
2024            if az != zdim :
2025                ftab = ftab.rename( {az:zdim})
[6245]2026    return lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
[6190]2027
[6666]2028def u2f (utab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
2029    '''Interpolates an array from U grid to F grid i-mean)
2030    '''
[6245]2031    mmath = __mmath__ (utab)
2032    utab_0 = mmath.where ( np.isnan(utab), 0., utab)
2033    utab_0 = lbc_add (utab_0 , nperio=nperio, cd_type='U', psgn=psgn)
[6666]2034    ay, jy = __find_axis__ (utab_0, 'y')
2035    az     = __find_axis__ (utab_0, 'z')[0]
2036    if jy :
2037        if action == 'ave'  :
2038            ftab = 0.5 *      (utab_0 + np.roll (utab_0, axis=jy, shift=-1))
2039        if action == 'min'  :
2040            ftab = np.minimum (utab_0 , np.roll (utab_0, axis=jy, shift=-1))
2041        if action == 'max'  :
2042            ftab = np.maximum (utab_0 , np.roll (utab_0, axis=jy, shift=-1))
2043        if action == 'mult' :
2044            ftab =             utab_0 * np.roll (utab_0, axis=jy, shift=-1)
2045        ftab = lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
2046    else :
2047        ftab = lbc_del (utab_0, nperio=nperio, cd_type='F', psgn=psgn)
[6190]2048
[6245]2049    if mmath == xr :
[6666]2050        if ay :
2051            ftab = ftab.assign_coords({'y':np.arange(ftab.shape[jy])+1.})
2052        if zdim and az :
2053            if az != zdim :
2054                ftab = ftab.rename( {az:zdim})
[6190]2055    return ftab
2056
[6666]2057def t2f (ttab, nperio=None, psgn=1.0, zdim=None, action='mean') :
2058    '''Interpolates an array on T grid to F grid (i- and j- means)
2059    '''
[6245]2060    mmath = __mmath__ (ttab)
2061    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
2062    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
[6666]2063    ftab = t2u (u2f (ttab, nperio=nperio, psgn=psgn, zdim=zdim, action=action),
2064                     nperio=nperio, psgn=psgn, zdim=zdim, action=action)
2065
[6245]2066    return lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
[6190]2067
[6666]2068def f2u (ftab, nperio=None, psgn=1.0, zdim=None, action='ave') :
2069    '''Interpolates an array on F grid to U grid (j-mean)
2070    '''
[6245]2071    mmath = __mmath__ (ftab)
2072    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
2073    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
[6666]2074    ay, jy = __find_axis__ (ftab_0, 'y')
2075    az     = __find_axis__ (ftab_0, 'z')[0]
2076    if jy :
2077        if action == 'ave'  :
2078            utab = 0.5 *      (ftab_0 + np.roll (ftab_0, axis=jy, shift=-1))
2079        if action == 'min'  :
2080            utab = np.minimum (ftab_0 , np.roll (ftab_0, axis=jy, shift=-1))
2081        if action == 'max'  :
2082            utab = np.maximum (ftab_0 , np.roll (ftab_0, axis=jy, shift=-1))
2083        if action == 'mult' :
2084            utab =             ftab_0 * np.roll (ftab_0, axis=jy, shift=-1)
2085        utab = lbc_del (utab  , nperio=nperio, cd_type='U', psgn=psgn)
2086    else :
2087        utab = lbc_del (ftab_0, nperio=nperio, cd_type='U', psgn=psgn)
[6245]2088
2089    if mmath == xr :
[6666]2090        utab = utab.assign_coords({ay:np.arange(ftab.shape[jy])+1.})
2091        if zdim and az and az != zdim :
2092            utab = utab.rename( {az:zdim})
[6190]2093    return utab
2094
[6666]2095def f2v (ftab, nperio=None, psgn=1.0, zdim=None, action='ave') :
2096    '''Interpolates an array from F grid to V grid (i-mean)
2097    '''
[6245]2098    mmath = __mmath__ (ftab)
2099    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
2100    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
[6666]2101    ax, ix = __find_axis__ (ftab_0, 'x')
2102    az     = __find_axis__ (ftab_0, 'z')[0]
2103    if ix :
2104        if action == 'ave'  :
2105            vtab = 0.5 *      (ftab_0 + np.roll (ftab_0, axis=ix, shift=-1))
2106        if action == 'min'  :
2107            vtab = np.minimum (ftab_0 , np.roll (ftab_0, axis=ix, shift=-1))
2108        if action == 'max'  :
2109            vtab = np.maximum (ftab_0 , np.roll (ftab_0, axis=ix, shift=-1))
2110        if action == 'mult' :
2111            vtab =             ftab_0 * np.roll (ftab_0, axis=ix, shift=-1)
2112        vtab = lbc_del (vtab  , nperio=nperio, cd_type='V', psgn=psgn)
2113    else :
2114        vtab = lbc_del (ftab_0, nperio=nperio, cd_type='V', psgn=psgn)
[6190]2115
[6245]2116    if mmath == xr :
[6190]2117        vtab = vtab.assign_coords({ax:np.arange(ftab.shape[ix])+1.})
[6666]2118        if zdim and az :
2119            if az != zdim :
2120                vtab = vtab.rename( {az:zdim})
[6190]2121    return vtab
2122
[6666]2123def w2t (wtab, zcoord=None, zdim=None, sval=np.nan) :
2124    '''Interpolates an array on W grid to T grid (k-mean)
2125
[6190]2126    sval is the bottom value
2127    '''
[6245]2128    mmath = __mmath__ (wtab)
2129    wtab_0 = mmath.where ( np.isnan(wtab), 0., wtab)
[6190]2130
[6666]2131    az, kz = __find_axis__ (wtab_0, 'z')
2132
2133    if kz :
2134        ttab = 0.5 * ( wtab_0 + np.roll (wtab_0, axis=kz, shift=-1) )
2135    else :
2136        ttab = wtab_0
2137
[6245]2138    if mmath == xr :
[6666]2139        ttab[{az:kz}] = sval
2140        if zdim and az :
2141            if az != zdim :
2142                ttab = ttab.rename ( {az:zdim} )
2143        if zcoord is not None :
2144            ttab = ttab.assign_coords ( {zdim:zcoord} )
[6190]2145    else :
2146        ttab[..., -1, :, :] = sval
2147
2148    return ttab
2149
[6666]2150def t2w (ttab, zcoord=None, zdim=None, sval=np.nan, extrap_surf=False) :
2151    '''Interpolates an array from T grid to W grid (k-mean)
2152
[6190]2153    sval is the surface value
2154    if extrap_surf==True, surface value is taken from 1st level value.
2155    '''
[6245]2156    mmath = __mmath__ (ttab)
2157    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
[6666]2158    az, kz = __find_axis__ (ttab_0, 'z')
2159    wtab = 0.5 * ( ttab_0 + np.roll (ttab_0, axis=kz, shift=1) )
[6190]2160
[6245]2161    if mmath == xr :
[6666]2162        if extrap_surf :
2163            wtab[{az:0}] = ttab[{az:0}]
2164        else           :
2165            wtab[{az:0}] = sval
2166    else :
2167        if extrap_surf :
2168            wtab[..., 0, :, :] = ttab[..., 0, :, :]
2169        else           :
2170            wtab[..., 0, :, :] = sval
[6190]2171
[6245]2172    if mmath == xr :
[6666]2173        if zdim and az and az != zdim :
2174            wtab = wtab.rename ( {az:zdim})
2175        if zcoord is not None :
2176            wtab = wtab.assign_coords ( {zdim:zcoord})
2177        else :
2178            wtab = wtab.assign_coords ( {zdim:np.arange(ttab.shape[kz])+1.} )
[6190]2179    return wtab
2180
[6666]2181def fill (ptab, nperio, cd_type='T', npass=1, sval=np.nan) :
2182    '''Fills np.nan values with mean of neighbours
2183
[6190]2184    Inputs :
2185       ptab : input field to fill
2186       nperio, cd_type : periodicity characteristics
[6666]2187    '''
[6190]2188
[6245]2189    mmath = __mmath__ (ptab)
[6190]2190
[6666]2191    do_perio  = False
2192    lperio    = nperio
[6245]2193    if nperio == 4.2 :
[6666]2194        do_perio, lperio = True, 4
[6245]2195    if nperio == 6.2 :
[6666]2196        do_perio, lperio = True, 6
2197
2198    if do_perio :
2199        ztab = lbc_add (ptab, nperio=nperio)
2200    else :
[6245]2201        ztab = ptab
[6666]2202
2203    if np.isnan (sval) :
[6245]2204        ztab   = mmath.where (np.isnan(ztab), np.nan, ztab)
2205    else :
2206        ztab   = mmath.where (ztab==sval    , np.nan, ztab)
[6666]2207
2208    for _ in np.arange (npass) :
[6245]2209        zmask = mmath.where ( np.isnan(ztab), 0., 1.   )
2210        ztab0 = mmath.where ( np.isnan(ztab), 0., ztab )
[6190]2211        # Compte du nombre de voisins
[6245]2212        zcount = 1./6. * ( zmask \
[6190]2213          + np.roll(zmask, shift=1, axis=-1) + np.roll(zmask, shift=-1, axis=-1) \
2214          + np.roll(zmask, shift=1, axis=-2) + np.roll(zmask, shift=-1, axis=-2) \
2215          + 0.5 * ( \
2216                + np.roll(np.roll(zmask, shift= 1, axis=-2), shift= 1, axis=-1) \
2217                + np.roll(np.roll(zmask, shift=-1, axis=-2), shift= 1, axis=-1) \
2218                + np.roll(np.roll(zmask, shift= 1, axis=-2), shift=-1, axis=-1) \
[6245]2219                + np.roll(np.roll(zmask, shift=-1, axis=-2), shift=-1, axis=-1) ) )
[6190]2220
[6245]2221        znew =1./6. * ( ztab0 \
2222           + np.roll(ztab0, shift=1, axis=-1) + np.roll(ztab0, shift=-1, axis=-1) \
2223           + np.roll(ztab0, shift=1, axis=-2) + np.roll(ztab0, shift=-1, axis=-2) \
2224           + 0.5 * ( \
2225                + np.roll(np.roll(ztab0 , shift= 1, axis=-2), shift= 1, axis=-1) \
2226                + np.roll(np.roll(ztab0 , shift=-1, axis=-2), shift= 1, axis=-1) \
2227                + np.roll(np.roll(ztab0 , shift= 1, axis=-2), shift=-1, axis=-1) \
2228                + np.roll(np.roll(ztab0 , shift=-1, axis=-2), shift=-1, axis=-1) ) )
[6190]2229
[6245]2230        zcount = lbc (zcount, nperio=lperio, cd_type=cd_type)
2231        znew   = lbc (znew  , nperio=lperio, cd_type=cd_type)
[6666]2232
[6245]2233        ztab = mmath.where (np.logical_and (zmask==0., zcount>0), znew/zcount, ztab)
[6190]2234
[6245]2235    ztab = mmath.where (zcount==0, sval, ztab)
[6666]2236    if do_perio :
2237        ztab = lbc_del (ztab, nperio=lperio)
[6190]2238
2239    return ztab
2240
2241def correct_uv (u, v, lat) :
2242    '''
[6666]2243    Corrects a Cartopy bug in orthographic projection
[6190]2244
2245    See https://github.com/SciTools/cartopy/issues/1179
2246
2247    The correction is needed with cartopy <= 0.20
2248    It seems that version 0.21 will correct the bug (https://github.com/SciTools/cartopy/pull/1926)
[6666]2249    Later note : the bug is still present in Cartopy 0.22
[6190]2250
2251    Inputs :
[6666]2252       u, v : eastward/northward components
[6190]2253       lat  : latitude of the point (degrees north)
2254
[6666]2255    Outputs :
[6245]2256       modified eastward/nothward components to have correct polar projections in cartopy
[6190]2257    '''
2258    uv = np.sqrt (u*u + v*v)           # Original modulus
2259    zu = u
[6666]2260    zv = v * np.cos (RAD*lat)
[6190]2261    zz = np.sqrt ( zu*zu + zv*zv )     # Corrected modulus
[6666]2262    uc = zu*uv/zz
2263    vc = zv*uv/zz      # Final corrected values
[6190]2264    return uc, vc
2265
[6666]2266def norm_uv (u, v) :
2267    '''Returns norm of a 2 components vector
[6245]2268    '''
[6666]2269    return np.sqrt (u*u + v*v)
2270
2271def normalize_uv (u, v) :
2272    '''Normalizes 2 components vector
[6245]2273    '''
[6666]2274    uv = norm_uv (u, v)
2275    return u/uv, v/uv
2276
2277def msf (vv, e1v_e3v, plat1d, depthw) :
2278    '''Computes the meridonal stream function
2279
2280    vv : meridional_velocity
2281    e1v_e3v : prodcut of scale factors e1v*e3v
2282    '''
2283
2284    v_e1v_e3v = vv * e1v_e3v
2285    v_e1v_e3v.attrs = vv.attrs
2286
2287    ax = __find_axis__ (v_e1v_e3v, 'x')[0]
2288    az = __find_axis__ (v_e1v_e3v, 'z')[0]
2289    if az == 'olevel' :
2290        new_az = 'olevel'
2291    else :
2292        new_az = 'depthw'
2293
2294    zomsf = -v_e1v_e3v.cumsum ( dim=az, keep_attrs=True).sum (dim=ax, keep_attrs=True)*1.E-6
2295    zomsf = zomsf - zomsf.isel ( { az:-1} )
2296
2297    ay = __find_axis__ (zomsf, 'y' )[0]
2298    zomsf = zomsf.assign_coords ( {az:depthw.values, ay:plat1d.values})
2299
2300    zomsf = zomsf.rename ( {ay:'lat'})
2301    if az != new_az :
2302        zomsf = zomsf.rename ( {az:new_az} )
2303    zomsf.attrs['standard_name'] = 'Meridional stream function'
[6245]2304    zomsf.attrs['long_name'] = 'Meridional stream function'
[6666]2305    zomsf.attrs['units'] = 'Sv'
2306    zomsf[new_az].attrs  = depthw.attrs
2307    zomsf.lat.attrs=plat1d.attrs
[6245]2308
2309    return zomsf
2310
[6666]2311def bsf (uu, e2u_e3u, mask, nperio=None, bsf0=None ) :
2312    '''Computes the barotropic stream function
[6245]2313
[6666]2314    uu      : zonal_velocity
2315    e2u_e3u : product of scales factor e2u*e3u
2316    bsf0    : the point with bsf=0
2317    (ex: bsf0={'x':3, 'y':120} for orca2,
2318         bsf0={'x':5, 'y':300} for eORCA1
[6245]2319    '''
[6666]2320    u_e2u_e3u       = uu * e2u_e3u
2321    u_e2u_e3u.attrs = uu.attrs
[6245]2322
[6666]2323    ay = __find_axis__ (u_e2u_e3u, 'y')[0]
2324    az = __find_axis__ (u_e2u_e3u, 'z')[0]
[6245]2325
[6666]2326    zbsf = -u_e2u_e3u.cumsum ( dim=ay, keep_attrs=True )
2327    zbsf = zbsf.sum (dim=az, keep_attrs=True)*1.E-6
2328
2329    if bsf0 :
2330        zbsf = zbsf - zbsf.isel (bsf0)
2331
2332    zbsf = zbsf.where (mask !=0, np.nan)
2333    zbsf.attrs.update (uu.attrs)
2334    zbsf.attrs['standard_name'] = 'barotropic_stream_function'
2335    zbsf.attrs['long_name']     = 'Barotropic stream function'
2336    zbsf.attrs['units']         = 'Sv'
2337    zbsf = lbc (zbsf, nperio=nperio, cd_type='F')
2338
2339    return zbsf
2340
2341if f90nml :
2342    def namelist_read (ref=None, cfg=None, out='dict', flat=False, verbose=False) :
2343        '''Read NEMO namelist(s) and return either a dictionnary or an xarray dataset
2344
2345        ref : file with reference namelist, or a f90nml.namelist.Namelist object
2346        cfg : file with config namelist, or a f90nml.namelist.Namelist object
2347        At least one namelist neaded
2348
2349        out:
[6245]2350        'dict' to return a dictonnary
2351        'xr'   to return an xarray dataset
[6666]2352        flat : only for dict output. Output a flat dictionary with all values.
[6245]2353
[6666]2354        '''
2355        if ref :
2356            if isinstance (ref, str) :
2357                nml_ref = f90nml.read (ref)
2358            if isinstance (ref, f90nml.namelist.Namelist) :
2359                nml_ref = ref
[6245]2360
[6666]2361        if cfg :
2362            if isinstance (cfg, str) :
2363                nml_cfg = f90nml.read (cfg)
2364            if isinstance (cfg, f90nml.namelist.Namelist) :
2365                nml_cfg = cfg
[6245]2366
[6666]2367        if out == 'dict' :
2368            dict_namelist = {}
2369        if out == 'xr'   :
2370            xr_namelist = xr.Dataset ()
[6245]2371
[6666]2372        list_nml     = []
2373        list_comment = []
[6245]2374
[6666]2375        if ref :
2376            list_nml.append (nml_ref)
2377            list_comment.append ('ref')
2378        if cfg :
2379            list_nml.append (nml_cfg)
2380            list_comment.append ('cfg')
[6245]2381
[6666]2382        for nml, comment in zip (list_nml, list_comment) :
2383            if verbose :
2384                print (comment)
2385            if flat and out =='dict' :
2386                for nam in nml.keys () :
2387                    if verbose :
2388                        print (nam)
2389                    for value in nml[nam] :
2390                        if out == 'dict' :
2391                            dict_namelist[value] = nml[nam][value]
2392                        if verbose :
2393                            print (nam, ':', value, ':', nml[nam][value])
2394            else :
2395                for nam in nml.keys () :
2396                    if verbose :
2397                        print (nam)
2398                    if out == 'dict' :
2399                        if nam not in dict_namelist.keys () :
2400                            dict_namelist[nam] = {}
2401                    for value in nml[nam] :
2402                        if out == 'dict' :
2403                            dict_namelist[nam][value] = nml[nam][value]
2404                        if out == 'xr'   :
2405                            xr_namelist[value] = nml[nam][value]
2406                        if verbose :
2407                            print (nam, ':', value, ':', nml[nam][value])
[6245]2408
[6666]2409        if out == 'dict' :
2410            return dict_namelist
2411        if out == 'xr'   :
2412            return xr_namelist
2413else :
2414     def namelist_read (ref=None, cfg=None, out='dict', flat=False, verbose=False) :
2415        '''Shadow version of namelist read, when f90nm module was not found
2416
2417        namelist_read :
2418        Read NEMO namelist(s) and return either a dictionnary or an xarray dataset
2419        '''
2420        print ( 'Error : module f90nml not found' )
2421        print ( 'Cannot call namelist_read' )
2422        print ( 'Call parameters where : ')
2423        print ( f'{err=} {ref=} {cfg=} {out=} {flat=} {verbose=}' )
2424
[6245]2425def fill_closed_seas (imask, nperio=None,  cd_type='T') :
2426    '''Fill closed seas with image processing library
[6666]2427
[6245]2428    imask : mask, 1 on ocean, 0 on land
2429    '''
2430    from scipy import ndimage
2431
2432    imask_filled = ndimage.binary_fill_holes ( lbc (imask, nperio=nperio, cd_type=cd_type))
2433    imask_filled = lbc ( imask_filled, nperio=nperio, cd_type=cd_type)
2434
2435    return imask_filled
2436
[6666]2437# ======================================================
2438# Sea water state function parameters from NEMO code
2439
2440RDELTAS = 32.
2441R1_S0   = 0.875/35.16504
2442R1_T0   = 1./40.
2443R1_Z0   = 1.e-4
2444
2445EOS000 =  8.0189615746e+02
2446EOS100 =  8.6672408165e+02
2447EOS200 = -1.7864682637e+03
2448EOS300 =  2.0375295546e+03
2449EOS400 = -1.2849161071e+03
2450EOS500 =  4.3227585684e+02
2451EOS600 = -6.0579916612e+01
2452EOS010 =  2.6010145068e+01
2453EOS110 = -6.5281885265e+01
2454EOS210 =  8.1770425108e+01
2455EOS310 = -5.6888046321e+01
2456EOS410 =  1.7681814114e+01
2457EOS510 = -1.9193502195
2458EOS020 = -3.7074170417e+01
2459EOS120 =  6.1548258127e+01
2460EOS220 = -6.0362551501e+01
2461EOS320 =  2.9130021253e+01
2462EOS420 = -5.4723692739
2463EOS030 =  2.1661789529e+01
2464EOS130 = -3.3449108469e+01
2465EOS230 =  1.9717078466e+01
2466EOS330 = -3.1742946532
2467EOS040 = -8.3627885467
2468EOS140 =  1.1311538584e+01
2469EOS240 = -5.3563304045
2470EOS050 =  5.4048723791e-01
2471EOS150 =  4.8169980163e-01
2472EOS060 = -1.9083568888e-01
2473EOS001 =  1.9681925209e+01
2474EOS101 = -4.2549998214e+01
2475EOS201 =  5.0774768218e+01
2476EOS301 = -3.0938076334e+01
2477EOS401 =  6.6051753097
2478EOS011 = -1.3336301113e+01
2479EOS111 = -4.4870114575
2480EOS211 =  5.0042598061
2481EOS311 = -6.5399043664e-01
2482EOS021 =  6.7080479603
2483EOS121 =  3.5063081279
2484EOS221 = -1.8795372996
2485EOS031 = -2.4649669534
2486EOS131 = -5.5077101279e-01
2487EOS041 =  5.5927935970e-01
2488EOS002 =  2.0660924175
2489EOS102 = -4.9527603989
2490EOS202 =  2.5019633244
2491EOS012 =  2.0564311499
2492EOS112 = -2.1311365518e-01
2493EOS022 = -1.2419983026
2494EOS003 = -2.3342758797e-02
2495EOS103 = -1.8507636718e-02
2496EOS013 =  3.7969820455e-01
2497
2498def rhop ( ptemp, psal ) :
2499    '''Returns potential density referenced to surface
2500
2501    Computation from NEMO code
2502    '''
2503    zt      = ptemp * R1_T0                                  # Temperature (°C)
2504    zs      = np.sqrt ( np.abs( psal + RDELTAS ) * R1_S0 )   # Square root of salinity (PSS)
2505    #
2506    prhop = (
2507      (((((EOS060*zt
2508         + EOS150*zs     + EOS050)*zt
2509         + (EOS240*zs    + EOS140)*zs + EOS040)*zt
2510         + ((EOS330*zs   + EOS230)*zs + EOS130)*zs + EOS030)*zt
2511         + (((EOS420*zs  + EOS320)*zs + EOS220)*zs + EOS120)*zs + EOS020)*zt
2512         + ((((EOS510*zs + EOS410)*zs + EOS310)*zs + EOS210)*zs + EOS110)*zs + EOS010)*zt
2513         + (((((EOS600*zs+ EOS500)*zs + EOS400)*zs + EOS300)*zs + EOS200)*zs + EOS100)*zs + EOS000 )
2514    #
2515    return prhop
2516
2517def rho ( pdep, ptemp, psal ) :
2518    '''Returns in situ density
2519
2520    Computation from NEMO code
2521    '''
2522    zh      = pdep  * R1_Z0                                  # Depth (m)
2523    zt      = ptemp * R1_T0                                  # Temperature (°C)
2524    zs      = np.sqrt ( np.abs( psal + RDELTAS ) * R1_S0 )   # Square root salinity (PSS)
2525    #
2526    zn3 = EOS013*zt + EOS103*zs+EOS003
2527    #
2528    zn2 = (EOS022*zt + EOS112*zs+EOS012)*zt + (EOS202*zs+EOS102)*zs+EOS002
2529    #
2530    zn1 = (
2531      (((EOS041*zt
2532       + EOS131*zs   + EOS031)*zt
2533       + (EOS221*zs  + EOS121)*zs + EOS021)*zt
2534       + ((EOS311*zs  + EOS211)*zs + EOS111)*zs + EOS011)*zt
2535       + (((EOS401*zs + EOS301)*zs + EOS201)*zs + EOS101)*zs + EOS001 )
2536    #
2537    zn0 = (
2538      (((((EOS060*zt
2539         + EOS150*zs      + EOS050)*zt
2540         + (EOS240*zs     + EOS140)*zs + EOS040)*zt
2541         + ((EOS330*zs    + EOS230)*zs + EOS130)*zs + EOS030)*zt
2542         + (((EOS420*zs   + EOS320)*zs + EOS220)*zs + EOS120)*zs + EOS020)*zt
2543         + ((((EOS510*zs  + EOS410)*zs + EOS310)*zs + EOS210)*zs + EOS110)*zs + EOS010)*zt
2544         + (((((EOS600*zs + EOS500)*zs + EOS400)*zs + EOS300)*zs +
2545                                       EOS200)*zs + EOS100)*zs + EOS000 )
2546    #
2547    prho  = ( ( zn3 * zh + zn2 ) * zh + zn1 ) * zh + zn0
2548    #
2549    return prho
2550
[6064]2551## ===========================================================================
2552##
2553##                               That's all folk's !!!
2554##
2555## ===========================================================================
[6190]2556
[6666]2557# def __is_orca_north_fold__ ( Xtest, cname_long='T' ) :
2558#     '''
2559#     Ported (pirated !!?) from Sosie
[6190]2560
[6666]2561#     Tell if there is a 2/point band overlaping folding at the north pole typical of the ORCA grid
[6190]2562
[6666]2563#     0 => not an orca grid (or unknown one)
2564#     4 => North fold T-point pivot (ex: ORCA2)
2565#     6 => North fold F-point pivot (ex: ORCA1)
[6190]2566
[6666]2567#     We need all this 'cname_long' stuff because with our method, there is a
2568#     confusion between "Grid_U with T-fold" and "Grid_V with F-fold"
2569#     => so knowing the name of the longitude array (as in namelist, and hence as
2570#     in netcdf file) might help taking the righ decision !!! UGLY!!!
2571#     => not implemented yet
2572#     '''
[6190]2573
[6666]2574#     ifld_nord =  0 ; cgrd_type = 'X'
2575#     ny, nx = Xtest.shape[-2:]
[6190]2576
[6666]2577#     if ny > 3 : # (case if called with a 1D array, ignoring...)
2578#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-1:nx-nx//2+1:-1] ).sum() == 0. :
2579#           ifld_nord = 4 ; cgrd_type = 'T' # T-pivot, grid_T
[6190]2580
[6666]2581#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-2:nx-nx//2  :-1] ).sum() == 0. :
2582#             if cnlon == 'U' : ifld_nord = 4 ;  cgrd_type = 'U' # T-pivot, grid_T
2583#                 ## LOLO: PROBLEM == 6, V !!!
[6190]2584
[6666]2585#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-1:nx-nx//2+1:-1] ).sum() == 0. :
2586#             ifld_nord = 4 ; cgrd_type = 'V' # T-pivot, grid_V
[6190]2587
[6666]2588#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-2, nx-1-1:nx-nx//2:-1] ).sum() == 0. :
2589#             ifld_nord = 6 ; cgrd_type = 'T'# F-pivot, grid_T
[6190]2590
[6666]2591#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-1, nx-1:nx-nx//2-1:-1] ).sum() == 0. :
2592#             ifld_nord = 6 ;  cgrd_type = 'U' # F-pivot, grid_U
[6190]2593
[6666]2594#         if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-2:nx-nx//2  :-1] ).sum() == 0. :
2595#             if cnlon == 'V' : ifld_nord = 6 ; cgrd_type = 'V' # F-pivot, grid_V
2596#                 ## LOLO: PROBLEM == 4, U !!!
2597
2598#     return ifld_nord, cgrd_type
Note: See TracBrowser for help on using the repository browser.