source: TOOLS/WATER_BUDGET/nemo.py @ 6652

Last change on this file since 6652 was 6647, checked in by omamce, 8 months ago

O.M. :

New version of WATER_BUDGET

  • Conservation in NEMO are OK
  • Conservation in Sechiba are OK, for both ICO and latlon grids
  • Problems in atmosphere, LIC surface and ocean/atmosphere coherence
  • Property svn:keywords set to Date Revision HeadURL Author
File size: 86.0 KB
Line 
1# -*- coding: utf-8 -*-
2## ===========================================================================
3##
4##  This software is governed by the CeCILL  license under French law and
5##  abiding by the rules of distribution of free software.  You can  use,
6##  modify and/ or redistribute the software under the terms of the CeCILL
7##  license as circulated by CEA, CNRS and INRIA at the following URL
8##  "http://www.cecill.info".
9##
10##  Warning, to install, configure, run, use any of Olivier Marti's
11##  software or to read the associated documentation you'll need at least
12##  one (1) brain in a reasonably working order. Lack of this implement
13##  will void any warranties (either express or implied).
14##  O. Marti assumes no responsability for errors, omissions,
15##  data loss, or any other consequences caused directly or indirectly by
16##  the usage of his software by incorrectly or partially configured
17##  personal.
18##
19## ===========================================================================
20'''
21Utilities to plot NEMO ORCA fields
22Periodicity and other stuff
23
24- Lots of tests for xarray object
25- Not much testerd for numpy objects
26
27olivier.marti@lsce.ipsl.fr
28
29## SVN information
30__Author__   = "$Author$"
31__Date__     = "$Date$"
32__Revision__ = "$Revision$"
33__Id__       = "$Id: $"
34__HeadURL    = "$HeadURL$"
35'''
36
37import numpy as np
38try    : import xarray as xr
39except ImportError : pass
40
41#try    : import f90nml
42#except : pass
43
44#try : from sklearn.impute import SimpleImputer
45#except : pass
46
47rpi = np.pi ; rad = np.deg2rad (1.0) ; dar = np.rad2deg (1.0)
48
49nperio_valid_range = [0, 1, 4, 4.2, 5, 6, 6.2]
50
51rday   = 24.*60.*60.     # Day length [s]
52rsiyea = 365.25 * rday * 2. * rpi / 6.283076 # Sideral year length [s]
53rsiday = rday / (1. + rday / rsiyea)
54raamo  =  12.        # Number of months in one year
55rjjhh  =  24.        # Number of hours in one day
56rhhmm  =  60.        # Number of minutes in one hour
57rmmss  =  60.        # Number of seconds in one minute
58omega  = 2. * rpi / rsiday # Earth rotation parameter [s-1]
59ra     = 6371229.    # Earth radius [m]
60grav   = 9.80665     # Gravity [m/s2]
61repsi  = np.finfo (1.0).eps
62
63## Default names of dimensions
64dim_names = {'x':'xx', 'y':'yy', 'z':'olevel', 't':None}
65
66## All possibles name of dimensions in Nemo files
67xName = [ 'x', 'X', 'X1', 'xx', 'XX', 'x_grid_T', 'x_grid_U', 'x_grid_V', 'x_grid_F', 'x_grid_W', 'lon', 'nav_lon', 'longitude', 'X1', 'x_c', 'x_f', ]
68yName = [ 'y', 'Y', 'Y1', 'yy', 'YY', 'y_grid_T', 'y_grid_U', 'y_grid_V', 'y_grid_F', 'y_grid_W', 'lat', 'nav_lat', 'latitude' , 'Y1', 'y_c', 'y_f', ]
69zName = [ 'z', 'Z', 'Z1', 'zz', 'ZZ', 'depth', 'tdepth', 'udepth', 'vdepth', 'wdepth', 'fdepth', 'deptht', 'depthu', 'depthv', 'depthw', 'depthf', 'olevel', 'z_c', 'z_f', ]
70tName = [ 't', 'T', 'tt', 'TT', 'time', 'time_counter', 'time_centered', ]
71
72## All possibles name of units of dimensions in Nemo files
73xUnit = [ 'degrees_east', ]
74yUnit = [ 'degrees_north', ]
75zUnit = [ 'm', 'meter', ]
76tUnit = [ 'second', 'minute', 'hour', 'day', 'month', 'year', ]
77
78## All possibles size of dimensions in Orca files
79xLength = [ 180, 182, 360, 362 ]
80yLength = [ 148, 149, 331, 332 ]
81zLength = [31, 75]
82
83## ===========================================================================
84def __mmath__ (tab, default=None) :
85    '''
86    Determines the type of tab : xarray or numpy object ?
87    '''
88    mmath = default
89    try    :
90        if type (tab) == xr.core.dataarray.DataArray : mmath = xr
91    except : pass
92
93    try    :
94        if type (tab) == np.ndarray : mmath = np
95    except : pass
96           
97    return mmath
98
99def __guessNperio__ (jpj, jpi, nperio=None, out='nperio') :
100    '''
101    Tries to guess the value of nperio (periodicity parameter. See NEMO documentation for details)
102   
103    Inputs
104    jpj    : number of latitudes
105    jpi    : number of longitudes
106    nperio : periodicity parameter
107    '''
108    if nperio == None :
109        nperio = __guessConfig__ (jpj, jpi, nperio=None, out='nperio')
110   
111    return nperio
112
113def __guessConfig__ (jpj, jpi, nperio=None, config=None, out='nperio') :
114    '''
115    Tries to guess the value of nperio (periodicity parameter. See NEMO documentation for details)
116
117    Inputs
118    jpj    : number of latitudes
119    jpi    : number of longitudes
120    nperio : periodicity parameter
121    '''
122    print ( jpi, jpj)
123    if nperio == None :
124        ## Values for NEMO version < 4.2
125        if (jpj ==  149 and jpi == 182) or (jpj == None and jpi == 182) or (jpj == 149 or jpi == None) :
126            config = 'ORCA2.3'
127            nperio = 4   # ORCA2. We choose legacy orca2.
128            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'T'
129        if (jpj == 332 and jpi == 362) or (jpj == None and jpi == 362) or (jpj ==  332 and jpi == None) : # eORCA1.
130            config = 'eORCA1.2'
131            nperio = 6 
132            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
133        if jpi == 1442 :  # ORCA025.
134            config = 'ORCA025'
135            nperio = 6 
136            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
137        if jpj ==  294 : # ORCA1
138            config = 'ORCA1'
139            nperio = 6
140            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
141           
142        ## Values for NEMO version >= 4.2. No more halo points
143        if (jpj == 148 and jpi == 180) or (jpj == None and jpi == 180) or (jpj == 148 and jpi == None) :
144            config = 'ORCA2.4'
145            nperio = 4.2 # ORCA2. We choose legacy orca2.
146            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
147        if (jpj == 331 and jpi == 360) or (jpj == None and jpi == 360) or (jpj == 331 and jpi == None) : # eORCA1.
148            config = 'eORCA1.4'
149            nperio = 6.2
150            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
151        if jpi == 1440 : # ORCA025.
152            config = 'ORCA025'
153            nperio = 6.2
154            Iperio = 1 ; Jperio = 0 ; NFold = 1 ; NFtype = 'F'
155           
156        if nperio == None :
157            raise Exception  ('in nemo module : nperio not found, and cannot by guessed')
158        else :
159            if nperio in nperio_valid_range :
160                print ( f'nperio set as {nperio} (deduced from {jpj=} and {jpi=})' )
161            else : 
162                raise ValueError ( f'nperio set as {nperio} (deduced from {jpi=} and {jpj=}) : nemo.py is not ready for this value' )
163
164    if out == 'nperio' : return nperio
165    if out == 'config' : return config
166    if out == 'perio'  : return Iperio, Jperio, NFold, NFtype
167    if out in ['full', 'all'] : return {'nperio':nperio, 'Iperio':Iperio, 'Jperio':Jperio, 'NFold':NFold, 'NFtype':NFtype}
168       
169def __guessPoint__ (ptab) :
170    '''
171    Tries to guess the grid point (periodicity parameter. See NEMO documentation for details)
172   
173    For array conforments with xgcm requirements
174
175    Inputs
176         ptab : xarray array
177
178    Credits : who is the original author ?
179    '''
180   
181    gP = None
182    mmath = __mmath__ (ptab)
183    if mmath == xr :
184        if 'x_c' in ptab.dims and 'y_c' in ptab.dims                        : gP = 'T'
185        if 'x_f' in ptab.dims and 'y_c' in ptab.dims                        : gP = 'U'
186        if 'x_c' in ptab.dims and 'y_f' in ptab.dims                        : gP = 'V'
187        if 'x_f' in ptab.dims and 'y_f' in ptab.dims                        : gP = 'F'
188        if 'x_c' in ptab.dims and 'y_c' in ptab.dims and 'z_c' in ptab.dims : gP = 'T'
189        if 'x_c' in ptab.dims and 'y_c' in ptab.dims and 'z_f' in ptab.dims : gP = 'W'
190        if 'x_f' in ptab.dims and 'y_c' in ptab.dims and 'z_f' in ptab.dims : gP = 'U'
191        if 'x_c' in ptab.dims and 'y_f' in ptab.dims and 'z_f' in ptab.dims : gP = 'V'
192        if 'x_f' in ptab.dims and 'y_f' in ptab.dims and 'z_f' in ptab.dims : gP = 'F'
193             
194        if gP == None :
195            raise Exception ('in nemo module : cd_type not found, and cannot by guessed')
196        else :
197            print ( f'Grid set as {gP} deduced from dims {ptab.dims}' )
198            return gP
199    else :
200         raise Exception  ('in nemo module : cd_type not found, input is not an xarray data')
201
202def get_shape ( ptab ) :
203    '''
204    Get shape of ptab :
205    shape main contain X, Y, Z or T
206    Y is missing for a latitudinal slice
207    X is missing for on longitudinal slice
208    etc ...
209    '''
210   
211    get_shape = ''
212    ix, ax = __findAxis__ (ptab, 'x')
213    jy, ay = __findAxis__ (ptab, 'y')
214    kz, az = __findAxis__ (ptab, 'z')
215    lt, at = __findAxis__ (ptab, 't')
216    if ax : get_shape = 'X'
217    if ay : get_shape = 'Y' + get_shape
218    if az : get_shape = 'Z' + get_shape
219    if at : get_shape = 'T' + get_shape
220    return get_shape
221     
222def lbc_diag (nperio) :
223    lperio = nperio ; aperio = False
224    if nperio == 4.2 :
225        lperio = 4 ; aperio = True
226    if nperio == 6.2 :
227        lperio = 6 ; aperio = True
228       
229    return lperio, aperio
230
231def __findAxis__ (tab, axis='z') :
232    '''
233    Find order and name of the requested axis
234    '''
235    mmath = __mmath__ (tab)
236    ix = None ; ax = None
237
238    if axis in xName : axName = xName ; unList = xUnit ; Length = xLength
239    if axis in yName : axName = yName ; unList = yUnit ; Length = yLength
240    if axis in zName : axName = zName ; unList = zUnit ; Length = zLength
241    if axis in tName : axName = tName ; unList = tUnit ; Length = None
242   
243    if mmath == xr :
244        for Name in axName :
245            try    :
246                ix = tab.dims.index (Name)
247                ax = Name
248            except : pass
249
250        for i, dim in enumerate (tab.dims) :
251            if 'units' in tab.coords[dim].attrs.keys() :
252                for name in unList :
253                    if name in tab.coords[dim].attrs['units'] :
254                        ix = i ; ax = dim
255    else :
256        #if axis in xName : ix=-1
257        #if axis in yName :
258        #    if len(tab.shape) >= 2 : ix=-2
259        #if axis in zName :
260        #    if len(tab.shape) >= 3 : ix=-3
261        #if axis in tName :
262        #    if len(tab.shape) >=3  : ix=-3
263        #    if len(tab.shape) >=4  : ix=-4
264
265        l_shape = tab.shape
266        for nn in np.arange ( len(l_shape)) :
267            if l_shape[nn] in Length : ix = nn
268       
269    return ix, ax
270
271def findAxis ( tab, axis= 'z' ) :
272  ix, xx = __findAxis__ (tab, axis)
273  return xx
274
275def fixed_lon (lon, center_lon=0.0) :
276    '''
277    Returns corrected longitudes for nicer plots
278
279    lon        : longitudes of the grid. At least 2D.
280    center_lon : center longitude. Default=0.
281
282    Designed by Phil Pelson. See https://gist.github.com/pelson/79cf31ef324774c97ae7
283    '''
284    mmath = __mmath__ (lon)
285   
286    fixed_lon = lon.copy ()
287       
288    fixed_lon = mmath.where (fixed_lon > center_lon+180., fixed_lon-360.0, fixed_lon)
289    fixed_lon = mmath.where (fixed_lon < center_lon-180., fixed_lon+360.0, fixed_lon)
290   
291    for i, start in enumerate (np.argmax (np.abs (np.diff (fixed_lon, axis=-1)) > 180., axis=-1)) :
292        fixed_lon [..., i, start+1:] += 360.
293
294    # Special case for eORCA025
295    if fixed_lon.shape [-1] == 1442 : fixed_lon [..., -2, :] = fixed_lon [..., -3, :]
296    if fixed_lon.shape [-1] == 1440 : fixed_lon [..., -1, :] = fixed_lon [..., -2, :]
297
298    if fixed_lon.min () > center_lon : fixed_lon += -360.0
299    if fixed_lon.max () < center_lon : fixed_lon +=  360.0
300       
301    if fixed_lon.min () < center_lon-360.0 : fixed_lon +=  360.0
302    if fixed_lon.max () > center_lon+360.0 : fixed_lon += -360.0
303               
304    return fixed_lon
305
306def bounds_clolon ( bounds_lon, lon, rad=False, deg=True) :
307    '''Choose closest to lon0 longitude, adding or substacting 360° if needed'''
308
309    if rad : lon_range = 2.0*np.pi
310    if deg : lon_range = 360.0
311    bounds_clolon = bounds_lon.copy ()
312
313    bounds_clolon = xr.where ( bounds_clolon < lon-lon_range/2., bounds_clolon+lon_range, bounds_clolon )
314    bounds_clolon = xr.where ( bounds_clolon > lon+lon_range/2., bounds_clolon-lon_range, bounds_clolon )
315
316    return bounds_clolon
317
318def UnifyDims ( dd, udims=dim_names, verbose=False ) :
319    '''
320    Rename dimensions to unify them between NEMO versions
321    '''
322   
323    if udims['x'] :
324        for xx in xName :
325            if xx in dd.dims and xx != udims['x'] :
326                if verbose : print ( f"{xx} renamed to {udims['x']}" )
327                dd = dd.rename ( {xx:udims['x']})
328    if udims['y'] :
329        for yy in yName :
330            if yy in dd.dims and yy != udims['y']  :
331                if verbose : print ( f"{yy} renamed to {udims['y']}" )
332                dd = dd.rename ( {yy:udims['y']} )
333    if udims['z'] :
334        for zz in zName :
335            if zz in dd.dims and zz != udims['z'] :
336                if verbose : print ( f"{zz} renamed to {udims['z']}" )
337                dd = dd.rename ( {zz:udims['z']} )
338    if udims['t'] :
339        for tt in tName  :
340            if tt in dd.dims and tt != udims['t'] :
341                if verbose : print ( f"{tt} renamed to {udims['t']}" )
342                dd = dd.rename ( {tt:udims['t']} )
343
344    return dd
345
346def fill_empty (ztab, sval=np.nan, transpose=False) :
347    '''
348    Fill values
349
350    Useful when NEMO has run with no wet points options :
351    some parts of the domain, with no ocean points, have no
352    values
353    '''
354    from sklearn.impute import SimpleImputer
355    mmath = __mmath__ (ztab)
356
357    imp = SimpleImputer (missing_values=sval, strategy='mean')
358    if transpose :
359        imp.fit (ztab.T)
360        ptab = imp.transform (ztab.T).T
361    else : 
362        imp.fit (ztab)
363        ptab = imp.transform (ztab)
364   
365    if mmath == xr :
366        ptab = xr.DataArray (ptab, dims=ztab.dims, coords=ztab.coords)
367        ptab.attrs = ztab.attrs
368       
369    return ptab
370
371def fill_lonlat (lon, lat, sval=-1) :
372    '''
373    Fill longitude/latitude values
374
375    Useful when NEMO has run with no wet points options :
376    some parts of the domain, with no ocean points, have no
377    lon/lat values
378    '''
379    from sklearn.impute import SimpleImputer
380    mmath = __mmath__ (lon)
381
382    imp = SimpleImputer (missing_values=sval, strategy='mean')
383    imp.fit (lon)
384    plon = imp.transform (lon)
385    imp.fit (lat.T)
386    plat = imp.transform (lat.T).T
387
388    if mmath == xr :
389        plon = xr.DataArray (plon, dims=lon.dims, coords=lon.coords)
390        plat = xr.DataArray (plat, dims=lat.dims, coords=lat.coords)
391        plon.attrs = lon.attrs ; plat.attrs = lat.attrs
392       
393    plon = fixed_lon (plon)
394   
395    return plon, plat
396
397def fill_bounds_lonlat (bounds_lon, bounds_lat, sval=-1) :
398    '''
399    Fill longitude/latitude bounds values
400
401    Useful when NEMO has run with no wet points options :
402    some parts of the domain, with no ocean points, as no
403    lon/lat values
404    '''
405    mmath = __mmath__ (bounds_lon)
406
407    p_bounds_lon = np.empty ( bounds_lon.shape )
408    p_bounds_lat = np.empty ( bounds_lat.shape )
409
410    imp = SimpleImputer (missing_values=sval, strategy='mean')
411   
412    for n in np.arange (4) : 
413        imp.fit (bounds_lon[:,:,n])
414        p_bounds_lon[:,:,n] = imp.transform (bounds_lon[:,:,n])
415        imp.fit (bounds_lat[:,:,n].T)
416        p_bounds_lat[:,:,n] = imp.transform (bounds_lat[:,:,n].T).T
417       
418    if mmath == xr :
419        p_bounds_lon = xr.DataArray (bounds_lon, dims=bounds_lon.dims, coords=bounds_lon.coords)
420        p_bounds_lat = xr.DataArray (bounds_lat, dims=bounds_lat.dims, coords=bounds_lat.coords)
421        p_bounds_lon.attrs = bounds_lat.attrs ; p_bounds_lat.attrs = bounds_lat.attrs
422       
423    return p_bounds_lon, p_bounds_lat
424
425def jeq (lat) :
426    '''
427    Returns j index of equator in the grid
428   
429    lat : latitudes of the grid. At least 2D.
430    '''
431    mmath = __mmath__ (lat)
432    ix, ax = __findAxis__ (lat, 'x')
433    jy, ay = __findAxis__ (lat, 'y')
434
435    if mmath == xr :
436        jeq = int ( np.mean ( np.argmin (np.abs (np.float64 (lat)), axis=jy) ) )
437    else : 
438        jeq = np.argmin (np.abs (np.float64 (lat[...,:, 0])))
439    return jeq
440
441def lon1D (lon, lat=None) :
442    '''
443    Returns 1D longitude for simple plots.
444   
445    lon : longitudes of the grid
446    lat (optionnal) : latitudes of the grid
447    '''
448    mmath = __mmath__ (lon)
449    jpj, jpi  = lon.shape [-2:]
450    if np.max (lat) :
451        je    = jeq (lat)
452        #lon1D = lon.copy() [..., je, :]
453        lon0 = lon [..., je, 0].copy()
454        dlon = lon [..., je, 1].copy() - lon [..., je, 0].copy()
455        lon1D = np.linspace ( start=lon0, stop=lon0+360.+2*dlon, num=jpi )
456    else :
457        lon0 = lon [..., jpj//3, 0].copy()
458        dlon = lon [..., jpj//3, 1].copy() - lon [..., jpj//3, 0].copy()
459        lon1D = np.linspace ( start=lon0, stop=lon0+360.+2*dlon, num=jpi )
460
461    #start = np.argmax (np.abs (np.diff (lon1D, axis=-1)) > 180.0, axis=-1)
462    #lon1D [..., start+1:] += 360
463
464    if mmath == xr :
465        lon1D = xr.DataArray( lon1D, dims=('lon',), coords=(lon1D,))
466        lon1D.attrs = lon.attrs
467        lon1D.attrs['units']         = 'degrees_east'
468        lon1D.attrs['standard_name'] = 'longitude'
469        lon1D.attrs['long_name :']   = 'Longitude'
470       
471    return lon1D
472
473def latreg (lat, diff=0.1) :
474    '''
475    Returns maximum j index where gridlines are along latitudes in the northern hemisphere
476   
477    lat : latitudes of the grid (2D)
478    diff [optional] : tolerance
479    '''
480    mmath = __mmath__ (lat)
481    if diff == None :
482        dy   = np.float64 (np.mean (np.abs (lat - np.roll (lat,shift=1,axis=-2, roll_coords=False))))
483        print ( f'{dy=}' )
484        diff = dy/100.
485
486    je     = jeq (lat)
487    jreg   = np.where (lat[...,je:,:].max(axis=-1) - lat[...,je:,:].min(axis=-1)< diff)[-1][-1] + je
488    latreg = np.float64 (lat[...,jreg,:].mean(axis=-1))
489    JREG   = jreg
490
491    return jreg, latreg
492
493def lat1D (lat) :
494    '''
495    Returns 1D latitudes for zonal means and simple plots.
496
497    lat : latitudes of the grid (2D)
498    '''
499    mmath = __mmath__ (lat)
500    jpj, jpi = lat.shape[-2:]
501
502    dy     = np.float64 (np.mean (np.abs (lat - np.roll (lat, shift=1,axis=-2))))
503    je     = jeq (lat)
504    lat_eq = np.float64 (lat[...,je,:].mean(axis=-1))
505     
506    jreg, lat_reg = latreg (lat)
507    lat_ave = np.mean (lat, axis=-1)
508
509    #print ( f'{dy=} {jpj=} {je=} {lat_eq=} {jreg=} ' )
510   
511    if (np.abs (lat_eq) < dy/100.) : # T, U or W grid
512        if jpj-1 > jreg : dys = (90.-lat_reg) / (jpj-jreg-1)*0.5
513        else            : dys = (90.-lat_reg) / 2.0
514        yrange = (90.-dys-lat_reg)
515    else                           :  # V or F grid
516        yrange = 90.-lat_reg
517
518    if jpj-1 > jreg :
519        lat1D = mmath.where (lat_ave<lat_reg, lat_ave, lat_reg + yrange * (np.arange(jpj)-jreg)/(jpj-jreg-1) )
520    else :
521        lat1D = lat_ave
522    lat1D[-1] = 90.0
523
524    if mmath == xr :
525        lat1D = xr.DataArray( lat1D.values, dims=('lat',), coords=(lat1D,))
526        lat1D.attrs = lat.attrs
527        lat1D.attrs ['units']         = 'degrees_north'
528        lat1D.attrs ['standard_name'] = 'latitude'
529        lat1D.attrs ['long_name :']   = 'Latitude'
530       
531    return lat1D
532
533def latlon1D (lat, lon) :
534    '''
535    Returns simple latitude and longitude (1D) for simple plots.
536
537    lat, lon : latitudes and longitudes of the grid (2D)
538    '''
539    return lat1D (lat),  lon1D (lon, lat)
540
541def mask_lonlat (ptab, x0, x1, y0, y1, lon, lat, sval=np.nan) :
542    mmath = __mmath__ (ptab)
543    try :
544        lon = lon.copy().to_masked_array()
545        lat = lat.copy().to_masked_array()
546    except : pass
547           
548    mask = np.logical_and (np.logical_and(lat>y0, lat<y1), 
549            np.logical_or (np.logical_or (np.logical_and(lon>x0, lon<x1), np.logical_and(lon+360>x0, lon+360<x1)),
550                                      np.logical_and(lon-360>x0, lon-360<x1)))
551    tab = mmath.where (mask, ptab, np.nan)
552   
553    return tab
554
555def extend (tab, Lon=False, jplus=25, jpi=None, nperio=4) :
556    '''
557    Returns extended field eastward to have better plots, and box average crossing the boundary
558    Works only for xarray and numpy data (?)
559
560    Useful for vertical sections in OCE and ATM.
561
562    tab : field to extend.
563    Lon : (optional, default=False) : if True, add 360 in the extended parts of the field
564    jpi : normal longitude dimension of the field. exrtend does nothing it the actual
565        size of the field != jpi (avoid to extend several times)
566    jplus (optional, default=25) : number of points added on the east side of the field
567   
568    '''
569    mmath = __mmath__ (tab)
570   
571    if tab.shape[-1] == 1 : extend = tab
572
573    else :
574        if jpi == None : jpi = tab.shape[-1]
575
576        if Lon : xplus = -360.0
577        else   : xplus =    0.0
578
579        if tab.shape[-1] > jpi :
580            extend = tab
581        else :
582            if nperio == 0 or nperio == 4.2 :
583                istart = 0 ; le=jpi+1 ; la=0
584            if nperio == 1 :
585                istart = 0 ; le=jpi+1 ; la=0
586            if nperio == 4 or nperio == 6 : # OPA case with two halo points for periodicity
587                istart = 1 ; le=jpi-2 ; la=1  # Perfect, except at the pole that should be masked by lbc_plot
588           
589            if mmath == xr :
590                extend = np.concatenate ((tab.values[..., istart   :istart+le+1    ] + xplus,
591                                          tab.values[..., istart+la:istart+la+jplus]         ), axis=-1)
592                lon    = tab.dims[-1]
593                new_coords = []
594                for coord in tab.dims :
595                    if coord == lon : new_coords.append ( np.arange( extend.shape[-1]))
596                    else            : new_coords.append ( tab.coords[coord].values)
597                extend = xr.DataArray ( extend, dims=tab.dims, coords=new_coords )
598            else : 
599                extend = np.concatenate ((tab [..., istart   :istart+le+1    ] + xplus,
600                                          tab [..., istart+la:istart+la+jplus]          ), axis=-1)
601    return extend
602
603def orca2reg (ff, lat_name='nav_lat', lon_name='nav_lon', y_name='y', x_name='x') :
604    '''
605    Assign an ORCA dataset on a regular grid.
606    For use in the tropical region.
607   
608    Inputs :
609      ff : xarray dataset
610      lat_name, lon_name : name of latitude and longitude 2D field in ff
611      y_name, x_name     : namex of dimensions in ff
612     
613      Returns : xarray dataset with rectangular grid. Incorrect above 20°N
614    '''
615    # Compute 1D longitude and latitude
616    (lat, lon) = latlon1D (ff[lat_name], ff[lon_name])
617
618    # Assign lon and lat as dimensions of the dataset
619    if y_name in ff.dims : 
620        lat = xr.DataArray (lat, coords=[lat,], dims=['lat',])     
621        ff  = ff.rename_dims ({y_name: "lat",}).assign_coords (lat=lat)
622    if x_name in ff.dims :
623        lon = xr.DataArray (lon, coords=[lon,], dims=['lon',])
624        ff  = ff.rename_dims ({x_name: "lon",}).assign_coords (lon=lon)
625    # Force dimensions to be in the right order
626    coord_order = ['lat', 'lon']
627    for dim in [ 'depthw', 'depthv', 'depthu', 'deptht', 'depth', 'z',
628                 'time_counter', 'time', 'tbnds', 
629                 'bnds', 'axis_nbounds', 'two2', 'two1', 'two', 'four',] :
630        if dim in ff.dims : coord_order.insert (0, dim)
631       
632    ff = ff.transpose (*coord_order)
633    return ff
634
635def lbc_init (ptab, nperio=None) :
636    '''
637    Prepare for all lbc calls
638   
639    Set periodicity on input field
640    nperio    : Type of periodicity
641      0       : No periodicity
642      1, 4, 6 : Cyclic on i dimension (generaly longitudes) with 2 points halo
643      2       : Obsolete (was symmetric condition at southern boundary ?)
644      3, 4    : North fold T-point pivot (legacy ORCA2)
645      5, 6    : North fold F-point pivot (ORCA1, ORCA025, ORCA2 with new grid for paleo)
646    cd_type   : Grid specification : T, U, V or F
647
648    See NEMO documentation for further details
649    '''
650    jpi = None ; jpj = None
651    ix, ax = __findAxis__ (ptab, 'x')
652    jy, ay = __findAxis__ (ptab, 'y')
653    if ax : jpi = ptab.shape[ix]
654    if ay : jpj = ptab.shape[jy]
655       
656    if nperio == None : nperio = __guessNperio__ (jpj, jpi, nperio)
657   
658    if nperio not in nperio_valid_range :
659        raise Exception ( f'{nperio=} is not in the valid range {nperio_valid_range}' )
660
661    return jpj, jpi, nperio
662       
663def lbc (ptab, nperio=None, cd_type='T', psgn=1.0, nemo_4U_bug=False) :
664    '''
665    Set periodicity on input field
666    ptab      : Input array (works for rank 2 at least : ptab[...., lat, lon])
667    nperio    : Type of periodicity
668    cd_type   : Grid specification : T, U, V or F
669    psgn      : For change of sign for vector components (1 for scalars, -1 for vector components)
670   
671    See NEMO documentation for further details
672    '''
673    jpj, jpi, nperio = lbc_init (ptab, nperio)
674    ix, ax = __findAxis__ (ptab, 'x')
675    jy, ay = __findAxis__ (ptab, 'y')
676    psgn   = ptab.dtype.type (psgn)
677    mmath = __mmath__ (ptab)
678   
679    if mmath == xr : ztab = ptab.values.copy ()
680    else           : ztab = ptab.copy ()
681       
682    if ax : 
683        #
684        #> East-West boundary conditions
685        # ------------------------------
686        if nperio in [1, 4, 6] :
687        # ... cyclic
688            ztab [...,  0] = ztab [..., -2]
689            ztab [..., -1] = ztab [...,  1]
690
691        if ay : 
692            #
693            #> North-South boundary conditions
694            # --------------------------------
695            if nperio in [3, 4] :  # North fold T-point pivot
696                if cd_type in [ 'T', 'W' ] : # T-, W-point
697                    ztab [..., -1, 1:       ] = psgn * ztab [..., -3, -1:0:-1      ]
698                    ztab [..., -1, 0        ] = psgn * ztab [..., -3, 2            ]
699                    ztab [..., -2, jpi//2:  ] = psgn * ztab [..., -2, jpi//2:0:-1  ]
700                   
701                if cd_type == 'U' :
702                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -3, -1:0:-1      ]       
703                    ztab [..., -1,  0       ] = psgn * ztab [..., -3,  1           ]
704                    ztab [..., -1, -1       ] = psgn * ztab [..., -3, -2           ]
705                   
706                    if nemo_4U_bug :
707                        ztab [..., -2, jpi//2+1:-1] = psgn * ztab [..., -2, jpi//2-2:0:-1]
708                        ztab [..., -2, jpi//2-1   ] = psgn * ztab [..., -2, jpi//2       ]
709                    else :
710                        ztab [..., -2, jpi//2-1:-1] = psgn * ztab [..., -2, jpi//2:0:-1]
711                       
712                if cd_type == 'V' : 
713                    ztab [..., -2, 1:       ] = psgn * ztab [..., -3, jpi-1:0:-1   ]
714                    ztab [..., -1, 1:       ] = psgn * ztab [..., -4, -1:0:-1      ]   
715                    ztab [..., -1, 0        ] = psgn * ztab [..., -4, 2            ]
716                   
717                if cd_type == 'F' :
718                    ztab [..., -2, 0:-1     ] = psgn * ztab [..., -3, -1:0:-1      ]
719                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -4, -1:0:-1      ]
720                    ztab [..., -1,  0       ] = psgn * ztab [..., -4,  1           ]
721                    ztab [..., -1, -1       ] = psgn * ztab [..., -4, -2           ]
722               
723            if nperio in [4.2] :  # North fold T-point pivot
724                if cd_type in [ 'T', 'W' ] : # T-, W-point
725                    ztab [..., -1, jpi//2:  ] = psgn * ztab [..., -1, jpi//2:0:-1  ]
726                   
727                if cd_type == 'U' :
728                    ztab [..., -1, jpi//2-1:-1] = psgn * ztab [..., -1, jpi//2:0:-1]
729                   
730                if cd_type == 'V' : 
731                    ztab [..., -1, 1:       ] = psgn * ztab [..., -2, jpi-1:0:-1   ]
732                   
733                if cd_type == 'F' :
734                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -2, -1:0:-1      ]
735               
736            if nperio in [5, 6] :            #  North fold F-point pivot 
737                if cd_type in ['T', 'W']  :
738                    ztab [..., -1, 0:       ] = psgn * ztab [..., -2, -1::-1       ]
739                   
740                if cd_type == 'U' :
741                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -2, -2::-1       ]       
742                    ztab [..., -1, -1       ] = psgn * ztab [..., -2, 0            ] # Bug ?
743                   
744                if cd_type == 'V' :
745                    ztab [..., -1, 0:       ] = psgn * ztab [..., -3, -1::-1       ]
746                    ztab [..., -2, jpi//2:  ] = psgn * ztab [..., -2, jpi//2-1::-1 ]
747                   
748                if cd_type == 'F' :
749                    ztab [..., -1, 0:-1     ] = psgn * ztab [..., -3, -2::-1       ]
750                    ztab [..., -1, -1       ] = psgn * ztab [..., -3, 0            ]
751                    ztab [..., -2, jpi//2:-1] = psgn * ztab [..., -2, jpi//2-2::-1 ]
752                   
753            #
754            #> East-West boundary conditions
755            # ------------------------------
756            if nperio in [1, 4, 6] :
757                # ... cyclic
758                ztab [...,  0] = ztab [..., -2]
759                ztab [..., -1] = ztab [...,  1]
760
761    if mmath == xr :
762        ztab = xr.DataArray ( ztab, dims=ptab.dims, coords=ptab.coords )
763        ztab.attrs = ptab.attrs
764       
765    return ztab
766
767def lbc_mask (ptab, nperio=None, cd_type='T', sval=np.nan) :
768    #
769    '''
770    Mask fields on duplicated points
771    ptab      : Input array. Rank 2 at least : ptab [...., lat, lon]
772    nperio    : Type of periodicity
773    cd_type   : Grid specification : T, U, V or F
774   
775    See NEMO documentation for further details
776    '''
777    jpj, jpi, nperio = lbc_init (ptab, nperio)
778    ix, ax = __findAxis__ (ptab, 'x')
779    jy, ay = __findAxis__ (ptab, 'y')
780    ztab = ptab.copy ()
781
782    if ax : 
783        #
784        #> East-West boundary conditions
785        # ------------------------------
786        if nperio in [1, 4, 6] :
787        # ... cyclic
788            ztab [...,  0] = sval
789            ztab [..., -1] = sval
790
791        if ay : 
792            #
793            #> South (in which nperio cases ?)
794            # --------------------------------
795            if nperio in [1, 3, 4, 5, 6] :
796                ztab [..., 0, :] = sval
797       
798            #
799            #> North-South boundary conditions
800            # --------------------------------
801            if nperio in [3, 4] :  # North fold T-point pivot
802                if cd_type in [ 'T', 'W' ] : # T-, W-point
803                    ztab [..., -1,  :         ] = sval
804                    ztab [..., -2, :jpi//2  ] = sval
805               
806                if cd_type == 'U' :
807                    ztab [..., -1,  :         ] = sval 
808                    ztab [..., -2, jpi//2+1:  ] = sval
809               
810                if cd_type == 'V' :
811                    ztab [..., -2, :       ] = sval
812                    ztab [..., -1, :       ] = sval   
813
814                if cd_type == 'F' :
815                    ztab [..., -2, :       ] = sval
816                    ztab [..., -1, :       ] = sval
817
818            if nperio in [4.2] :  # North fold T-point pivot
819                if cd_type in [ 'T', 'W' ] : # T-, W-point
820                    ztab [..., -1, jpi//2  :  ] = sval
821
822                if cd_type == 'U' :
823                    ztab [..., -1, jpi//2-1:-1] = sval
824
825                if cd_type == 'V' : 
826                    ztab [..., -1, 1:       ] = sval
827
828                if cd_type == 'F' :
829                    ztab [..., -1, 0:-1     ] = sval
830
831            if nperio in [5, 6] :            #  North fold F-point pivot
832                if cd_type in ['T', 'W']  :
833                    ztab [..., -1, 0:       ] = sval
834
835                if cd_type == 'U' :
836                    ztab [..., -1, 0:-1     ] = sval       
837                    ztab [..., -1, -1       ] = sval
838
839                if cd_type == 'V' :
840                    ztab [..., -1, 0:       ] = sval
841                    ztab [..., -2, jpi//2:  ] = sval
842
843                if cd_type == 'F' :
844                    ztab [..., -1, 0:-1       ] = sval
845                    ztab [..., -1, -1         ] = sval
846                    ztab [..., -2, jpi//2+1:-1] = sval
847
848    return ztab
849
850def lbc_plot (ptab, nperio=None, cd_type='T', psgn=1.0, sval=np.nan) :
851    '''
852    Set periodicity on input field, adapted for plotting for any cartopy projection
853    ptab      : Input array. Rank 2 at least : ptab[...., lat, lon]
854    nperio    : Type of periodicity
855    cd_type   : Grid specification : T, U, V or F
856    psgn      : For change of sign for vector components (1 for scalars, -1 for vector components)
857   
858    See NEMO documentation for further details
859    '''
860    jpj, jpi, nperio = lbc_init (ptab, nperio)
861    ix, ax = __findAxis__ (ptab, 'x')
862    jy, ay = __findAxis__ (ptab, 'y')
863    psgn   = ptab.dtype.type (psgn)
864    ztab   = ptab.copy ()
865
866    if ax : 
867        #
868        #> East-West boundary conditions
869        # ------------------------------
870        if nperio in [1, 4, 6] :
871            # ... cyclic
872            ztab [..., :,  0] = ztab [..., :, -2]
873            ztab [..., :, -1] = ztab [..., :,  1]
874
875        if ay : 
876            #> Masks south
877            # ------------
878            if nperio in [4, 6] : ztab [..., 0, : ] = sval
879
880            #
881            #> North-South boundary conditions
882            # --------------------------------
883            if nperio in [3, 4] :  # North fold T-point pivot
884                if cd_type in [ 'T', 'W' ] : # T-, W-point
885                    ztab [..., -1,  :      ] = sval
886                    #ztab [..., -2, jpi//2: ] = sval
887                    ztab [..., -2, :jpi//2 ] = sval # Give better plots than above
888                if cd_type == 'U' :
889                    ztab [..., -1, : ] = sval
890
891                if cd_type == 'V' : 
892                    ztab [..., -2, : ] = sval
893                    ztab [..., -1, : ] = sval
894
895                if cd_type == 'F' :
896                    ztab [..., -2, : ] = sval
897                    ztab [..., -1, : ] = sval
898
899            if nperio in [4.2] :  # North fold T-point pivot
900                if cd_type in [ 'T', 'W' ] : # T-, W-point
901                    ztab [..., -1, jpi//2:  ] = sval
902
903                if cd_type == 'U' :
904                    ztab [..., -1, jpi//2-1:-1] = sval
905
906                if cd_type == 'V' : 
907                    ztab [..., -1, 1:       ] = sval
908
909                if cd_type == 'F' :
910                    ztab [..., -1, 0:-1     ] = sval
911
912            if nperio in [5, 6] :            #  North fold F-point pivot 
913                if cd_type in ['T', 'W']  :
914                    ztab [..., -1, : ] = sval
915
916                if cd_type == 'U' :
917                    ztab [..., -1, : ] = sval     
918
919                if cd_type == 'V' :
920                    ztab [..., -1, :        ] = sval
921                    ztab [..., -2, jpi//2:  ] = sval
922
923                if cd_type == 'F' :
924                    ztab [..., -1, :          ] = sval
925                    ztab [..., -2, jpi//2+1:-1] = sval
926
927    return ztab
928
929def lbc_add (ptab, nperio=None, cd_type=None, psgn=1, sval=None) :
930    '''
931    Handles NEMO domain changes between NEMO 4.0 to NEMO 4.2
932      Peridodicity halo has been removed
933    This routine adds the halos if needed
934
935    ptab      : Input array (works
936      rank 2 at least : ptab[...., lat, lon]
937    nperio    : Type of periodicity
938 
939    See NEMO documentation for further details
940    '''
941    mmath = __mmath__ (ptab) 
942    jpj, jpi, nperio = lbc_init (ptab, nperio)
943    lshape = get_shape (ptab)
944    ix, ax = __findAxis__ (ptab, 'x')
945    jy, ay = __findAxis__ (ptab, 'y')
946
947    t_shape = np.array (ptab.shape)
948
949    if nperio == 4.2 or nperio == 6.2 :
950     
951        ext_shape = t_shape.copy()
952        if 'X' in lshape : ext_shape[ix] = ext_shape[ix] + 2
953        if 'Y' in lshape : ext_shape[jy] = ext_shape[jy] + 1
954
955        if mmath == xr :
956            ptab_ext = xr.DataArray (np.zeros (ext_shape), dims=ptab.dims)
957            if 'X' in lshape and 'Y' in lshape :
958                ptab_ext.values[..., :-1, 1:-1] = ptab.values.copy ()
959            else :
960                if 'X' in lshape     : ptab_ext.values[...,      1:-1] = ptab.values.copy ()
961                if 'Y' in lshape     : ptab_ext.values[..., :-1      ] = ptab.values.copy ()
962        else           :
963            ptab_ext =               np.zeros (ext_shape)
964            if 'X' in lshape and 'Y' in lshape : ptab_ext       [..., :-1, 1:-1] = ptab.copy ()
965            else :
966                if 'X' in lshape     : ptab_ext       [...,      1:-1] = ptab.copy ()
967                if 'Y' in lshape     : ptab_ext       [..., :-1      ] = ptab.copy ()           
968
969        if nperio == 4.2 : ptab_ext = lbc (ptab_ext, nperio=4, cd_type=cd_type, psgn=psgn)
970        if nperio == 6.2 : ptab_ext = lbc (ptab_ext, nperio=6, cd_type=cd_type, psgn=psgn)
971       
972        if mmath == xr :
973            ptab_ext.attrs = ptab.attrs
974            kz, az = __findAxis__ (ptab, 'z')
975            it, at = __findAxis__ (ptab, 't')
976            if az : ptab_ext = ptab_ext.assign_coords ( {az:ptab.coords[az]} )
977            if at : ptab_ext = ptab_ext.assign_coords ( {at:ptab.coords[at]} )
978
979    else : ptab_ext = lbc (ptab, nperio=nperio, cd_type=cd_type, psgn=psgn)
980       
981    return ptab_ext
982
983def lbc_del (ptab, nperio=None, cd_type='T', psgn=1) :
984    '''
985    Handles NEMO domain changes between NEMO 4.0 to NEMO 4.2
986      Periodicity halo has been removed
987    This routine removes the halos if needed
988
989    ptab      : Input array (works
990      rank 2 at least : ptab[...., lat, lon]
991    nperio    : Type of periodicity
992 
993    See NEMO documentation for further details
994    '''
995    jpj, jpi, nperio = lbc_init (ptab, nperio)
996    lshape = get_shape (ptab)
997    ix, ax = __findAxis__ (ptab, 'x')
998    jy, ay = __findAxis__ (ptab, 'y')
999
1000    if nperio == 4.2 or nperio == 6.2 :
1001        if ax or ay : 
1002            if ax and ay : 
1003                return lbc (ptab[..., :-1, 1:-1], nperio=nperio, cd_type=cd_type, psgn=psgn)
1004            else : 
1005                if ax :
1006                    return lbc (ptab[...,      1:-1], nperio=nperio, cd_type=cd_type, psgn=psgn)
1007                if ay :
1008                    return lbc (ptab[..., -1], nperio=nperio, cd_type=cd_type, psgn=psgn)
1009        else :
1010            return ptab
1011    else :
1012        return ptab
1013
1014def lbc_index (jj, ii, jpj, jpi, nperio=None, cd_type='T') :
1015    '''
1016    For indexes of a NEMO point, give the corresponding point inside the util domain
1017    jj, ii    : indexes
1018    jpi, jpi  : size of domain
1019    nperio    : type of periodicity
1020    cd_type   : grid specification : T, U, V or F
1021   
1022    See NEMO documentation for further details
1023    '''
1024
1025    if nperio == None : nperio = __guessNperio__ (jpj, jpi, nperio)
1026   
1027    ## For the sake of simplicity, switch to the convention of original lbc Fortran routine from NEMO
1028    ## : starts indexes at 1
1029    jy = jj + 1 ; ix = ii + 1
1030
1031    mmath = __mmath__ (jj)
1032    if mmath == None : mmath=np
1033
1034    #
1035    #> East-West boundary conditions
1036    # ------------------------------
1037    if nperio in [1, 4, 6] :
1038        #... cyclic
1039        ix = mmath.where (ix==jpi, 2   , ix)
1040        ix = mmath.where (ix== 1 ,jpi-1, ix)
1041
1042    #
1043    def modIJ (cond, jy_new, ix_new) :
1044        jy_r = mmath.where (cond, jy_new, jy)
1045        ix_r = mmath.where (cond, ix_new, ix)
1046        return jy_r, ix_r
1047    #
1048    #> North-South boundary conditions
1049    # --------------------------------
1050    if nperio in [ 3 , 4 ]  :
1051        if cd_type in  [ 'T' , 'W' ] :
1052            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix>=2       ), jpj-2, jpi-ix+2)
1053            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==1       ), jpj-1, 3       )   
1054            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, ix>=jpi//2+1), jy   , jpi-ix+2) 
1055
1056        if cd_type in [ 'U' ] :
1057            (jy, ix) = modIJ (np.logical_and (jy==jpj  , np.logical_and (ix>=1, ix <= jpi-1)   ), jy   , jpi-ix+1)
1058            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==1  )                               , jpj-2, 2       )
1059            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==jpi)                               , jpj-2, jpi-1   )
1060            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, np.logical_and (ix>=jpi//2, ix<=jpi-1)), jy   , jpi-ix+1)
1061         
1062        if cd_type in [ 'V' ] :
1063            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, ix>=2  ), jpj-2, jpi-ix+2)
1064            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix>=2  ), jpj-3, jpi-ix+2)
1065            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==1  ), jpj-3,  3      )
1066           
1067        if cd_type in [ 'F' ] :
1068            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, ix<=jpi-1), jpj-2, jpi-ix+1)
1069            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix<=jpi-1), jpj-3, jpi-ix+1)
1070            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==1    ), jpj-3, 2       )
1071            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==jpi  ), jpj-3, jpi-1   )
1072
1073    if nperio in [ 5 , 6 ] :
1074        if cd_type in [ 'T' , 'W' ] :                        # T-, W-point
1075             (jy, ix) = modIJ (jy==jpj, jpj-1, jpi-ix+1)
1076 
1077        if cd_type in [ 'U' ] :                              # U-point
1078            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix<=jpi-1   ), jpj-1, jpi-ix  )
1079            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix==jpi     ), jpi-1, 1       )
1080           
1081        if cd_type in [ 'V' ] :    # V-point
1082            (jy, ix) = modIJ (jy==jpj                                 , jy   , jpi-ix+1)
1083            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, ix>=jpi//2+1), jy   , jpi-ix+1)
1084           
1085        if cd_type in [ 'F' ] :                              # F-point
1086            (jy, ix) = modIJ (np.logical_and (jy==jpj  , ix<=jpi-1   ), jpj-2, jpi-ix  )
1087            (jy, ix) = modIJ (np.logical_and (ix==jpj  , ix==jpi     ), jpj-2, 1       )
1088            (jy, ix) = modIJ (np.logical_and (jy==jpj-1, ix>=jpi//2+1), jy   , jpi-ix  )
1089
1090    ## Restore convention to Python/C : indexes start at 0
1091    jy += -1 ; ix += -1
1092
1093    if isinstance (jj, int) : jy = jy.item ()
1094    if isinstance (ii, int) : ix = ix.item ()
1095
1096    return jy, ix
1097   
1098def findJI (lat_data, lon_data, lat_grid, lon_grid, mask=1.0, verbose=False, out=None) :
1099    '''
1100    Description: seeks J,I indices of the grid point which is the closest of a given point
1101    Usage: go FindJI  <data latitude> <data longitude> <grid latitudes> <grid longitudes> [mask]
1102    <longitude fields> <latitude field> are 2D fields on J/I (Y/X) dimensions
1103    mask : if given, seek only non masked grid points (i.e with mask=1)
1104   
1105    Example : findIJ (40, -20, nav_lat, nav_lon, mask=1.0)
1106
1107    Note : all longitudes and latitudes in degrees
1108       
1109    Note : may work with 1D lon/lat (?)
1110    '''
1111    # Get grid dimensions
1112    if len (lon_grid.shape) == 2 : (jpj, jpi) = lon_grid.shape
1113    else                         : jpj = len(lat_grid) ; jpi=len(lon_grid)
1114
1115    mmath = __mmath__ (lat_grid)
1116       
1117    # Compute distance from the point to all grid points (in radian)
1118    arg      = np.sin (rad*lat_data) * np.sin (rad*lat_grid) \
1119             + np.cos (rad*lat_data) * np.cos (rad*lat_grid) * np.cos(rad*(lon_data-lon_grid))
1120    distance = np.arccos (arg) + 4.0*rpi*(1.0-mask) # Send masked points to 'infinite'
1121
1122    # Truncates to alleviate some precision problem with some grids
1123    prec = int (1E7)
1124    distance = (distance*prec).astype(int) / prec
1125
1126    # Compute minimum of distance, and index of minimum
1127    #
1128    distance_min = distance.min    ()
1129    jimin        = int (distance.argmin ())
1130   
1131    # Compute 2D indices
1132    jmin = jimin // jpi ; imin = jimin - jmin*jpi
1133   
1134    # Result
1135    if verbose :
1136        # Compute distance achieved
1137        mindist = distance [jmin, imin]
1138       
1139        # Compute azimuth
1140        dlon = lon_data-lon_grid[jmin,imin]
1141        arg  = np.sin (rad*dlon) /  (np.cos(rad*lat_data)*np.tan(rad*lat_grid[jmin,imin]) - np.sin(rad*lat_data)*np.cos(rad*dlon))
1142        azimuth = dar*np.arctan (arg)
1143        print ( f'I={imin:d} J={jmin:d} - Data:{lat_data:5.1f}°N {lon_data:5.1f}°E - Grid:{lat_grid[jmin,imin]:4.1f}°N '   \
1144            +   f'{lon_grid[jmin,imin]:4.1f}°E - Dist: {ra*distance[jmin,imin]:6.1f}km {dar*distance[jmin,imin]:5.2f}° ' \
1145            +   f'- Azimuth: {rad*azimuth:3.2f}rad - {azimuth:5.1f}°' )
1146
1147    if   out=='dict'                               : return {'x':imin, 'y':jmin}
1148    elif out=='array' or out=='numpy'  or out=='np': return np.array ( [jmin, imin] )
1149    elif out=='xarray' or out=='xr'                : return xr.DataArray ( [jmin, imin] )
1150    elif out=='list'                               : return [jmin, imin]
1151    elif out=='tuple'                              : return jmin, imin
1152    else                                           : return jmin, imin
1153
1154def geo2en (pxx, pyy, pzz, glam, gphi) : 
1155    '''
1156    Change vector from geocentric to east/north
1157
1158    Inputs :
1159        pxx, pyy, pzz : components on the geocentric system
1160        glam, gphi : longitude and latitude of the points
1161    '''
1162
1163    gsinlon = np.sin (rad * glam)
1164    gcoslon = np.cos (rad * glam)
1165    gsinlat = np.sin (rad * gphi)
1166    gcoslat = np.cos (rad * gphi)
1167         
1168    pte = - pxx * gsinlon            + pyy * gcoslon
1169    ptn = - pxx * gcoslon * gsinlat  - pyy * gsinlon * gsinlat + pzz * gcoslat
1170
1171    return pte, ptn
1172
1173def en2geo (pte, ptn, glam, gphi) :
1174    '''
1175    Change vector from east/north to geocentric
1176
1177    Inputs :
1178        pte, ptn   : eastward/northward components
1179        glam, gphi : longitude and latitude of the points
1180    '''
1181   
1182    gsinlon = np.sin (rad * glam)
1183    gcoslon = np.cos (rad * glam)
1184    gsinlat = np.sin (rad * gphi)
1185    gcoslat = np.cos (rad * gphi)
1186
1187    pxx = - pte * gsinlon - ptn * gcoslon * gsinlat
1188    pyy =   pte * gcoslon - ptn * gsinlon * gsinlat
1189    pzz =   ptn * gcoslat
1190   
1191    return pxx, pyy, pzz
1192
1193
1194def clo_lon (lon, lon0=0., rad=False, deg=True) :
1195    '''Choose closest to lon0 longitude, adding or substacting 360° if needed'''
1196    mmath = __mmath__ (lon, np)
1197    if rad : lon_range = 2.*np.pi
1198    if deg : lon_range = 360.
1199    clo_lon = lon
1200    clo_lon = mmath.where (clo_lon > lon0 + lon_range*0.5, clo_lon-lon_range, clo_lon)
1201    clo_lon = mmath.where (clo_lon < lon0 - lon_range*0.5, clo_lon+lon_range, clo_lon)
1202    clo_lon = mmath.where (clo_lon > lon0 + lon_range*0.5, clo_lon-lon_range, clo_lon)
1203    clo_lon = mmath.where (clo_lon < lon0 - lon_range*0.5, clo_lon+lon_range, clo_lon)
1204    if clo_lon.shape == () : clo_lon = clo_lon.item ()
1205    if mmath == xr :
1206        try : 
1207            for attr in lon.attrs : clo_lon.attrs[attr] = lon.attrs[attr]
1208        except :
1209            pass
1210    return clo_lon
1211
1212
1213def index2depth (pk, gdept_0) :
1214    '''
1215    From index (real, continuous), get depth
1216    '''
1217    jpk = gdept_0.shape[0]
1218    kk = xr.DataArray(pk)
1219    k  = np.maximum (0, np.minimum (jpk-1, kk    ))
1220    k0 = np.floor (k).astype (int)
1221    k1 = np.maximum (0, np.minimum (jpk-1,  k0+1))
1222    zz = k - k0
1223    gz = (1.0-zz)*gdept_0[k0]+ zz*gdept_0[k1]
1224    return gz.values
1225
1226def depth2index (pz, gdept_0) :
1227    '''
1228    From depth, get index (real, continuous)
1229    '''
1230    jpk  = gdept_0.shape[0]
1231    if type (pz) == xr.core.dataarray.DataArray :
1232        zz   = xr.DataArray (pz.values, dims=('zz',))
1233    elif type (pz) == np.ndarray :
1234        zz   = xr.DataArray (pz.ravel(), dims=('zz',))
1235    else :
1236        zz   = xr.DataArray (np.array([pz]).ravel(), dims=('zz',))
1237    zz   = np.minimum (gdept_0[-1], np.maximum (0, zz))
1238   
1239    idk1 = np.minimum ( (gdept_0-zz), 0.).argmax (axis=0).astype(int)
1240    idk1 = np.maximum (0, np.minimum (jpk-1,  idk1  ))
1241    idk2 = np.maximum (0, np.minimum (jpk-1,  idk1-1))
1242   
1243    ff = (zz - gdept_0[idk2])/(gdept_0[idk1]-gdept_0[idk2])
1244    idk = ff*idk1 + (1.0-ff)*idk2
1245    idk = xr.where ( np.isnan(idk), idk1, idk)
1246    return idk.values
1247
1248def index2depth_panels (pk, gdept_0, depth0, fact) :
1249    '''
1250    From  index (real, continuous), get depth, with bottom part compressed
1251    '''
1252    jpk = gdept_0.shape[0]
1253    kk = xr.DataArray (pk)
1254    k  = np.maximum (0, np.minimum (jpk-1, kk    ))
1255    k0 = np.floor (k).astype (int)
1256    k1 = np.maximum (0, np.minimum (jpk-1,  k0+1))
1257    zz = k - k0
1258    gz = (1.0-zz)*gdept_0[k0]+ zz*gdept_0[k1]
1259    gz = xr.where ( gz<depth0, gz, depth0 + (gz-depth0)*fact)
1260    return gz.values
1261
1262def depth2index_panels (pz, gdept_0, depth0, fact) :
1263    '''
1264    From  index (real, continuous), get depth, with bottom part compressed
1265    '''
1266    jpk = gdept_0.shape[0]
1267    if type (pz) == xr.core.dataarray.DataArray :
1268        zz   = xr.DataArray (pz.values , dims=('zz',))
1269    elif type (pz) == np.ndarray :
1270        zz   = xr.DataArray (pz.ravel(), dims=('zz',))
1271    else : 
1272        zz   = xr.DataArray (np.array([pz]).ravel(), dims=('zz',))
1273    zz         = np.minimum (gdept_0[-1], np.maximum (0, zz))
1274    gdept_comp = xr.where ( gdept_0>depth0, (gdept_0-depth0)*fact+depth0, gdept_0)
1275    zz_comp    = xr.where ( zz     >depth0, (zz     -depth0)*fact+depth0, zz     )
1276    zz_comp    = np.minimum (gdept_comp[-1], np.maximum (0, zz_comp))
1277
1278    idk1 = np.minimum ( (gdept_0-zz_comp), 0.).argmax (axis=0).astype(int)
1279    idk1 = np.maximum (0, np.minimum (jpk-1,  idk1  ))
1280    idk2 = np.maximum (0, np.minimum (jpk-1,  idk1-1))
1281     
1282    ff = (zz_comp - gdept_0[idk2])/(gdept_0[idk1]-gdept_0[idk2])
1283    idk = ff*idk1 + (1.0-ff)*idk2
1284    idk = xr.where ( np.isnan(idk), idk1, idk)
1285    return idk.values
1286
1287def depth2comp (pz, depth0, fact ) :
1288    '''
1289    Form depth, get compressed depth, with bottom part compressed
1290    '''
1291    #print ('start depth2comp')
1292    if type (pz) == xr.core.dataarray.DataArray :
1293        zz   = pz.values
1294    elif type(pz) == list :
1295        zz = np.array (pz)
1296    else : 
1297        zz   = pz
1298    gz = np.where ( zz>depth0, (zz-depth0)*fact+depth0, zz)
1299    #print ( f'depth2comp : {gz=}' )
1300    if type (pz) in [int, float] : return gz.item()
1301    else : return gz
1302    #print ('end depth2comp')
1303
1304def comp2depth (pz, depth0, fact ) :
1305    '''
1306    Form compressed depth, get depth, with bottom part compressed
1307    '''
1308    if type (pz) == xr.core.dataarray.DataArray :
1309        zz   = pz.values
1310    elif type(pz) == list :
1311        zz = np.array (pz)
1312    else : 
1313        zz   = pz
1314    gz = np.where ( zz>depth0, (zz-depth0)/fact+depth0, zz)
1315    if type (pz) in [int, float] : return gz.item()
1316    else : return gz
1317
1318def index2lon (pi, lon1D) :
1319    '''
1320    From index (real, continuous), get longitude
1321    '''
1322    jpi = lon1D.shape[0]
1323    ii = xr.DataArray (pi)
1324    i =  np.maximum (0, np.minimum (jpi-1, ii    ))
1325    i0 = np.floor (i).astype (int)
1326    i1 = np.maximum (0, np.minimum (jpi-1,  i0+1))
1327    xx = i - i0
1328    gx = (1.0-xx)*lon1D[i0]+ xx*lon1D[i1]
1329    return gx.values
1330
1331def lon2index (px, lon1D) :
1332    '''
1333    From longitude, get index (real, continuous)
1334    '''
1335    jpi  = lon1D.shape[0]
1336    if type (px) == xr.core.dataarray.DataArray :
1337        xx   = xr.DataArray (px.values , dims=('xx',))
1338    elif type (px) == np.ndarray :
1339        xx   = xr.DataArray (px.ravel(), dims=('xx',))
1340    else : 
1341        xx   = xr.DataArray (np.array([px]).ravel(), dims=('xx',))
1342    xx   = xr.where ( xx>lon1D.max(), xx-360.0, xx)
1343    xx   = xr.where ( xx<lon1D.min(), xx+360.0, xx)
1344    xx   = np.minimum (lon1D.max(), np.maximum(xx, lon1D.min() ))
1345    idi1 = np.minimum ( (lon1D-xx), 0.).argmax (axis=0).astype(int)
1346    idi1 = np.maximum (0, np.minimum (jpi-1,  idi1  ))
1347    idi2 = np.maximum (0, np.minimum (jpi-1,  idi1-1))
1348   
1349    ff = (xx - lon1D[idi2])/(lon1D[idi1]-lon1D[idi2])
1350    idi = ff*idi1 + (1.0-ff)*idi2
1351    idi = xr.where ( np.isnan(idi), idi1, idi)
1352    return idi.values
1353
1354def index2lat (pj, lat1D) :
1355    '''
1356    From index (real, continuous), get latitude
1357    '''
1358    jpj = lat1D.shape[0]
1359    jj  = xr.DataArray (pj)
1360    j   = np.maximum (0, np.minimum (jpj-1, jj    ))
1361    j0  = np.floor (j).astype (int)
1362    j1  = np.maximum (0, np.minimum (jpj-1,  j0+1))
1363    yy  = j - j0
1364    gy  = (1.0-yy)*lat1D[j0]+ yy*lat1D[j1]
1365    return gy.values
1366
1367def lat2index (py, lat1D) :
1368    '''
1369    From latitude, get index (real, continuous)
1370    '''
1371    jpj = lat1D.shape[0]
1372    if type (py) == xr.core.dataarray.DataArray :
1373        yy   = xr.DataArray (py.values , dims=('yy',))
1374    elif type (py) == np.ndarray :
1375        yy   = xr.DataArray (py.ravel(), dims=('yy',))
1376    else : 
1377        yy   = xr.DataArray (np.array([py]).ravel(), dims=('yy',))
1378    yy   = np.minimum (lat1D.max(), np.minimum(yy, lat1D.max() ))
1379    idj1 = np.minimum ( (lat1D-yy), 0.).argmax (axis=0).astype(int)
1380    idj1 = np.maximum (0, np.minimum (jpj-1,  idj1  ))
1381    idj2 = np.maximum (0, np.minimum (jpj-1,  idj1-1))
1382   
1383    ff = (yy - lat1D[idj2])/(lat1D[idj1]-lat1D[idj2])
1384    idj = ff*idj1 + (1.0-ff)*idj2
1385    idj = xr.where ( np.isnan(idj), idj1, idj)
1386    return idj.values
1387
1388def angle_full (glamt, gphit, glamu, gphiu, glamv, gphiv, glamf, gphif, nperio=None) :
1389    '''Compute sinus and cosinus of model line direction with respect to east'''
1390    mmath = __mmath__ (glamt)
1391
1392    zlamt = lbc_add (glamt, nperio, 'T', 1.)
1393    zphit = lbc_add (gphit, nperio, 'T', 1.)
1394    zlamu = lbc_add (glamu, nperio, 'U', 1.)
1395    zphiu = lbc_add (gphiu, nperio, 'U', 1.)
1396    zlamv = lbc_add (glamv, nperio, 'V', 1.)
1397    zphiv = lbc_add (gphiv, nperio, 'V', 1.)
1398    zlamf = lbc_add (glamf, nperio, 'F', 1.)
1399    zphif = lbc_add (gphif, nperio, 'F', 1.)
1400   
1401    # north pole direction & modulous (at T-point)
1402    zxnpt = 0. - 2.0 * np.cos (rad*zlamt) * np.tan (rpi/4.0 - rad*zphit/2.0)
1403    zynpt = 0. - 2.0 * np.sin (rad*zlamt) * np.tan (rpi/4.0 - rad*zphit/2.0)
1404    znnpt = zxnpt*zxnpt + zynpt*zynpt
1405   
1406    # north pole direction & modulous (at U-point)
1407    zxnpu = 0. - 2.0 * np.cos (rad*zlamu) * np.tan (rpi/4.0 - rad*zphiu/2.0)
1408    zynpu = 0. - 2.0 * np.sin (rad*zlamu) * np.tan (rpi/4.0 - rad*zphiu/2.0)
1409    znnpu = zxnpu*zxnpu + zynpu*zynpu
1410   
1411    # north pole direction & modulous (at V-point)
1412    zxnpv = 0. - 2.0 * np.cos (rad*zlamv) * np.tan (rpi/4.0 - rad*zphiv/2.0)
1413    zynpv = 0. - 2.0 * np.sin (rad*zlamv) * np.tan (rpi/4.0 - rad*zphiv/2.0)
1414    znnpv = zxnpv*zxnpv + zynpv*zynpv
1415
1416    # north pole direction & modulous (at F-point)
1417    zxnpf = 0. - 2.0 * np.cos( rad*zlamf ) * np.tan ( rpi/4. - rad*zphif/2. )
1418    zynpf = 0. - 2.0 * np.sin( rad*zlamf ) * np.tan ( rpi/4. - rad*zphif/2. )
1419    znnpf = zxnpf*zxnpf + zynpf*zynpf
1420
1421    # j-direction: v-point segment direction (around T-point)
1422    zlam = zlamv 
1423    zphi = zphiv
1424    zlan = np.roll ( zlamv, axis=-2, shift=1)  # glamv (ji,jj-1)
1425    zphh = np.roll ( zphiv, axis=-2, shift=1)  # gphiv (ji,jj-1)
1426    zxvvt =  2.0 * np.cos ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1427          -  2.0 * np.cos ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1428    zyvvt =  2.0 * np.sin ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1429          -  2.0 * np.sin ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1430    znvvt = np.sqrt ( znnpt * ( zxvvt*zxvvt + zyvvt*zyvvt )  )
1431
1432    # j-direction: f-point segment direction (around u-point)
1433    zlam = zlamf
1434    zphi = zphif
1435    zlan = np.roll (zlamf, axis=-2, shift=1) # glamf (ji,jj-1)
1436    zphh = np.roll (zphif, axis=-2, shift=1) # gphif (ji,jj-1)
1437    zxffu =  2.0 * np.cos ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1438          -  2.0 * np.cos ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1439    zyffu =  2.0 * np.sin ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1440          -  2.0 * np.sin ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1441    znffu = np.sqrt ( znnpu * ( zxffu*zxffu + zyffu*zyffu )  )
1442
1443    # i-direction: f-point segment direction (around v-point)
1444    zlam = zlamf 
1445    zphi = zphif
1446    zlan = np.roll (zlamf, axis=-1, shift=1) # glamf (ji-1,jj)
1447    zphh = np.roll (zphif, axis=-1, shift=1) # gphif (ji-1,jj)
1448    zxffv =  2.0 * np.cos ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1449          -  2.0 * np.cos ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1450    zyffv =  2.0 * np.sin ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1451          -  2.0 * np.sin ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1452    znffv = np.sqrt ( znnpv * ( zxffv*zxffv + zyffv*zyffv )  )
1453
1454    # j-direction: u-point segment direction (around f-point)
1455    zlam = np.roll (zlamu, axis=-2, shift=-1) # glamu (ji,jj+1)
1456    zphi = np.roll (zphiu, axis=-2, shift=-1) # gphiu (ji,jj+1)
1457    zlan = zlamu
1458    zphh = zphiu
1459    zxuuf =  2. * np.cos ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1460          -  2. * np.cos ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1461    zyuuf =  2. * np.sin ( rad*zlam ) * np.tan ( rpi/4. - rad*zphi/2. )   \
1462          -  2. * np.sin ( rad*zlan ) * np.tan ( rpi/4. - rad*zphh/2. )
1463    znuuf = np.sqrt ( znnpf * ( zxuuf*zxuuf + zyuuf*zyuuf )  )
1464
1465   
1466    # cosinus and sinus using scalar and vectorial products
1467    gsint = ( zxnpt*zyvvt - zynpt*zxvvt ) / znvvt
1468    gcost = ( zxnpt*zxvvt + zynpt*zyvvt ) / znvvt
1469   
1470    gsinu = ( zxnpu*zyffu - zynpu*zxffu ) / znffu
1471    gcosu = ( zxnpu*zxffu + zynpu*zyffu ) / znffu
1472   
1473    gsinf = ( zxnpf*zyuuf - zynpf*zxuuf ) / znuuf
1474    gcosf = ( zxnpf*zxuuf + zynpf*zyuuf ) / znuuf
1475   
1476    gsinv = ( zxnpv*zxffv + zynpv*zyffv ) / znffv
1477    gcosv =-( zxnpv*zyffv - zynpv*zxffv ) / znffv  # (caution, rotation of 90 degres)
1478   
1479    #gsint = lbc (gsint, cd_type='T', nperio=nperio, psgn=-1.)
1480    #gcost = lbc (gcost, cd_type='T', nperio=nperio, psgn=-1.)
1481    #gsinu = lbc (gsinu, cd_type='U', nperio=nperio, psgn=-1.)
1482    #gcosu = lbc (gcosu, cd_type='U', nperio=nperio, psgn=-1.)
1483    #gsinv = lbc (gsinv, cd_type='V', nperio=nperio, psgn=-1.)
1484    #gcosv = lbc (gcosv, cd_type='V', nperio=nperio, psgn=-1.)
1485    #gsinf = lbc (gsinf, cd_type='F', nperio=nperio, psgn=-1.)
1486    #gcosf = lbc (gcosf, cd_type='F', nperio=nperio, psgn=-1.)
1487
1488    gsint = lbc_del (gsint, cd_type='T', nperio=nperio, psgn=-1.)
1489    gcost = lbc_del (gcost, cd_type='T', nperio=nperio, psgn=-1.)
1490    gsinu = lbc_del (gsinu, cd_type='U', nperio=nperio, psgn=-1.)
1491    gcosu = lbc_del (gcosu, cd_type='U', nperio=nperio, psgn=-1.)
1492    gsinv = lbc_del (gsinv, cd_type='V', nperio=nperio, psgn=-1.)
1493    gcosv = lbc_del (gcosv, cd_type='V', nperio=nperio, psgn=-1.)
1494    gsinf = lbc_del (gsinf, cd_type='F', nperio=nperio, psgn=-1.)
1495    gcosf = lbc_del (gcosf, cd_type='F', nperio=nperio, psgn=-1.)
1496
1497    if mmath == xr :
1498        gsint = gsint.assign_coords ( glamt.coords )
1499        gcost = gcost.assign_coords ( glamt.coords )
1500        gsinu = gsinu.assign_coords ( glamu.coords )
1501        gcosu = gcosu.assign_coords ( glamu.coords )
1502        gsinv = gsinv.assign_coords ( glamv.coords )
1503        gcosv = gcosv.assign_coords ( glamv.coords )
1504        gsinf = gsinf.assign_coords ( glamf.coords )
1505        gcosf = gcosf.assign_coords ( glamf.coords )
1506
1507    return gsint, gcost, gsinu, gcosu, gsinv, gcosv, gsinf, gcosf
1508
1509def angle (glam, gphi, nperio, cd_type='T') :
1510    '''Compute sinus and cosinus of model line direction with respect to east'''
1511    mmath = __mmath__ (glam)
1512
1513    zlam = lbc_add (glam, nperio, cd_type, 1.)
1514    zphi = lbc_add (gphi, nperio, cd_type, 1.)
1515   
1516    # north pole direction & modulous
1517    zxnp = 0. - 2.0 * np.cos (rad*zlam) * np.tan (rpi/4.0 - rad*zphi/2.0)
1518    zynp = 0. - 2.0 * np.sin (rad*zlam) * np.tan (rpi/4.0 - rad*zphi/2.0)
1519    znnp = zxnp*zxnp + zynp*zynp
1520
1521    # j-direction: segment direction (around point)
1522    zlan_n = np.roll (zlam, axis=-2, shift=-1) # glam [jj+1, ji]
1523    zphh_n = np.roll (zphi, axis=-2, shift=-1) # gphi [jj+1, ji]
1524    zlan_s = np.roll (zlam, axis=-2, shift= 1) # glam [jj-1, ji]
1525    zphh_s = np.roll (zphi, axis=-2, shift= 1) # gphi [jj-1, ji]
1526   
1527    zxff = 2.0 * np.cos (rad*zlan_n) * np.tan (rpi/4.0 - rad*zphh_n/2.0) \
1528        -  2.0 * np.cos (rad*zlan_s) * np.tan (rpi/4.0 - rad*zphh_s/2.0)
1529    zyff = 2.0 * np.sin (rad*zlan_n) * np.tan (rpi/4.0 - rad*zphh_n/2.0) \
1530        -  2.0 * np.sin (rad*zlan_s) * np.tan (rpi/4.0 - rad*zphh_s/2.0)
1531    znff = np.sqrt (znnp * (zxff*zxff + zyff*zyff) )
1532 
1533    gsin = (zxnp*zyff - zynp*zxff) / znff
1534    gcos = (zxnp*zxff + zynp*zyff) / znff
1535
1536    gsin = lbc_del (gsin, cd_type=cd_type, nperio=nperio, psgn=-1.)
1537    gcos = lbc_del (gcos, cd_type=cd_type, nperio=nperio, psgn=-1.)
1538
1539    if mmath == xr :
1540        gsin = gsin.assign_coords ( glam.coords )
1541        gcos = gcos.assign_coords ( glam.coords )
1542       
1543    return gsin, gcos
1544
1545def rot_en2ij ( u_e, v_n, gsin, gcos, nperio, cd_type ) :
1546    '''
1547    ** Purpose :   Rotate the Repere: Change vector componantes between
1548    geographic grid --> stretched coordinates grid.
1549    All components are on the same grid (T, U, V or F)
1550    '''
1551
1552    u_i = + u_e * gcos + v_n * gsin
1553    v_j = - u_e * gsin + v_n * gcos
1554   
1555    u_i = lbc (u_i, nperio=nperio, cd_type=cd_type, psgn=-1.0)
1556    v_j = lbc (v_j, nperio=nperio, cd_type=cd_type, psgn=-1.0)
1557   
1558    return u_i, v_j
1559
1560def rot_ij2en ( u_i, v_j, gsin, gcos, nperio, cd_type='T' ) :
1561    '''
1562    ** Purpose :   Rotate the Repere: Change vector componantes from
1563    stretched coordinates grid --> geographic grid
1564    All components are on the same grid (T, U, V or F)
1565    '''
1566    u_e = + u_i * gcos - v_j * gsin
1567    v_n = + u_i * gsin + v_j * gcos
1568   
1569    u_e = lbc (u_e, nperio=nperio, cd_type=cd_type, psgn=1.0)
1570    v_n = lbc (v_n, nperio=nperio, cd_type=cd_type, psgn=1.0)
1571   
1572    return u_e, v_n
1573
1574def rot_uv2en ( uo, vo, gsint, gcost, nperio, zdim=None ) :
1575    '''
1576    ** Purpose :   Rotate the Repere: Change vector componantes from
1577    stretched coordinates grid --> geographic grid
1578    uo is on the U grid point, vo is on the V grid point
1579    east-north components on the T grid point   
1580    '''
1581    mmath = __mmath__ (uo)
1582   
1583    ut = U2T (uo, nperio=nperio, psgn=-1.0, zdim=zdim)
1584    vt = V2T (vo, nperio=nperio, psgn=-1.0, zdim=zdim)
1585   
1586    u_e = + ut * gcost - vt * gsint
1587    v_n = + ut * gsint + vt * gcost
1588
1589    u_e = lbc (u_e, nperio=nperio, cd_type='T', psgn=1.0)
1590    v_n = lbc (v_n, nperio=nperio, cd_type='T', psgn=1.0)
1591   
1592    return u_e, v_n
1593
1594def rot_uv2enF ( uo, vo, gsinf, gcosf, nperio, zdim=None ) :
1595    '''
1596    ** Purpose : Rotate the Repere: Change vector componantes from
1597    stretched coordinates grid --> geographic grid
1598    uo is on the U grid point, vo is on the V grid point
1599    east-north components on the T grid point   
1600    '''
1601    mmath = __mmath__ (uo)
1602
1603    uf = U2F (uo, nperio=nperio, psgn=-1.0, zdim=zdim)
1604    vf = V2F (vo, nperio=nperio, psgn=-1.0, zdim=zdim)
1605   
1606    u_e = + uf * gcosf - vf * gsinf
1607    v_n = + uf * gsinf + vf * gcosf
1608
1609    u_e = lbc (u_e, nperio=nperio, cd_type='F', psgn= 1.0)
1610    v_n = lbc (v_n, nperio=nperio, cd_type='F', psgn= 1.0)
1611   
1612    return u_e, v_n
1613
1614def U2T (utab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
1615    '''Interpolate an array from U grid to T grid i-mean)'''
1616    mmath = __mmath__ (utab)
1617    utab_0 = mmath.where ( np.isnan(utab), 0., utab)
1618    lperio, aperio = lbc_diag (nperio)
1619    utab_0 = lbc_add (utab_0, nperio=nperio, cd_type='U', psgn=psgn)
1620    ix, ax = __findAxis__ (utab_0, 'x')
1621    kz, az = __findAxis__ (utab_0, 'z')
1622
1623    if ax : 
1624        if action == 'ave' : ttab = 0.5 *      (utab_0 + np.roll (utab_0, axis=ix, shift=1))
1625        if action == 'min' : ttab = np.minimum (utab_0 , np.roll (utab_0, axis=ix, shift=1))
1626        if action == 'max' : ttab = np.maximum (utab_0 , np.roll (utab_0, axis=ix, shift=1))
1627        if action == 'mult': ttab =             utab_0 * np.roll (utab_0, axis=ix, shift=1)
1628        ttab = lbc_del (ttab  , nperio=nperio, cd_type='T', psgn=psgn)
1629    else : 
1630        ttab = lbc_del (utab_0, nperio=nperio, cd_type='T', psgn=psgn)
1631       
1632    if mmath == xr :
1633        if ax :
1634            ttab = ttab.assign_coords({ax:np.arange (ttab.shape[ix])+1.})
1635        if zdim and az :
1636            if az != zdim : ttab = ttab.rename( {az:zdim}) 
1637    return ttab
1638
1639def V2T (vtab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
1640    '''Interpolate an array from V grid to T grid (j-mean)'''
1641    mmath = __mmath__ (vtab)
1642    lperio, aperio = lbc_diag (nperio)
1643    vtab_0 = mmath.where ( np.isnan(vtab), 0., vtab)
1644    vtab_0 = lbc_add (vtab_0, nperio=nperio, cd_type='V', psgn=psgn)
1645    jy, ay = __findAxis__ (vtab_0, 'y')
1646    kz, az = __findAxis__ (vtab_0, 'z')
1647    if ay : 
1648        if action == 'ave'  : ttab = 0.5 *      (vtab_0 + np.roll (vtab_0, axis=jy, shift=1))
1649        if action == 'min'  : ttab = np.minimum (vtab_0 , np.roll (vtab_0, axis=jy, shift=1))
1650        if action == 'max'  : ttab = np.maximum (vtab_0 , np.roll (vtab_0, axis=jy, shift=1))
1651        if action == 'mult' : ttab =             vtab_0 * np.roll (vtab_0, axis=jy, shift=1)
1652        ttab = lbc_del (ttab  , nperio=nperio, cd_type='T', psgn=psgn)
1653    else :
1654        ttab = lbc_del (vtab_0, nperio=nperio, cd_type='T', psgn=psgn)
1655       
1656    if mmath == xr :
1657        if ay :
1658            ttab = ttab.assign_coords({ay:np.arange(ttab.shape[jy])+1.})
1659        if zdim and az :
1660            if az != zdim : ttab = ttab.rename( {az:zdim}) 
1661    return ttab
1662
1663def F2T (ftab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1664    '''Interpolate an array from F grid to T grid (i- and j- means)'''
1665    mmath = __mmath__ (ftab)
1666    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
1667    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
1668    ttab = V2T (F2V (ftab_0, nperio=nperio, psgn=psgn, zdim=zdim, action=action), nperio=nperio, psgn=psgn, zdim=zdim, action=action)
1669    return lbc_del (ttab, nperio=nperio, cd_type='T', psgn=psgn)
1670
1671def T2U (ttab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1672    '''Interpolate an array from T grid to U grid (i-mean)'''
1673    mmath = __mmath__ (ttab)
1674    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1675    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
1676    ix, ax = __findAxis__ (ttab_0, 'x')
1677    kz, az = __findAxis__ (ttab_0, 'z')
1678    if ix : 
1679        if action == 'ave'  : utab = 0.5 *      (ttab_0 + np.roll (ttab_0, axis=ix, shift=-1))
1680        if action == 'min'  : utab = np.minimum (ttab_0 , np.roll (ttab_0, axis=ix, shift=-1))
1681        if action == 'max'  : utab = np.maximum (ttab_0 , np.roll (ttab_0, axis=ix, shift=-1))
1682        if action == 'mult' : utab =             ttab_0 * np.roll (ttab_0, axis=ix, shift=-1)
1683        utab = lbc_del (utab  , nperio=nperio, cd_type='U', psgn=psgn)
1684    else :
1685        utab = lbc_del (ttab_0, nperio=nperio, cd_type='U', psgn=psgn)
1686       
1687    if mmath == xr :   
1688        if ax : 
1689            utab = ttab.assign_coords({ax:np.arange(utab.shape[ix])+1.})
1690        if zdim and az :
1691            if az != zdim : utab = utab.rename( {az:zdim}) 
1692    return utab
1693
1694def T2V (ttab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1695    '''Interpolate an array from T grid to V grid (j-mean)'''
1696    mmath = __mmath__ (ttab)
1697    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1698    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
1699    jy, ay = __findAxis__ (ttab_0, 'y')
1700    kz, az = __findAxis__ (ttab_0, 'z')
1701    if jy : 
1702        if action == 'ave'  : vtab = 0.5 *      (ttab_0 + np.roll (ttab_0, axis=jy, shift=-1))
1703        if action == 'min'  : vtab = np.minimum (ttab_0 , np.roll (ttab_0, axis=jy, shift=-1))
1704        if action == 'max'  : vtab = np.maximum (ttab_0 , np.roll (ttab_0, axis=jy, shift=-1))
1705        if action == 'mult' : vtab =             ttab_0 * np.roll (ttab_0, axis=jy, shift=-1)
1706        vtab = lbc_del (vtab  , nperio=nperio, cd_type='V', psgn=psgn)
1707    else :
1708        vtab = lbc_del (ttab_0, nperio=nperio, cd_type='V', psgn=psgn)
1709
1710    if mmath == xr :
1711        if ay : 
1712            vtab = vtab.assign_coords({ay:np.arange(vtab.shape[jy])+1.})
1713        if zdim and az :
1714            if az != zdim : vtab = vtab.rename( {az:zdim}) 
1715    return vtab
1716
1717def V2F (vtab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
1718    '''Interpolate an array from V grid to F grid (i-mean)'''
1719    mmath = __mmath__ (vtab)
1720    vtab_0 = mmath.where ( np.isnan(vtab), 0., vtab)
1721    vtab_0 = lbc_add (vtab_0 , nperio=nperio, cd_type='V', psgn=psgn)
1722    ix, ax = __findAxis__ (vtab_0, 'x')
1723    kz, az = __findAxis__ (vtab_0, 'z')
1724    if ix : 
1725        if action == 'ave'  : 0.5 *      (vtab_0 + np.roll (vtab_0, axis=ix, shift=-1))
1726        if action == 'min'  : np.minimum (vtab_0 , np.roll (vtab_0, axis=ix, shift=-1))
1727        if action == 'max'  : np.maximum (vtab_0 , np.roll (vtab_0, axis=ix, shift=-1))
1728        if action == 'mult' :             vtab_0 * np.roll (vtab_0, axis=ix, shift=-1)
1729        ftab = lbc_del (ftab  , nperio=nperio, cd_type='F', psgn=psgn)
1730    else :
1731         ftab = lbc_del (vtab_0, nperio=nperio, cd_type='F', psgn=psgn)
1732   
1733    if mmath == xr :
1734        if ax : 
1735            ftab = ftab.assign_coords({ax:np.arange(ftab.shape[ix])+1.})
1736        if zdim and az :
1737            if az != zdim : ftab = ftab.rename( {az:zdim}) 
1738    return lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
1739
1740def U2F (utab, nperio=None, psgn=-1.0, zdim=None, action='ave') :
1741    '''Interpolate an array from U grid to F grid i-mean)'''
1742    mmath = __mmath__ (utab)
1743    utab_0 = mmath.where ( np.isnan(utab), 0., utab)
1744    utab_0 = lbc_add (utab_0 , nperio=nperio, cd_type='U', psgn=psgn)
1745    jy, ay = __findAxis__ (utab_0, 'y')
1746    kz, az = __findAxis__ (utab_0, 'z')
1747    if jy : 
1748        if action == 'ave'  :    ftab = 0.5 *      (utab_0 + np.roll (utab_0, axis=jy, shift=-1))
1749        if action == 'min'  :    ftab = np.minimum (utab_0 , np.roll (utab_0, axis=jy, shift=-1))
1750        if action == 'max'  :    ftab = np.maximum (utab_0 , np.roll (utab_0, axis=jy, shift=-1))
1751        if action == 'mult' :    ftab =             utab_0 * np.roll (utab_0, axis=jy, shift=-1)
1752        ftab = lbc_del (ftab  , nperio=nperio, cd_type='F', psgn=psgn)
1753    else :
1754        ftab = lbc_del (utab_0, nperio=nperio, cd_type='F', psgn=psgn)
1755 
1756    if mmath == xr :
1757        if ay : 
1758            ftab = ftab.assign_coords({'y':np.arange(ftab.shape[jy])+1.})
1759        if zdim and az :
1760            if az != zdim : ftab = ftab.rename( {az:zdim}) 
1761    return ftab
1762
1763def F2T (ftab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1764    '''Interpolate an array on F grid to T grid (i- and j- means)'''
1765    mmath = __mmath__ (ftab)
1766    ftab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1767    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
1768    ttab = U2T(F2U(ftab_0, nperio=nperio, psgn=psgn, zdim=zdim, action=action), nperio=nperio, psgn=psgn, zdim=zdim, action=action)
1769    return lbc_del (ttab, nperio=nperio, cd_type='T', psgn=psgn)
1770
1771def T2F (ttab, nperio=None, psgn=1.0, zdim=None, action='mean') :
1772    '''Interpolate an array on T grid to F grid (i- and j- means)'''
1773    mmath = __mmath__ (ttab)
1774    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1775    ttab_0 = lbc_add (ttab_0 , nperio=nperio, cd_type='T', psgn=psgn)
1776    ftab = T2U (U2F (ttab, nperio=nperio, psgn=psgn, zdim=zdim, action=action), nperio=nperio, psgn=psgn, zdim=zdim, action=action)
1777   
1778    return lbc_del (ftab, nperio=nperio, cd_type='F', psgn=psgn)
1779
1780def F2U (ftab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1781    '''Interpolate an array on F grid to FUgrid (i-mean)'''
1782    mmath = __mmath__ (ftab)
1783    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
1784    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
1785    jy, ay = __findAxis__ (ftab_0, 'y')
1786    kz, az = __findAxis__ (ftab_0, 'z')
1787    if jy : 
1788        if action == 'ave'  : utab = 0.5 *      (ftab_0 + np.roll (ftab_0, axis=jy, shift=-1))
1789        if action == 'min'  : utab = np.minimum (ftab_0 , np.roll (ftab_0, axis=jy, shift=-1))
1790        if action == 'max'  : utab = np.maximum (ftab_0 , np.roll (ftab_0, axis=jy, shift=-1))
1791        if action == 'mult' : utab =             ftab_0 * np.roll (ftab_0, axis=jy, shift=-1)
1792        utab = lbc_del (utab  , nperio=nperio, cd_type='U', psgn=psgn)
1793    else :
1794        utab = lbc_del (ftab_0, nperio=nperio, cd_type='U', psgn=psgn)
1795
1796    if mmath == xr :
1797        utab = utab.assign_coords({ay:np.arange(ftab.shape[jy])+1.})
1798        if zdim and zz :
1799            if az != zdim : utab = utab.rename( {az:zdim}) 
1800    return utab
1801
1802def F2V (ftab, nperio=None, psgn=1.0, zdim=None, action='ave') :
1803    '''Interpolate an array from F grid to V grid (i-mean)'''
1804    mmath = __mmath__ (ftab)
1805    ftab_0 = mmath.where ( np.isnan(ftab), 0., ftab)
1806    ftab_0 = lbc_add (ftab_0 , nperio=nperio, cd_type='F', psgn=psgn)
1807    ix, ax = __findAxis__ (ftab_0, 'x')
1808    kz, az = __findAxis__ (ftab_0, 'z')
1809    if ix : 
1810        if action == 'ave'  : vtab = 0.5 *      (ftab_0 + np.roll (ftab_0, axis=ix, shift=-1))
1811        if action == 'min'  : vtab = np.minimum (ftab_0 , np.roll (ftab_0, axis=ix, shift=-1))
1812        if action == 'max'  : vtab = np.maximum (ftab_0 , np.roll (ftab_0, axis=ix, shift=-1))
1813        if action == 'mult' : vtab =             ftab_0 * np.roll (ftab_0, axis=ix, shift=-1)
1814        vtab = lbc_del (vtab  , nperio=nperio, cd_type='V', psgn=psgn)
1815    else : 
1816        vtab = lbc_del (ftab_0, nperio=nperio, cd_type='V', psgn=psgn)
1817
1818    if mmath == xr :
1819        vtab = vtab.assign_coords({ax:np.arange(ftab.shape[ix])+1.})
1820        if zdim and az :
1821            if az != zdim : vtab = vtab.rename( {az:zdim}) 
1822    return vtab
1823
1824def W2T (wtab, zcoord=None, zdim=None, sval=np.nan) :
1825    '''
1826    Interpolate an array on W grid to T grid (k-mean)
1827    sval is the bottom value
1828    '''
1829    mmath = __mmath__ (wtab)
1830    wtab_0 = mmath.where ( np.isnan(wtab), 0., wtab)
1831
1832    kz, az = __findAxis__ (wtab_0, 'z')
1833
1834    if kz : 
1835        ttab = 0.5 * ( wtab_0 + np.roll (wtab_0, axis=kz, shift=-1) )
1836    else :
1837        ttab = wtab_0
1838
1839    if mmath == xr :
1840        ttab[{az:kz}] = sval
1841        if zdim and az :
1842            if az != zdim : ttab = ttab.rename ( {az:zdim} )
1843        if zcoord is not None :
1844            ttab = ttab.assign_coords ( {zdim:zcoord} )
1845    else :
1846        ttab[..., -1, :, :] = sval
1847
1848    return ttab
1849
1850def T2W (ttab, zcoord=None, zdim=None, sval=np.nan, extrap_surf=False) :
1851    '''Interpolate an array from T grid to W grid (k-mean)
1852    sval is the surface value
1853    if extrap_surf==True, surface value is taken from 1st level value.
1854    '''
1855    mmath = __mmath__ (ttab)
1856    ttab_0 = mmath.where ( np.isnan(ttab), 0., ttab)
1857    kz, az = __findAxis__ (ttab_0, 'z')
1858    wtab = 0.5 * ( ttab_0 + np.roll (ttab_0, axis=kz, shift=1) )
1859
1860    if mmath == xr :
1861        if extrap_surf : wtab[{az:0}] = ttabb[{az:0}]
1862        else           : wtab[{az:0}] = sval
1863    else : 
1864        if extrap_surf : wtab[..., 0, :, :] = ttab[..., 0, :, :]
1865        else           : wtab[..., 0, :, :] = sval
1866
1867    if mmath == xr :
1868        if zdim and az :
1869            if az != zdim : wtab = wtab.rename ( {az:zdim})
1870        if zcoord is not None :
1871            wtab = wtab.assign_coords ( {zdim:zcoord})
1872        else :
1873            ztab = wtab.assign_coords ( {zdim:np.arange(ttab.shape[kz])+1.} )
1874    return wtab
1875
1876def fill (ptab, nperio, cd_type='T', npass=1, sval=0.) :
1877    '''
1878    Fill sval values with mean of neighbours
1879   
1880    Inputs :
1881       ptab : input field to fill
1882       nperio, cd_type : periodicity characteristics
1883    '''       
1884
1885    mmath = __mmath__ (ptab)
1886
1887    DoPerio = False ; lperio = nperio
1888    if nperio == 4.2 :
1889        DoPerio = True ; lperio = 4
1890    if nperio == 6.2 :
1891        DoPerio = True ; lperio = 6
1892       
1893    if DoPerio :
1894        ztab = lbc_add (ptab, nperio=nperio, sval=sval)
1895    else : 
1896        ztab = ptab
1897       
1898    if np.isnan (sval) : 
1899        ztab   = mmath.where (np.isnan(ztab), np.nan, ztab)
1900    else :
1901        ztab   = mmath.where (ztab==sval    , np.nan, ztab)
1902   
1903    for nn in np.arange (npass) : 
1904        zmask = mmath.where ( np.isnan(ztab), 0., 1.   )
1905        ztab0 = mmath.where ( np.isnan(ztab), 0., ztab )
1906        # Compte du nombre de voisins
1907        zcount = 1./6. * ( zmask \
1908          + np.roll(zmask, shift=1, axis=-1) + np.roll(zmask, shift=-1, axis=-1) \
1909          + np.roll(zmask, shift=1, axis=-2) + np.roll(zmask, shift=-1, axis=-2) \
1910          + 0.5 * ( \
1911                + np.roll(np.roll(zmask, shift= 1, axis=-2), shift= 1, axis=-1) \
1912                + np.roll(np.roll(zmask, shift=-1, axis=-2), shift= 1, axis=-1) \
1913                + np.roll(np.roll(zmask, shift= 1, axis=-2), shift=-1, axis=-1) \
1914                + np.roll(np.roll(zmask, shift=-1, axis=-2), shift=-1, axis=-1) ) )
1915
1916        znew =1./6. * ( ztab0 \
1917           + np.roll(ztab0, shift=1, axis=-1) + np.roll(ztab0, shift=-1, axis=-1) \
1918           + np.roll(ztab0, shift=1, axis=-2) + np.roll(ztab0, shift=-1, axis=-2) \
1919           + 0.5 * ( \
1920                + np.roll(np.roll(ztab0 , shift= 1, axis=-2), shift= 1, axis=-1) \
1921                + np.roll(np.roll(ztab0 , shift=-1, axis=-2), shift= 1, axis=-1) \
1922                + np.roll(np.roll(ztab0 , shift= 1, axis=-2), shift=-1, axis=-1) \
1923                + np.roll(np.roll(ztab0 , shift=-1, axis=-2), shift=-1, axis=-1) ) )
1924
1925        zcount = lbc (zcount, nperio=lperio, cd_type=cd_type)
1926        znew   = lbc (znew  , nperio=lperio, cd_type=cd_type)
1927       
1928        ztab = mmath.where (np.logical_and (zmask==0., zcount>0), znew/zcount, ztab)
1929
1930    ztab = mmath.where (zcount==0, sval, ztab)
1931    if DoPerio : ztab = lbc_del (ztab, nperio=lperio)
1932
1933    return ztab
1934
1935def correct_uv (u, v, lat) :
1936    '''
1937    Correct a Cartopy bug in orthographic projection
1938
1939    See https://github.com/SciTools/cartopy/issues/1179
1940
1941    The correction is needed with cartopy <= 0.20
1942    It seems that version 0.21 will correct the bug (https://github.com/SciTools/cartopy/pull/1926)
1943    Later note : the bug is still present in Cartopy 0.22
1944
1945    Inputs :
1946       u, v : eastward/northward components
1947       lat  : latitude of the point (degrees north)
1948
1949    Outputs :
1950       modified eastward/nothward components to have correct polar projections in cartopy
1951    '''
1952    uv = np.sqrt (u*u + v*v)           # Original modulus
1953    zu = u
1954    zv = v * np.cos (rad*lat)
1955    zz = np.sqrt ( zu*zu + zv*zv )     # Corrected modulus
1956    uc = zu*uv/zz ; vc = zv*uv/zz      # Final corrected values
1957    return uc, vc
1958
1959def norm_uv (u, v) :
1960    '''
1961    Return norm of a 2 components vector
1962    '''
1963    return np.sqrt (u*u + v*v)
1964
1965def normalize_uv (u, v) :
1966    '''
1967    Normalize 2 components vector
1968    '''
1969    uv = norm_uv (u, v)
1970    return u/uv, v/uv
1971
1972def msf (vv, e1v_e3v, lat1d, depthw) :
1973    '''
1974    Computes the meridonal stream function
1975    First input is meridional_velocity*e1v*e3v
1976    '''
1977 
1978    v_e1v_e3v = vv * e1v_e3v
1979    v_e1v_e3v.attrs = vv.attrs
1980   
1981    ix, ax = __findAxis__ (v_e1v_e3v, 'x')
1982    kz, az = __findAxis__ (v_e1v_e3v, 'z')
1983    if az == 'olevel' : new_az = 'olevel'
1984    else              : new_az = 'depthw'
1985
1986    zomsf = -v_e1v_e3v.cumsum ( dim=az, keep_attrs=True).sum (dim=ax, keep_attrs=True)*1.E-6
1987    zomsf = zomsf - zomsf.isel ( { az:-1} )
1988   
1989    jy, ay = __findAxis__ (zomsf, 'y' )
1990    zomsf = zomsf.assign_coords ( {az:depthw.values, ay:lat1d.values})
1991   
1992    zomsf = zomsf.rename ( {ay:'lat'})
1993    if az != new_az : zomsf = zomsf.rename ( {az:new_az} )
1994    zomsf.attrs['standard_name'] = 'Meridional stream function'
1995    zomsf.attrs['long_name'] = 'Meridional stream function'
1996    zomsf.attrs['units'] = 'Sv'
1997    zomsf[new_az].attrs  = depthw.attrs
1998    zomsf.lat.attrs=lat1d.attrs
1999       
2000    return zomsf
2001
2002def bsf (uu, e2u_e3u, mask, nperio=None, bsf0=None ) :
2003    '''
2004    Computes the barotropic stream function
2005    First input is zonal_velocity*e2u*e3u
2006    bsf0 is the point with bsf=0
2007    (ex: bsf0={'x':3, 'y':120} for orca2,
2008         bsf0={'x':5, 'y':300} for eeORCA1
2009    '''
2010    u_e2u_e3u       = uu * e2u_e3u
2011    u_e2u_e3u.attrs = uu.attrs
2012
2013    iy, ay = __findAxis__ (u_e2u_e3u, 'y')
2014    kz, az = __findAxis__ (u_e2u_e3u, 'z')
2015   
2016    bsf = -u_e2u_e3u.cumsum ( dim=ay, keep_attrs=True )
2017    bsf = bsf.sum (dim=az, keep_attrs=True)*1.E-6
2018       
2019    if bsf0 :
2020        bsf = bsf - bsf.isel (bsf0)
2021       
2022    bsf = bsf.where (mask !=0, np.nan)
2023    for attr in uu.attrs :
2024        bsf.attrs[attr] = uu.attrs[attr]
2025    bsf.attrs['standard_name'] = 'barotropic_stream_function'
2026    bsf.attrs['long_name']     = 'Barotropic stream function'
2027    bsf.attrs['units']         = 'Sv'
2028    bsf = lbc (bsf, nperio=nperio, cd_type='F')
2029       
2030    return bsf
2031
2032def namelist_read (ref=None, cfg=None, out='dict', flat=False, verbose=False) :
2033    '''
2034    Read NEMO namelist(s) and return either a dictionnary or an xarray dataset
2035
2036    ref : file with reference namelist, or a f90nml.namelist.Namelist object
2037    cfg : file with config namelist, or a f90nml.namelist.Namelist object
2038    At least one namelist neaded
2039
2040    out:
2041        'dict' to return a dictonnary
2042        'xr'   to return an xarray dataset
2043    flat : only for dict output. Output a flat dictionnary with all values.
2044   
2045    '''
2046   
2047    import f90nml
2048    if ref :
2049        if isinstance (ref, str) : nml_ref = f90nml.read (ref)
2050        if isinstance (ref, f90nml.namelist.Namelist) : nml_ref = ref
2051       
2052    if cfg :
2053        if isinstance (cfg, str) : nml_cfg = f90nml.read (cfg)
2054        if isinstance (cfg, f90nml.namelist.Namelist) : nml_cfg = cfg
2055   
2056    if out == 'dict' : dict_namelist = {}
2057    if out == 'xr'   : xr_namelist = xr.Dataset ()
2058
2059    list_nml = [] ; list_comment = []
2060
2061    if ref : list_nml.append (nml_ref) ; list_comment.append ('ref')
2062    if cfg : list_nml.append (nml_cfg) ; list_comment.append ('cfg')
2063
2064    for nml, comment in zip (list_nml, list_comment) :
2065        if verbose : print (comment)
2066        if flat and out =='dict' :
2067            for nam in nml.keys () :
2068                if verbose : print (nam)
2069                for value in nml[nam] :
2070                     if out == 'dict' : dict_namelist[value] = nml[nam][value]
2071                     if verbose : print (nam, ':', value, ':', nml[nam][value])
2072        else :
2073            for nam in nml.keys () :
2074                if verbose : print (nam)
2075                if out == 'dict' :
2076                    if nam not in dict_namelist.keys () : dict_namelist[nam] = {}
2077                for value in nml[nam] :
2078                    if out == 'dict' : dict_namelist[nam][value] = nml[nam][value]
2079                    if out == 'xr'   : xr_namelist[value] = nml[nam][value]
2080                    if verbose : print (nam, ':', value, ':', nml[nam][value])
2081
2082    if out == 'dict' : return dict_namelist
2083    if out == 'xr'   : return xr_namelist
2084
2085def fill_closed_seas (imask, nperio=None,  cd_type='T') :
2086    '''Fill closed seas with image processing library
2087    imask : mask, 1 on ocean, 0 on land
2088    '''
2089    from scipy import ndimage
2090
2091    imask_filled = ndimage.binary_fill_holes ( lbc (imask, nperio=nperio, cd_type=cd_type))
2092    imask_filled = lbc ( imask_filled, nperio=nperio, cd_type=cd_type)
2093
2094    return imask_filled
2095
2096'''
2097Sea water state function parameters from NEMO code
2098'''
2099rdeltaS = 32. ; r1_S0  = 0.875/35.16504 ; r1_T0  = 1./40. ; r1_Z0  = 1.e-4
2100
2101EOS000 =  8.0189615746e+02 ; EOS100 =  8.6672408165e+02 ; EOS200 = -1.7864682637e+03 ; EOS300 =  2.0375295546e+03 ; EOS400 = -1.2849161071e+03 ; EOS500 =  4.3227585684e+02 ; EOS600 = -6.0579916612e+01
2102EOS010 =  2.6010145068e+01 ; EOS110 = -6.5281885265e+01 ; EOS210 =  8.1770425108e+01 ; EOS310 = -5.6888046321e+01 ; EOS410 =  1.7681814114e+01 ; EOS510 = -1.9193502195
2103EOS020 = -3.7074170417e+01 ; EOS120 =  6.1548258127e+01 ; EOS220 = -6.0362551501e+01 ; EOS320 =  2.9130021253e+01 ; EOS420 = -5.4723692739     ; EOS030 =  2.1661789529e+01 
2104EOS130 = -3.3449108469e+01 ; EOS230 =  1.9717078466e+01 ; EOS330 = -3.1742946532
2105EOS040 = -8.3627885467     ; EOS140 =  1.1311538584e+01 ; EOS240 = -5.3563304045
2106EOS050 =  5.4048723791e-01 ; EOS150 =  4.8169980163e-01
2107EOS060 = -1.9083568888e-01
2108EOS001 =  1.9681925209e+01 ; EOS101 = -4.2549998214e+01 ; EOS201 =  5.0774768218e+01 ; EOS301 = -3.0938076334e+01 ; EOS401 =   6.6051753097    ; EOS011 = -1.3336301113e+01
2109EOS111 = -4.4870114575     ; EOS211 =  5.0042598061     ; EOS311 = -6.5399043664e-01 ; EOS021 =  6.7080479603     ; EOS121 =   3.5063081279
2110EOS221 = -1.8795372996     ; EOS031 = -2.4649669534     ; EOS131 = -5.5077101279e-01 ; EOS041 =  5.5927935970e-01
2111EOS002 =  2.0660924175     ; EOS102 = -4.9527603989     ; EOS202 =  2.5019633244     ; EOS012 =  2.0564311499     ; EOS112 = -2.1311365518e-01 ; EOS022 = -1.2419983026
2112EOS003 = -2.3342758797e-02 ; EOS103 = -1.8507636718e-02 ; EOS013 =  3.7969820455e-01 
2113
2114def rhop ( ptemp, psal ) :
2115    '''
2116    Potential density referenced to surface
2117    Computation from NEMO code
2118    '''
2119    zt  = ptemp * r1_T0                                  # Temperature (°C)
2120    zs  = np.sqrt ( np.abs( psal + rdeltaS ) * r1_S0 )   # Square root of salinity (PSS)
2121    #
2122    prhop = (((((EOS060*zt   \
2123             + EOS150*zs     + EOS050)*zt   \
2124             + (EOS240*zs    + EOS140)*zs + EOS040)*zt   \
2125             + ((EOS330*zs   + EOS230)*zs + EOS130)*zs + EOS030)*zt   \
2126             + (((EOS420*zs  + EOS320)*zs + EOS220)*zs + EOS120)*zs + EOS020)*zt   \
2127             + ((((EOS510*zs + EOS410)*zs + EOS310)*zs + EOS210)*zs + EOS110)*zs + EOS010)*zt   \
2128             + (((((EOS600*zs+ EOS500)*zs + EOS400)*zs + EOS300)*zs + EOS200)*zs + EOS100)*zs + EOS000
2129    #
2130    return prhop
2131
2132def rho ( pdep, ptemp, psal ) :
2133    '''
2134    In situ density
2135    Computation from NEMO code
2136    '''
2137    zh  = pdep  * r1_Z0                                  # Depth (m)
2138    zt  = ptemp * r1_T0                                  # Temperature (°C)
2139    zs  = np.sqrt ( np.abs( psal + rdeltaS ) * r1_S0 )   # Square root salinity (PSS)
2140    #
2141    zn3 = EOS013*zt + EOS103*zs+EOS003
2142    #
2143    zn2 = (EOS022*zt + EOS112*zs+EOS012)*zt + (EOS202*zs+EOS102)*zs+EOS002
2144    #
2145    zn1 = (((EOS041*zt   \
2146         + EOS131*zs   + EOS031)*zt   \
2147         + (EOS221*zs  + EOS121)*zs + EOS021)*zt   \
2148        + ((EOS311*zs  + EOS211)*zs + EOS111)*zs + EOS011)*zt   \
2149        + (((EOS401*zs + EOS301)*zs + EOS201)*zs + EOS101)*zs + EOS001
2150    #
2151    zn0 = (((((EOS060*zt   \
2152             + EOS150*zs      + EOS050)*zt   \
2153             + (EOS240*zs     + EOS140)*zs + EOS040)*zt   \
2154             + ((EOS330*zs    + EOS230)*zs + EOS130)*zs + EOS030)*zt   \
2155             + (((EOS420*zs   + EOS320)*zs + EOS220)*zs + EOS120)*zs + EOS020)*zt   \
2156             + ((((EOS510*zs  + EOS410)*zs + EOS310)*zs + EOS210)*zs + EOS110)*zs + EOS010)*zt   \
2157             + (((((EOS600*zs + EOS500)*zs + EOS400)*zs + EOS300)*zs + EOS200)*zs + EOS100)*zs + EOS000
2158    #
2159    prho  = ( ( zn3 * zh + zn2 ) * zh + zn1 ) * zh + zn0
2160    #
2161    return prho
2162
2163## ===========================================================================
2164##
2165##                               That's all folk's !!!
2166##
2167## ===========================================================================
2168
2169def __is_orca_north_fold__ ( Xtest, cname_long='T' ) :
2170    '''
2171    Ported (pirated !!?) from Sosie
2172
2173    Tell if there is a 2/point band overlaping folding at the north pole typical of the ORCA grid
2174
2175    0 => not an orca grid (or unknown one)
2176    4 => North fold T-point pivot (ex: ORCA2)
2177    6 => North fold F-point pivot (ex: ORCA1)
2178
2179    We need all this 'cname_long' stuff because with our method, there is a
2180    confusion between "Grid_U with T-fold" and "Grid_V with F-fold"
2181    => so knowing the name of the longitude array (as in namelist, and hence as
2182    in netcdf file) might help taking the righ decision !!! UGLY!!!
2183    => not implemented yet
2184    '''
2185   
2186    ifld_nord =  0 ; cgrd_type = 'X'
2187    ny, nx = Xtest.shape[-2:]
2188
2189    if ny > 3 : # (case if called with a 1D array, ignoring...)
2190        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-1:nx-nx//2+1:-1] ).sum() == 0. :
2191          ifld_nord = 4 ; cgrd_type = 'T' # T-pivot, grid_T     
2192
2193        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-2:nx-nx//2  :-1] ).sum() == 0. :
2194            if cnlon == 'U' : ifld_nord = 4 ;  cgrd_type = 'U' # T-pivot, grid_T
2195                ## LOLO: PROBLEM == 6, V !!!
2196
2197        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-1:nx-nx//2+1:-1] ).sum() == 0. :
2198            ifld_nord = 4 ; cgrd_type = 'V' # T-pivot, grid_V
2199
2200        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-2, nx-1-1:nx-nx//2:-1] ).sum() == 0. :
2201            ifld_nord = 6 ; cgrd_type = 'T'# F-pivot, grid_T
2202
2203        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-1, nx-1:nx-nx//2-1:-1] ).sum() == 0. :
2204            ifld_nord = 6 ;  cgrd_type = 'U' # F-pivot, grid_U
2205
2206        if ( Xtest [ny-1, 1:nx//2-1] - Xtest [ny-3, nx-2:nx-nx//2  :-1] ).sum() == 0. :
2207            if cnlon == 'V' : ifld_nord = 6 ; cgrd_type = 'V' # F-pivot, grid_V
2208                ## LOLO: PROBLEM == 4, U !!!
2209
2210    return ifld_nord, cgrd_type
Note: See TracBrowser for help on using the repository browser.