source: XIOS/dev/dev_olga/src/extern/remap/py/remap.py @ 1022

Last change on this file since 1022 was 1022, checked in by mhnguyen, 7 years ago
File size: 12.3 KB
Line 
1import netCDF4 as nc
2import ctypes as ct
3import numpy as np
4import os
5import sys
6import math
7from mpi4py import MPI
8
9import time
10
11remap = ct.cdll.LoadLibrary(os.path.realpath('libmapper.so'))
12
13def from_mpas(filename):
14        # construct vortex bounds from Mpas grid structure
15        f = nc.Dataset(filename)
16        # in this case it is must faster to first read the whole file into memory
17        # before converting the data structure
18        print "read"
19        stime = time.time()
20        lon_vert = np.array(f.variables["lonVertex"])
21        lat_vert = np.array(f.variables["latVertex"])
22        vert_cell = np.array(f.variables["verticesOnCell"])
23        nvert_cell = np.array(f.variables["nEdgesOnCell"])
24        ncell, nvert = vert_cell.shape
25        assert(max(nvert_cell) <= nvert)
26        lon = np.zeros(vert_cell.shape)
27        lat = np.zeros(vert_cell.shape)
28        etime = time.time()
29        print "finished read, now convert", etime-stime
30        scal = 180.0/math.pi
31        for c in range(ncell):
32                lat[c,:] = lat_vert[vert_cell[c,:]-1]*scal
33                lon[c,:] = lon_vert[vert_cell[c,:]-1]*scal
34                # signal "last vertex" by netCDF convetion
35                lon[c,nvert_cell[c]] = lon[c,0]
36                lat[c,nvert_cell[c]] = lat[c,0]
37        print "convert end", time.time() - etime
38        return lon, lat
39
40grid_types = {
41        "dynamico:mesh": {
42                "lon_name": "bounds_lon_i",
43                "lat_name": "bounds_lat_i",
44                "pole": [0,0,0]
45        },
46        "dynamico:vort": {
47                "lon_name": "bounds_lon_v",
48                "lat_name": "bounds_lat_v",
49                "pole": [0,0,0]
50        },
51        "dynamico:restart": {
52                "lon_name": "lon_i_vertices",
53                "lat_name": "lat_i_vertices",
54                "pole": [0,0,0]
55        },
56        "test:polygon": {
57                "lon_name": "bounds_lon",
58                "lat_name": "bounds_lat",
59                "pole": [0,0,0]
60        },
61        "test:latlon": {
62                "lon_name": "bounds_lon",
63                "lat_name": "bounds_lat",
64                "pole": [0,0,1]
65        },
66        "mpas": {
67                "reader": from_mpas,
68                "pole": [0,0,0]
69        }
70}
71
72interp_types = {
73        "FV1": 1,
74        "FV2": 2
75}
76
77usage = """
78Usage: python remap.py interp srctype srcfile dsttype dstfile mode outfile
79
80   interp: type of interpolation
81       choices:
82           FV1: first order conservative Finite Volume
83           FV2: second order conservative Finite Volume
84
85   srctype, dsttype: grid type of source and destination
86       choices: """ + " ".join(grid_types.keys()) + """
87
88   srcfile, dstfile: grid file names, should mostly be netCDF file
89
90   mode: modus of operation
91       choices:
92           weights: computes weight and stores them in outfile
93           remap:   computes the interpolated values on destination grid and stores them in outfile
94
95   outfile: output filename
96
97"""
98
99# parse command line arguments
100if not len(sys.argv) == 8:
101        print usage
102        sys.exit(2)
103
104interp = sys.argv[1]
105try:
106        srctype = grid_types[sys.argv[2]]
107except KeyError:
108        print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
109        exit(2)
110srcfile = sys.argv[3]
111try:
112        dsttype = grid_types[sys.argv[4]]
113except KeyError:
114        print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
115        exit(2)
116dstfile = sys.argv[5]
117mode    = sys.argv[6]
118outfile = sys.argv[7]
119
120if not mode in ("weights", "remap"):
121        print "Error: mode must be of of the following: weights remap."
122        exit(2)
123
124remap.mpi_init()
125rank = remap.mpi_rank()
126size = remap.mpi_size()
127
128print rank, "/", size
129
130print "Reading grids from netCDF files."
131
132if "reader" in srctype:
133        src_lon, src_lat = srctype["reader"](srcfile)
134else:
135        src = nc.Dataset(srcfile)
136        # the following two lines do not perform the actual read
137        # the file is read later when assigning to the ctypes array
138        # -> no unnecessary array copying in memory
139        src_lon = src.variables[srctype["lon_name"]]
140        src_lat = src.variables[srctype["lat_name"]]
141
142if "reader" in dsttype:
143        dst_lon, dst_lat = dsttype["reader"](dstfile)
144else:
145        dst = nc.Dataset(dstfile)
146        dst_lon = dst.variables[dsttype["lon_name"]]
147        dst_lat = dst.variables[dsttype["lat_name"]]
148
149src_ncell, src_nvert = src_lon.shape
150dst_ncell, dst_nvert = dst_lon.shape
151
152def compute_distribution(ncell):
153        "Returns the local number and starting position in global array."
154        if rank < ncell % size:
155                return ncell//size + 1, \
156                       (ncell//size + 1)*rank
157        else:
158                return ncell//size, \
159                       (ncell//size + 1)*(ncell%size) + (ncell//size)*(rank - ncell%size)
160
161src_ncell_loc, src_loc_start = compute_distribution(src_ncell)
162dst_ncell_loc, dst_loc_start = compute_distribution(dst_ncell)
163
164print "src", src_ncell_loc, src_loc_start
165print "dst", dst_ncell_loc, dst_loc_start
166
167c_src_lon = (ct.c_double * (src_ncell_loc*src_nvert))()
168c_src_lat = (ct.c_double * (src_ncell_loc*src_nvert))()
169c_dst_lon = (ct.c_double * (dst_ncell_loc*dst_nvert))()
170c_dst_lat = (ct.c_double * (dst_ncell_loc*dst_nvert))()
171
172c_src_lon[:] = nc.numpy.reshape(src_lon[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
173c_src_lat[:] = nc.numpy.reshape(src_lat[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
174c_dst_lon[:] = nc.numpy.reshape(dst_lon[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
175c_dst_lat[:] = nc.numpy.reshape(dst_lat[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
176
177
178print "Calling remap library to compute weights."
179srcpole = (ct.c_double * (3))()
180dstpole = (ct.c_double * (3))()
181srcpole[:] = srctype["pole"]
182dstpole[:] = dsttype["pole"]
183
184c_src_ncell = ct.c_int(src_ncell_loc)
185c_src_nvert = ct.c_int(src_nvert)
186c_dst_ncell = ct.c_int(dst_ncell_loc)
187c_dst_nvert = ct.c_int(dst_nvert)
188order = ct.c_int(interp_types[interp])
189
190c_nweight = ct.c_int()
191
192print "src:", src_ncell, src_nvert
193print "dst:", dst_ncell, dst_nvert
194
195remap.remap_get_num_weights(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
196               c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
197               order, ct.byref(c_nweight))
198
199nwgt = c_nweight.value
200
201c_weights = (ct.c_double * nwgt)()
202c_dst_idx = (ct.c_int * nwgt)()
203c_src_idx = (ct.c_int * nwgt)()
204
205remap.remap_get_weights(c_weights, c_src_idx, c_dst_idx)
206
207wgt_glo     = MPI.COMM_WORLD.gather(c_weights[:])
208src_idx_glo = MPI.COMM_WORLD.gather(c_src_idx[:])
209dst_idx_glo = MPI.COMM_WORLD.gather(c_dst_idx[:])
210
211
212if rank == 0 and mode == 'weights':
213        nwgt_glo = sum(len(wgt) for wgt in wgt_glo)
214
215        print "Writing", nwgt_glo, "weights to netCDF-file '" + outfile + "'."
216        f = nc.Dataset(outfile,'w')
217        f.createDimension('n_src',    src_ncell)
218        f.createDimension('n_dst',    dst_ncell)
219        f.createDimension('n_weight', nwgt_glo)
220
221        var = f.createVariable('src_idx', 'i', ('n_weight'))
222        var[:] = np.hstack(src_idx_glo) + 1 # make indices start from 1
223        var = f.createVariable('dst_idx', 'i', ('n_weight'))
224        var[:] = np.hstack(dst_idx_glo) + 1 # make indices start from 1
225        var = f.createVariable('weight',  'd', ('n_weight'))
226        var[:] = np.hstack(wgt_glo)
227        f.close()
228
229def test_fun(x, y, z):
230        return (1-x**2)*(1-y**2)*z
231
232def test_fun_ll(lat, lon):
233        #return np.cos(lat*math.pi/180)*np.cos(lon*math.pi/180)
234        return 2.0 + np.cos(lat*math.pi/180.)**2 * np.cos(2*lon*math.pi/180.);
235
236#UNUSED
237#def sphe2cart(lat, lon):
238#       phi   = math.pi/180*lon[:]
239#       theta = math.pi/2 - math.pi/180*lat[:]
240#       return np.sin(theta)*np.cos(phi), np.sin(theta)*np.sin(phi), np.cos(theta)
241
242if mode == 'remap':
243        c_centre_lon = (ct.c_double * src_ncell_loc)()
244        c_centre_lat = (ct.c_double * src_ncell_loc)()
245        c_areas      = (ct.c_double * src_ncell_loc)()
246        remap.remap_get_barycentres_and_areas(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
247                c_centre_lon, c_centre_lat, c_areas)
248#       src_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
249#       src_val_loc = src.variables["ps"]
250#       src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
251#        src_val_glo = src_val_loc
252
253        c_centre_lon = (ct.c_double * dst_ncell_loc)()
254        c_centre_lat = (ct.c_double * dst_ncell_loc)()
255        c_areas      = (ct.c_double * dst_ncell_loc)()
256        remap.remap_get_barycentres_and_areas(c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
257                c_centre_lon, c_centre_lat, c_areas)
258#       dst_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
259
260#       dst_val_glo = MPI.COMM_WORLD.gather(dst_val_loc)
261        dst_areas_glo = MPI.COMM_WORLD.gather(np.array(c_areas[:]))
262        dst_centre_lon_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lon[:]))
263        dst_centre_lat_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lat[:]))
264
265
266if rank == 0 and mode == 'remap':
267        from scipy import sparse
268        A = sparse.csr_matrix(sparse.coo_matrix((np.hstack(wgt_glo),(np.hstack(dst_idx_glo),np.hstack(src_idx_glo)))))
269
270#       src_val = np.hstack(src_val_glo)
271#       dst_ref = np.hstack(dst_val_glo)
272        dst_areas = np.hstack(dst_areas_glo)
273        dst_centre_lon = np.hstack(dst_centre_lon_glo)
274        dst_centre_lat = np.hstack(dst_centre_lat_glo)
275
276#       print "source:", src_val.shape
277#       print "destin:", dst_ref.shape
278#       dst_val = A*src_val
279#       err = dst_val - dst_ref
280#       print "absolute maximum error, maximum value:", np.max(np.abs(err)), np.max(np.abs(dst_ref))
281#       print "relative maximum error, normalized L2 error, average target cell size (edgelength of same-area square):"
282#       print np.max(np.abs(err))/np.max(np.abs(dst_ref)), np.linalg.norm(err)/np.linalg.norm(dst_ref), np.mean(np.sqrt(dst_areas))
283
284        lev=src.dimensions['lev']
285        nq=src.dimensions['nq']
286        f = nc.Dataset(outfile,'w')
287        f.createDimension('nvert', dst_nvert)
288        f.createDimension('cell', dst_ncell)
289        f.createDimension('lev', len(lev))
290        f.createDimension('nq', len(nq))
291
292        var = f.createVariable('lat', 'd', ('cell'))
293        var.setncattr("long_name", "latitude")
294        var.setncattr("units", "degrees_north")
295        var.setncattr("bounds", "bounds_lat")
296        var[:] = dst_centre_lat
297        var = f.createVariable('lon', 'd', ('cell'))
298        var.setncattr("long_name", "longitude")
299        var.setncattr("units", "degrees_east")
300        var.setncattr("bounds", "bounds_lon")
301        var[:] = dst_centre_lon
302
303        var = f.createVariable('bounds_lon', 'd', ('cell','nvert'))
304        var[:] = dst_lon
305        var = f.createVariable('bounds_lat', 'd', ('cell','nvert'))
306        var[:] = dst_lat
307
308        var = f.createVariable('lev', 'd', ('lev'))
309        var[:] = src.variables['lev']
310        var.setncattr('axis', 'Z')
311        var.setncattr('units', 'Pa')
312        var.setncattr('positive', 'down')
313        var[:] = src.variables['lev']
314
315        ps = f.createVariable('ps', 'd', ('cell'))
316        ps.setncattr("coordinates", "lon lat")
317
318        phis = f.createVariable('phis', 'd', ('cell'))
319        phis.setncattr("coordinates", "lon lat")
320
321        theta_rhodz = f.createVariable('theta_rhodz', 'd', ('lev','cell'))
322        theta_rhodz.setncattr("coordinates", "lev lon lat")
323
324        ulon = f.createVariable('ulon', 'd', ('lev','cell'))
325        ulon.setncattr("coordinates", "lev lon lat")
326
327        ulat = f.createVariable('ulat', 'd', ('lev','cell'))
328        ulat.setncattr("coordinates", "lev lon lat")
329
330        q = f.createVariable('q', 'd', ('nq','lev','cell'))
331        q.setncattr("coordinates", "nq lev lon lat")
332       
333       
334
335#for ps
336if mode == 'remap':
337        src_val_loc = src.variables['ps']
338        src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
339
340if rank == 0 and mode == 'remap':
341#        print(src_val_glo)
342        src_val = np.hstack(src_val_glo)
343#        print src_val
344        print A.shape
345        print src_val.shape
346        dst_val = A*src_val
347        ps[:] = dst_val
348
349#for phis
350if mode == 'remap':
351        src_val_loc = src.variables['phis']
352        src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
353
354if rank == 0 and mode == 'remap':
355        src_val = np.hstack(src_val_glo)
356        dst_val = A*src_val
357        phis[:] = dst_val
358
359
360#for theta_rhodz
361if mode == 'remap':
362        src_val_loc = src.variables['theta_rhodz']
363
364for l in range(0, len(lev)):
365        if mode == 'remap':
366                src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
367
368        if rank == 0 and mode == 'remap':
369                src_val = np.hstack(src_val_glo)
370                dst_val = A*src_val
371                theta_rhodz[l,:] = dst_val
372
373#for ulon
374if mode == 'remap':
375        src_val_loc = src.variables['ulon']
376
377for l in range(0, len(lev)):
378        if mode == 'remap':
379                src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
380
381        if rank == 0 and mode == 'remap':
382                src_val = np.hstack(src_val_glo)
383                dst_val = A*src_val
384                ulon[l,:] = dst_val
385
386#for ulat
387if mode == 'remap':
388        src_val_loc = src.variables['ulat']
389
390for l in range(0, len(lev)):
391        if mode == 'remap':
392                src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
393
394        if rank == 0 and mode == 'remap':
395                src_val = np.hstack(src_val_glo)
396                dst_val = A*src_val
397                ulat[l,:] = dst_val
398
399#for q
400if mode == 'remap':
401        src_val_loc = src.variables['q']
402
403for n in range(0, len(nq)):
404        for l in range(0, len(lev)):
405                if mode == 'remap':
406                        src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[n,l,:]))
407
408                if rank == 0 and mode == 'remap':
409                        src_val = np.hstack(src_val_glo)
410                        dst_val = A*src_val
411                        q[n,l,:] = dst_val
412
413if mode == 'remap':
414        f.close()
415       
416
417if not "reader" in srctype:
418        src.close()
419if not "reader" in dsttype:
420        dst.close()
421
Note: See TracBrowser for help on using the repository browser.