source: XIOS/dev/dev_olga/src/extern/remap/py/remap_orig.py @ 1022

Last change on this file since 1022 was 1022, checked in by mhnguyen, 7 years ago
File size: 9.8 KB
Line 
1import netCDF4 as nc
2import ctypes as ct
3import numpy as np
4import os
5import sys
6import math
7from mpi4py import MPI
8
9import time
10
11remap = ct.cdll.LoadLibrary(os.path.realpath('libmapper.so'))
12
13def from_mpas(filename):
14        # construct vortex bounds from Mpas grid structure
15        f = nc.Dataset(filename)
16        # in this case it is must faster to first read the whole file into memory
17        # before converting the data structure
18        print "read"
19        stime = time.time()
20        lon_vert = np.array(f.variables["lonVertex"])
21        lat_vert = np.array(f.variables["latVertex"])
22        vert_cell = np.array(f.variables["verticesOnCell"])
23        nvert_cell = np.array(f.variables["nEdgesOnCell"])
24        ncell, nvert = vert_cell.shape
25        assert(max(nvert_cell) <= nvert)
26        lon = np.zeros(vert_cell.shape)
27        lat = np.zeros(vert_cell.shape)
28        etime = time.time()
29        print "finished read, now convert", etime-stime
30        scal = 180.0/math.pi
31        for c in range(ncell):
32                lat[c,:] = lat_vert[vert_cell[c,:]-1]*scal
33                lon[c,:] = lon_vert[vert_cell[c,:]-1]*scal
34                # signal "last vertex" by netCDF convetion
35                lon[c,nvert_cell[c]] = lon[c,0]
36                lat[c,nvert_cell[c]] = lat[c,0]
37        print "convert end", time.time() - etime
38        return lon, lat
39
40grid_types = {
41        "dynamico:mesh": {
42                "lon_name": "bounds_lon_i",
43                "lat_name": "bounds_lat_i",
44                "pole": [0,0,0]
45        },
46        "dynamico:vort": {
47                "lon_name": "bounds_lon_v",
48                "lat_name": "bounds_lat_v",
49                "pole": [0,0,0]
50        },
51        "dynamico:restart": {
52                "lon_name": "lon_i_vertices",
53                "lat_name": "lat_i_vertices",
54                "pole": [0,0,0]
55        },
56        "test:polygon": {
57                "lon_name": "bounds_lon",
58                "lat_name": "bounds_lat",
59                "pole": [0,0,0]
60        },
61        "test:latlon": {
62                "lon_name": "bounds_lon",
63                "lat_name": "bounds_lat",
64                "pole": [0,0,1]
65        },
66        "mpas": {
67                "reader": from_mpas,
68                "pole": [0,0,0]
69        }
70}
71
72interp_types = {
73        "FV1": 1,
74        "FV2": 2
75}
76
77usage = """
78Usage: python remap.py interp srctype srcfile dsttype dstfile mode outfile
79
80   interp: type of interpolation
81       choices:
82           FV1: first order conservative Finite Volume
83           FV2: second order conservative Finite Volume
84
85   srctype, dsttype: grid type of source and destination
86       choices: """ + " ".join(grid_types.keys()) + """
87
88   srcfile, dstfile: grid file names, should mostly be netCDF file
89
90   mode: modus of operation
91       choices:
92           weights: computes weight and stores them in outfile
93           remap:   computes the interpolated values on destination grid and stores them in outfile
94
95   outfile: output filename
96
97"""
98
99# parse command line arguments
100if not len(sys.argv) == 8:
101        print usage
102        sys.exit(2)
103
104interp = sys.argv[1]
105try:
106        srctype = grid_types[sys.argv[2]]
107except KeyError:
108        print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
109        exit(2)
110srcfile = sys.argv[3]
111try:
112        dsttype = grid_types[sys.argv[4]]
113except KeyError:
114        print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
115        exit(2)
116dstfile = sys.argv[5]
117mode    = sys.argv[6]
118outfile = sys.argv[7]
119
120if not mode in ("weights", "remap"):
121        print "Error: mode must be of of the following: weights remap."
122        exit(2)
123
124remap.mpi_init()
125rank = remap.mpi_rank()
126size = remap.mpi_size()
127
128print rank, "/", size
129
130print "Reading grids from netCDF files."
131
132if "reader" in srctype:
133        src_lon, src_lat = srctype["reader"](srcfile)
134else:
135        src = nc.Dataset(srcfile)
136        # the following two lines do not perform the actual read
137        # the file is read later when assigning to the ctypes array
138        # -> no unnecessary array copying in memory
139        src_lon = src.variables[srctype["lon_name"]]
140        src_lat = src.variables[srctype["lat_name"]]
141
142if "reader" in dsttype:
143        dst_lon, dst_lat = dsttype["reader"](dstfile)
144else:
145        dst = nc.Dataset(dstfile)
146        dst_lon = dst.variables[dsttype["lon_name"]]
147        dst_lat = dst.variables[dsttype["lat_name"]]
148
149src_ncell, src_nvert = src_lon.shape
150dst_ncell, dst_nvert = dst_lon.shape
151
152def compute_distribution(ncell):
153        "Returns the local number and starting position in global array."
154        if rank < ncell % size:
155                return ncell//size + 1, \
156                       (ncell//size + 1)*rank
157        else:
158                return ncell//size, \
159                       (ncell//size + 1)*(ncell%size) + (ncell//size)*(rank - ncell%size)
160
161src_ncell_loc, src_loc_start = compute_distribution(src_ncell)
162dst_ncell_loc, dst_loc_start = compute_distribution(dst_ncell)
163
164print "src", src_ncell_loc, src_loc_start
165print "dst", dst_ncell_loc, dst_loc_start
166
167c_src_lon = (ct.c_double * (src_ncell_loc*src_nvert))()
168c_src_lat = (ct.c_double * (src_ncell_loc*src_nvert))()
169c_dst_lon = (ct.c_double * (dst_ncell_loc*dst_nvert))()
170c_dst_lat = (ct.c_double * (dst_ncell_loc*dst_nvert))()
171
172c_src_lon[:] = nc.numpy.reshape(src_lon[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
173c_src_lat[:] = nc.numpy.reshape(src_lat[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
174c_dst_lon[:] = nc.numpy.reshape(dst_lon[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
175c_dst_lat[:] = nc.numpy.reshape(dst_lat[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
176
177
178print "Calling remap library to compute weights."
179srcpole = (ct.c_double * (3))()
180dstpole = (ct.c_double * (3))()
181srcpole[:] = srctype["pole"]
182dstpole[:] = dsttype["pole"]
183
184c_src_ncell = ct.c_int(src_ncell_loc)
185c_src_nvert = ct.c_int(src_nvert)
186c_dst_ncell = ct.c_int(dst_ncell_loc)
187c_dst_nvert = ct.c_int(dst_nvert)
188order = ct.c_int(interp_types[interp])
189
190c_nweight = ct.c_int()
191
192print "src:", src_ncell, src_nvert
193print "dst:", dst_ncell, dst_nvert
194
195remap.remap_get_num_weights(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
196               c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
197               order, ct.byref(c_nweight))
198
199nwgt = c_nweight.value
200
201c_weights = (ct.c_double * nwgt)()
202c_dst_idx = (ct.c_int * nwgt)()
203c_src_idx = (ct.c_int * nwgt)()
204
205remap.remap_get_weights(c_weights, c_src_idx, c_dst_idx)
206
207wgt_glo     = MPI.COMM_WORLD.gather(c_weights[:])
208src_idx_glo = MPI.COMM_WORLD.gather(c_src_idx[:])
209dst_idx_glo = MPI.COMM_WORLD.gather(c_dst_idx[:])
210
211
212if rank == 0 and mode == 'weights':
213        nwgt_glo = sum(len(wgt) for wgt in wgt_glo)
214
215        print "Writing", nwgt_glo, "weights to netCDF-file '" + outfile + "'."
216        f = nc.Dataset(outfile,'w')
217        f.createDimension('n_src',    src_ncell)
218        f.createDimension('n_dst',    dst_ncell)
219        f.createDimension('n_weight', nwgt_glo)
220
221        var = f.createVariable('src_idx', 'i', ('n_weight'))
222        var[:] = np.hstack(src_idx_glo) + 1 # make indices start from 1
223        var = f.createVariable('dst_idx', 'i', ('n_weight'))
224        var[:] = np.hstack(dst_idx_glo) + 1 # make indices start from 1
225        var = f.createVariable('weight',  'd', ('n_weight'))
226        var[:] = np.hstack(wgt_glo)
227        f.close()
228
229def test_fun(x, y, z):
230        return (1-x**2)*(1-y**2)*z
231
232def test_fun_ll(lat, lon):
233        #return np.cos(lat*math.pi/180)*np.cos(lon*math.pi/180)
234        return 2.0 + np.cos(lat*math.pi/180.)**2 * np.cos(2*lon*math.pi/180.);
235
236#UNUSED
237#def sphe2cart(lat, lon):
238#       phi   = math.pi/180*lon[:]
239#       theta = math.pi/2 - math.pi/180*lat[:]
240#       return np.sin(theta)*np.cos(phi), np.sin(theta)*np.sin(phi), np.cos(theta)
241
242if mode == 'remap':
243        c_centre_lon = (ct.c_double * src_ncell_loc)()
244        c_centre_lat = (ct.c_double * src_ncell_loc)()
245        c_areas      = (ct.c_double * src_ncell_loc)()
246        remap.remap_get_barycentres_and_areas(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
247                c_centre_lon, c_centre_lat, c_areas)
248        src_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
249        src_val_glo = MPI.COMM_WORLD.gather(src_val_loc)
250
251        c_centre_lon = (ct.c_double * dst_ncell_loc)()
252        c_centre_lat = (ct.c_double * dst_ncell_loc)()
253        c_areas      = (ct.c_double * dst_ncell_loc)()
254        remap.remap_get_barycentres_and_areas(c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
255                c_centre_lon, c_centre_lat, c_areas)
256        dst_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
257
258        dst_val_glo = MPI.COMM_WORLD.gather(dst_val_loc)
259        dst_areas_glo = MPI.COMM_WORLD.gather(np.array(c_areas[:]))
260        dst_centre_lon_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lon[:]))
261        dst_centre_lat_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lat[:]))
262
263
264if rank == 0 and mode == 'remap':
265        from scipy import sparse
266        A = sparse.csr_matrix(sparse.coo_matrix((np.hstack(wgt_glo),(np.hstack(dst_idx_glo),np.hstack(src_idx_glo)))))
267
268        src_val = np.hstack(src_val_glo)
269        dst_ref = np.hstack(dst_val_glo)
270        dst_areas = np.hstack(dst_areas_glo)
271        dst_centre_lon = np.hstack(dst_centre_lon_glo)
272        dst_centre_lat = np.hstack(dst_centre_lat_glo)
273
274        print "source:", src_val.shape
275        print "destin:", dst_ref.shape
276        dst_val = A*src_val
277        err = dst_val - dst_ref
278        print "absolute maximum error, maximum value:", np.max(np.abs(err)), np.max(np.abs(dst_ref))
279        print "relative maximum error, normalized L2 error, average target cell size (edgelength of same-area square):"
280        print np.max(np.abs(err))/np.max(np.abs(dst_ref)), np.linalg.norm(err)/np.linalg.norm(dst_ref), np.mean(np.sqrt(dst_areas))
281
282        f = nc.Dataset(outfile,'w')
283        f.createDimension('n_vert', dst_nvert)
284        f.createDimension('n_cell', dst_ncell)
285
286        var = f.createVariable('lat', 'd', ('n_cell'))
287        var.setncattr("long_name", "latitude")
288        var.setncattr("units", "degrees_north")
289        var.setncattr("bounds", "bounds_lat")
290        var[:] = dst_centre_lat
291        var = f.createVariable('lon', 'd', ('n_cell'))
292        var.setncattr("long_name", "longitude")
293        var.setncattr("units", "degrees_east")
294        var.setncattr("bounds", "bounds_lon")
295        var[:] = dst_centre_lon
296
297        var = f.createVariable('bounds_lon', 'd', ('n_cell','n_vert'))
298        var[:] = dst_lon
299        var = f.createVariable('bounds_lat', 'd', ('n_cell','n_vert'))
300        var[:] = dst_lat
301        var = f.createVariable('val', 'd', ('n_cell'))
302        var.setncattr("coordinates", "lon lat")
303        var[:] = dst_val
304        var = f.createVariable('val_ref', 'd', ('n_cell'))
305        var.setncattr("coordinates", "lon lat")
306        var[:] = dst_ref
307        var = f.createVariable('err', 'd', ('n_cell'))
308        var.setncattr("coordinates", "lon lat")
309        var[:] = err
310        var = f.createVariable('area', 'd', ('n_cell'))
311        var.setncattr("coordinates", "lon lat")
312        var[:] = dst_areas[:] # dest
313        f.close()
314
315if not "reader" in srctype:
316        src.close()
317if not "reader" in dsttype:
318        dst.close()
319
Note: See TracBrowser for help on using the repository browser.