source: XIOS/dev/dev_olga/src/extern/remap/py/remap_ECDYN.py @ 1022

Last change on this file since 1022 was 1022, checked in by mhnguyen, 7 years ago
File size: 12.7 KB
Line 
1import netCDF4 as nc
2import ctypes as ct
3import numpy as np
4import os
5import sys
6import math
7from mpi4py import MPI
8
9import time
10
11remap = ct.cdll.LoadLibrary(os.path.realpath('libmapper.so'))
12
13def from_mpas(filename):
14        # construct vortex bounds from Mpas grid structure
15        f = nc.Dataset(filename)
16        # in this case it is must faster to first read the whole file into memory
17        # before converting the data structure
18        print "read"
19        stime = time.time()
20        lon_vert = np.array(f.variables["lonVertex"])
21        lat_vert = np.array(f.variables["latVertex"])
22        vert_cell = np.array(f.variables["verticesOnCell"])
23        nvert_cell = np.array(f.variables["nEdgesOnCell"])
24        ncell, nvert = vert_cell.shape
25        assert(max(nvert_cell) <= nvert)
26        lon = np.zeros(vert_cell.shape)
27        lat = np.zeros(vert_cell.shape)
28        etime = time.time()
29        print "finished read, now convert", etime-stime
30        scal = 180.0/math.pi
31        for c in range(ncell):
32                lat[c,:] = lat_vert[vert_cell[c,:]-1]*scal
33                lon[c,:] = lon_vert[vert_cell[c,:]-1]*scal
34                # signal "last vertex" by netCDF convetion
35                lon[c,nvert_cell[c]] = lon[c,0]
36                lat[c,nvert_cell[c]] = lat[c,0]
37        print "convert end", time.time() - etime
38        return lon, lat
39
40grid_types = {
41        "dynamico:mesh": {
42                "lon_name": "bounds_lon_i",
43                "lat_name": "bounds_lat_i",
44                "pole": [0,0,0]
45        },
46        "dynamico:vort": {
47                "lon_name": "bounds_lon_v",
48                "lat_name": "bounds_lat_v",
49                "pole": [0,0,0]
50        },
51        "dynamico:restart": {
52                "lon_name": "lon_i_vertices",
53                "lat_name": "lat_i_vertices",
54                "pole": [0,0,0]
55        },
56        "test:polygon": {
57                "lon_name": "bounds_lon",
58                "lat_name": "bounds_lat",
59                "pole": [0,0,0]
60        },
61        "test:latlon": {
62                "lon_name": "bounds_lon",
63                "lat_name": "bounds_lat",
64                "pole": [0,0,1]
65        },
66        "mpas": {
67                "reader": from_mpas,
68                "pole": [0,0,0]
69        }
70}
71
72interp_types = {
73        "FV1": 1,
74        "FV2": 2
75}
76
77usage = """
78Usage: python remap.py interp srctype srcfile dsttype dstfile mode outfile
79
80   interp: type of interpolation
81       choices:
82           FV1: first order conservative Finite Volume
83           FV2: second order conservative Finite Volume
84
85   srctype, dsttype: grid type of source and destination
86       choices: """ + " ".join(grid_types.keys()) + """
87
88   srcfile, dstfile: grid file names, should mostly be netCDF file
89
90   mode: modus of operation
91       choices:
92           weights: computes weight and stores them in outfile
93           remap:   computes the interpolated values on destination grid and stores them in outfile
94
95   outfile: output filename
96
97"""
98
99# parse command line arguments
100if not len(sys.argv) == 8:
101        print usage
102        sys.exit(2)
103
104interp = sys.argv[1]
105try:
106        srctype = grid_types[sys.argv[2]]
107except KeyError:
108        print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
109        exit(2)
110srcfile = sys.argv[3]
111try:
112        dsttype = grid_types[sys.argv[4]]
113except KeyError:
114        print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
115        exit(2)
116dstfile = sys.argv[5]
117mode    = sys.argv[6]
118outfile = sys.argv[7]
119
120if not mode in ("weights", "remap"):
121        print "Error: mode must be of of the following: weights remap."
122        exit(2)
123
124remap.mpi_init()
125rank = remap.mpi_rank()
126size = remap.mpi_size()
127
128print rank, "/", size
129
130print "Reading grids from netCDF files."
131
132if "reader" in srctype:
133        src_lon, src_lat = srctype["reader"](srcfile)
134else:
135        src = nc.Dataset(srcfile)
136        # the following two lines do not perform the actual read
137        # the file is read later when assigning to the ctypes array
138        # -> no unnecessary array copying in memory
139        src_lon = src.variables[srctype["lon_name"]]
140        src_lat = src.variables[srctype["lat_name"]]
141
142if "reader" in dsttype:
143        dst_lon, dst_lat = dsttype["reader"](dstfile)
144else:
145        dst = nc.Dataset(dstfile)
146        dst_lon = dst.variables[dsttype["lon_name"]]
147        dst_lat = dst.variables[dsttype["lat_name"]]
148
149src_ncell, src_nvert = src_lon.shape
150dst_ncell, dst_nvert = dst_lon.shape
151
152def compute_distribution(ncell):
153        "Returns the local number and starting position in global array."
154        if rank < ncell % size:
155                return ncell//size + 1, \
156                       (ncell//size + 1)*rank
157        else:
158                return ncell//size, \
159                       (ncell//size + 1)*(ncell%size) + (ncell//size)*(rank - ncell%size)
160
161src_ncell_loc, src_loc_start = compute_distribution(src_ncell)
162dst_ncell_loc, dst_loc_start = compute_distribution(dst_ncell)
163
164print "src", src_ncell_loc, src_loc_start
165print "dst", dst_ncell_loc, dst_loc_start
166
167c_src_lon = (ct.c_double * (src_ncell_loc*src_nvert))()
168c_src_lat = (ct.c_double * (src_ncell_loc*src_nvert))()
169c_dst_lon = (ct.c_double * (dst_ncell_loc*dst_nvert))()
170c_dst_lat = (ct.c_double * (dst_ncell_loc*dst_nvert))()
171
172c_src_lon[:] = nc.numpy.reshape(src_lon[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
173c_src_lat[:] = nc.numpy.reshape(src_lat[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
174c_dst_lon[:] = nc.numpy.reshape(dst_lon[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
175c_dst_lat[:] = nc.numpy.reshape(dst_lat[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
176
177
178print "Calling remap library to compute weights."
179srcpole = (ct.c_double * (3))()
180dstpole = (ct.c_double * (3))()
181srcpole[:] = srctype["pole"]
182dstpole[:] = dsttype["pole"]
183
184c_src_ncell = ct.c_int(src_ncell_loc)
185c_src_nvert = ct.c_int(src_nvert)
186c_dst_ncell = ct.c_int(dst_ncell_loc)
187c_dst_nvert = ct.c_int(dst_nvert)
188order = ct.c_int(interp_types[interp])
189
190c_nweight = ct.c_int()
191
192print "src:", src_ncell, src_nvert
193print "dst:", dst_ncell, dst_nvert
194
195remap.remap_get_num_weights(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
196               c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
197               order, ct.byref(c_nweight))
198
199nwgt = c_nweight.value
200
201c_weights = (ct.c_double * nwgt)()
202c_dst_idx = (ct.c_int * nwgt)()
203c_src_idx = (ct.c_int * nwgt)()
204
205remap.remap_get_weights(c_weights, c_src_idx, c_dst_idx)
206
207wgt_glo     = MPI.COMM_WORLD.gather(c_weights[:])
208src_idx_glo = MPI.COMM_WORLD.gather(c_src_idx[:])
209dst_idx_glo = MPI.COMM_WORLD.gather(c_dst_idx[:])
210
211
212if rank == 0 and mode == 'weights':
213        nwgt_glo = sum(len(wgt) for wgt in wgt_glo)
214
215        print "Writing", nwgt_glo, "weights to netCDF-file '" + outfile + "'."
216        f = nc.Dataset(outfile,'w')
217        f.createDimension('n_src',    src_ncell)
218        f.createDimension('n_dst',    dst_ncell)
219        f.createDimension('n_weight', nwgt_glo)
220
221        var = f.createVariable('src_idx', 'i', ('n_weight'))
222        var[:] = np.hstack(src_idx_glo) + 1 # make indices start from 1
223        var = f.createVariable('dst_idx', 'i', ('n_weight'))
224        var[:] = np.hstack(dst_idx_glo) + 1 # make indices start from 1
225        var = f.createVariable('weight',  'd', ('n_weight'))
226        var[:] = np.hstack(wgt_glo)
227        f.close()
228
229def test_fun(x, y, z):
230        return (1-x**2)*(1-y**2)*z
231
232def test_fun_ll(lat, lon):
233        #return np.cos(lat*math.pi/180)*np.cos(lon*math.pi/180)
234        return 2.0 + np.cos(lat*math.pi/180.)**2 * np.cos(2*lon*math.pi/180.);
235
236#UNUSED
237#def sphe2cart(lat, lon):
238#       phi   = math.pi/180*lon[:]
239#       theta = math.pi/2 - math.pi/180*lat[:]
240#       return np.sin(theta)*np.cos(phi), np.sin(theta)*np.sin(phi), np.cos(theta)
241
242if mode == 'remap':
243        c_centre_lon = (ct.c_double * src_ncell_loc)()
244        c_centre_lat = (ct.c_double * src_ncell_loc)()
245        c_areas      = (ct.c_double * src_ncell_loc)()
246        remap.remap_get_barycentres_and_areas(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
247                c_centre_lon, c_centre_lat, c_areas)
248#       src_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
249#       src_val_loc = src.variables["ps"]
250#       src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
251#        src_val_glo = src_val_loc
252
253        c_centre_lon = (ct.c_double * dst_ncell_loc)()
254        c_centre_lat = (ct.c_double * dst_ncell_loc)()
255        c_areas      = (ct.c_double * dst_ncell_loc)()
256        remap.remap_get_barycentres_and_areas(c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
257                c_centre_lon, c_centre_lat, c_areas)
258#       dst_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
259
260#       dst_val_glo = MPI.COMM_WORLD.gather(dst_val_loc)
261        dst_areas_glo = MPI.COMM_WORLD.gather(np.array(c_areas[:]))
262        dst_centre_lon_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lon[:]))
263        dst_centre_lat_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lat[:]))
264
265
266if rank == 0 and mode == 'remap':
267        from scipy import sparse
268        A = sparse.csr_matrix(sparse.coo_matrix((np.hstack(wgt_glo),(np.hstack(dst_idx_glo),np.hstack(src_idx_glo)))))
269
270#       src_val = np.hstack(src_val_glo)
271#       dst_ref = np.hstack(dst_val_glo)
272        dst_areas = np.hstack(dst_areas_glo)
273        dst_centre_lon = np.hstack(dst_centre_lon_glo)
274        dst_centre_lat = np.hstack(dst_centre_lat_glo)
275
276#       print "source:", src_val.shape
277#       print "destin:", dst_ref.shape
278#       dst_val = A*src_val
279#       err = dst_val - dst_ref
280#       print "absolute maximum error, maximum value:", np.max(np.abs(err)), np.max(np.abs(dst_ref))
281#       print "relative maximum error, normalized L2 error, average target cell size (edgelength of same-area square):"
282#       print np.max(np.abs(err))/np.max(np.abs(dst_ref)), np.linalg.norm(err)/np.linalg.norm(dst_ref), np.mean(np.sqrt(dst_areas))
283
284        lev=src.dimensions['lev']
285        f = nc.Dataset(outfile,'w')
286        f.createDimension('nvert', dst_nvert)
287        f.createDimension('cell', dst_ncell)
288        f.createDimension('lev', len(lev))
289
290        var = f.createVariable('lat', 'd', ('cell'))
291        var.setncattr("long_name", "latitude")
292        var.setncattr("units", "degrees_north")
293        var.setncattr("bounds", "bounds_lat")
294        var[:] = dst_centre_lat
295        var = f.createVariable('lon', 'd', ('cell'))
296        var.setncattr("long_name", "longitude")
297        var.setncattr("units", "degrees_east")
298        var.setncattr("bounds", "bounds_lon")
299        var[:] = dst_centre_lon
300
301        var = f.createVariable('bounds_lon', 'd', ('cell','nvert'))
302        var[:] = dst_lon
303        var = f.createVariable('bounds_lat', 'd', ('cell','nvert'))
304        var[:] = dst_lat
305
306        var = f.createVariable('lev', 'd', ('lev'))
307        var[:] = src.variables['lev']
308        var.setncattr('axis', 'Z')
309        var.setncattr('units', 'Pa')
310        var.setncattr('positive', 'down')
311        var[:] = src.variables['lev']
312
313        U = f.createVariable('U', 'd', ('lev','cell'))
314        U.setncattr("coordinates", "lev lon lat")
315 
316        V = f.createVariable('V', 'd', ('lev','cell'))
317        V.setncattr("coordinates", "lev lon lat")
318
319        TEMP = f.createVariable('TEMP', 'd', ('lev','cell'))
320        TEMP.setncattr("coordinates", "lev lon lat")
321
322        R = f.createVariable('R', 'd', ('lev','cell'))
323        R.setncattr("coordinates", "lev lon lat")
324
325        Z = f.createVariable('Z', 'd', ('cell'))
326        Z.setncattr("coordinates", "lon lat")
327
328        ST = f.createVariable('ST', 'd', ('cell'))
329        ST.setncattr("coordinates", "lon lat")
330
331        CDSW = f.createVariable('CDSW', 'd', ('cell'))
332        CDSW.setncattr("coordinates", "lon lat")
333       
334        SP = f.createVariable('SP', 'd', ('cell'))
335        SP.setncattr("coordinates", "lon lat")
336       
337
338
339#for U
340if mode == 'remap':
341        src_val_loc = src.variables['U']
342
343for l in range(0, len(lev)):
344        if mode == 'remap':
345                src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
346
347        if rank == 0 and mode == 'remap':
348                src_val = np.hstack(src_val_glo)
349                dst_val = A*src_val
350                U[l,:] = dst_val
351
352#for V
353if mode == 'remap':
354        src_val_loc = src.variables['V']
355
356for l in range(0, len(lev)):
357        if mode == 'remap':
358                src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
359
360        if rank == 0 and mode == 'remap':
361                src_val = np.hstack(src_val_glo)
362                dst_val = A*src_val
363                V[l,:] = dst_val
364
365
366#for TEMP
367if mode == 'remap':
368        src_val_loc = src.variables['TEMP']
369
370for l in range(0, len(lev)):
371        if mode == 'remap':
372                src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
373
374        if rank == 0 and mode == 'remap':
375                src_val = np.hstack(src_val_glo)
376                dst_val = A*src_val
377                TEMP[l,:] = dst_val
378
379#for R
380if mode == 'remap':
381        src_val_loc = src.variables['R']
382
383for l in range(0, len(lev)):
384        if mode == 'remap':
385                src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
386
387        if rank == 0 and mode == 'remap':
388                src_val = np.hstack(src_val_glo)
389                dst_val = A*src_val
390                R[l,:] = dst_val
391
392
393#for Z
394if mode == 'remap':
395        src_val_loc = src.variables['Z']
396        src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
397
398if rank == 0 and mode == 'remap':
399        src_val = np.hstack(src_val_glo)
400        dst_val = A*src_val
401        Z[:] = dst_val
402
403#for ST
404if mode == 'remap':
405        src_val_loc = src.variables['ST']
406        src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
407
408if rank == 0 and mode == 'remap':
409        src_val = np.hstack(src_val_glo)
410        dst_val = A*src_val
411        ST[:] = dst_val
412
413
414#for CDSW
415if mode == 'remap':
416        src_val_loc = src.variables['CDSW']
417        src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
418
419if rank == 0 and mode == 'remap':
420        src_val = np.hstack(src_val_glo)
421        dst_val = A*src_val
422        CDSW[:] = dst_val
423
424#for SP
425if mode == 'remap':
426        src_val_loc = src.variables['SP']
427        src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
428
429if rank == 0 and mode == 'remap':
430        src_val = np.hstack(src_val_glo)
431        dst_val = A*src_val
432        SP[:] = dst_val
433
434
435
436if mode == 'remap':
437        f.close()
438       
439
440if not "reader" in srctype:
441        src.close()
442if not "reader" in dsttype:
443        dst.close()
444
Note: See TracBrowser for help on using the repository browser.