source: codes/icosagcm/devel/Python/test/py/partition.py @ 620

Last change on this file since 620 was 620, checked in by dubos, 6 years ago

devel/unstructured : mesh partitioning

File size: 7.5 KB
Line 
1print 'Starting'
2
3from mpi4py import MPI
4comm = MPI.COMM_WORLD
5mpi_rank, mpi_size = comm.Get_rank(), comm.Get_size()
6print '%d/%d starting'%(mpi_rank,mpi_size)
7
8import sys
9import math as math
10import numpy as np
11import netCDF4 as cdf
12
13import matplotlib.pyplot as plt
14from matplotlib.patches import Polygon
15from matplotlib.collections import PatchCollection
16
17#from dynamico import partition
18from dynamico import parallel
19from dynamico import unstructured as unst
20from dynamico.unstructured import list_stencil, reindex
21
22print 'Done loading modules'
23
24sys.stdout.flush()
25
26#----------------- partition hand-written 15-cell mesh ------------------#
27
28if mpi_size<15:
29    send=np.random.randn(mpi_size)
30    recv=np.zeros(mpi_size)
31    comm.Alltoall(send,recv)
32    #time.sleep(mpi_rank)
33#    print mpi_rank, send, recv
34   
35    adjncy=[1, 5, 0, 2, 6, 1, 3, 7, 2, 4, 8, 3, 9, 
36              0, 6, 10, 1, 5, 7, 11, 2, 6, 8, 12, 3, 7, 9, 13, 4, 8, 14,
37              5, 11, 6, 10, 12, 7, 11, 13, 8, 12, 14, 9, 13 ]
38    xadj=[0, 2, 5, 8, 11, 13, 16, 20, 24, 28, 31, 33, 36, 39, 42, 44]
39   
40    nb_vert = len(xadj)-1
41    vtxdist = [i*nb_vert/mpi_size for i in range(mpi_size+1)]
42    xadj, adjncy, vtxdist = [np.asarray(x,np.int32) for x in xadj,adjncy,vtxdist]
43   
44    idx_start = vtxdist[mpi_rank]
45    idx_end = vtxdist[mpi_rank+1]
46    nb_vert = idx_end - idx_start
47   
48    xadj_loc = xadj[idx_start:idx_end+1]-xadj[idx_start]
49    adjncy_loc = adjncy[ xadj[idx_start]:xadj[idx_end] ]
50    part = 0*xadj_loc[0:-1];
51
52    unst.ker.dynamico_partition_graph(mpi_rank, mpi_size, vtxdist, xadj_loc, adjncy_loc, 4, part)
53
54#    for i in range(len(part)):
55#        print 'vertex', i+idx_start, 'proc', part[i]
56
57#-----------------------------------------------------------------------------#
58#---------------         partition and plot MPAS mesh       ------------------#
59#-----------------------------------------------------------------------------#
60
61# Helper functions to plot unstructured graph
62
63def patches(degree, bounds, lon, lat):
64    for i in range(degree.size):
65        nb_edge=degree[i]
66        bounds_cell = bounds[i,0:nb_edge]
67        lat_cell    = lat[bounds_cell]
68        lon_cell    = lon[bounds_cell]
69        orig=lon_cell[0]
70        lon_cell    = lon_cell-orig+180.
71        lon_cell    = np.mod(lon_cell,360.)
72        lon_cell    = lon_cell+orig-180.
73#        if np.abs(lon_cell-orig).max()>10. :
74#            print '%d patches :'%mpi_rank, lon_cell
75        lonlat_cell = np.zeros((nb_edge,2))
76        lonlat_cell[:,0],lonlat_cell[:,1] = lon_cell,lat_cell
77        polygon = Polygon(lonlat_cell, True)
78        yield polygon
79
80def plot_mesh(ax, clim, degree, bounds, lon, lat, data):
81    nb_vertex = lon.size # global
82    p = list(patches(degree, bounds, lon, lat))
83    print '%d : plot_mesh %d %d %d'%( mpi_rank, degree.size, len(p), len(data) ) 
84    p = PatchCollection(p, linewidth=0.01)
85    p.set_array(data) # set values at each polygon (cell)
86    p.set_clim(clim)
87    ax.add_collection(p)
88
89def local_mesh(get_mycells):
90    mydegree, mybounds = [get_mycells(x) for x in nEdgesOnCell, verticesOnCell]
91    print '%d : len(mydegree)=%d'%(mpi_rank, len(mydegree))
92    vertex_list = sorted(set(list_stencil(mydegree,mybounds))) 
93    print '%d : len(vertex_list))=%d'%(mpi_rank, len(vertex_list))
94    get_myvertices = parallel.Get_Indices(dim_vertex, vertex_list)
95    mylon, mylat = [get_myvertices(x)*180./math.pi for x in lonVertex, latVertex]
96    vertex_dict = parallel.inverse_list(vertex_list)
97    reindex(vertex_dict, mydegree, mybounds)
98    return vertex_list, mydegree, mybounds, mylon, mylat
99
100#--------------- read MPAS grid file ---------------#
101
102#grid = 'x1.2562'
103grid = 'x1.10242'
104#grid = 'x4.163842'
105print 'Reading MPAS file %s ...'%grid
106sys.stdout.flush()
107
108nc = cdf.Dataset('grids/%s.grid.nc'%grid, "r")
109dim_cell, dim_edge, dim_vertex = [
110    parallel.PDim(nc.dimensions[name], comm) 
111    for name in 'nCells','nEdges','nVertices']
112edge_degree   = parallel.CstPArray1D(dim_edge, np.int32, 2)
113vertex_degree = parallel.CstPArray1D(dim_vertex, np.int32, 3)
114nEdgesOnCell, verticesOnCell, edgesOnCell, cellsOnCell, latCell = [
115    parallel.PArray(dim_cell, nc.variables[var])
116    for var in 'nEdgesOnCell', 'verticesOnCell', 'edgesOnCell', 'cellsOnCell', 'latCell' ]
117cellsOnVertex, edgesOnVertex, kiteAreasOnVertex, lonVertex, latVertex = [
118    parallel.PArray(dim_vertex, nc.variables[var])
119    for var in 'cellsOnVertex', 'edgesOnVertex', 'kiteAreasOnVertex', 'lonVertex', 'latVertex']
120nEdgesOnEdge, cellsOnEdge, edgesOnEdge, verticesOnEdge, weightsOnEdge = [
121    parallel.PArray(dim_edge, nc.variables[var])
122    for var in 'nEdgesOnEdge', 'cellsOnEdge', 'edgesOnEdge', 'verticesOnEdge', 'weightsOnEdge']
123
124# Indices start at 0 on the C/Python side and at 1 on the Fortran/MPAS side
125# hence an offset of 1 is added/substracted where needed.
126for x in (verticesOnCell, edgesOnCell, cellsOnCell, cellsOnVertex, edgesOnVertex,
127          cellsOnEdge, edgesOnEdge, verticesOnEdge) : x.data = x.data-1
128edge2cell, cell2edge, edge2vertex, vertex2edge, cell2cell, edge2edge = [
129    unst.Stencil_glob(a,b) for a,b in 
130    (edge_degree, cellsOnEdge), (nEdgesOnCell, edgesOnCell),
131    (edge_degree, verticesOnEdge), (vertex_degree, edgesOnVertex),
132    (nEdgesOnCell, cellsOnCell), (nEdgesOnEdge, edgesOnEdge) ]
133
134#---------------- partition edges and cells ------------------#
135
136print 'Partitioning ...'
137sys.stdout.flush()
138
139edge_owner = unst.partition_mesh(nEdgesOnEdge, edgesOnEdge, mpi_size)
140edge_owner = parallel.LocPArray1D(dim_edge, edge_owner)
141cell_owner = unst.partition_from_stencil(edge_owner, nEdgesOnCell, edgesOnCell)
142cell_owner = parallel.LocPArray1D(dim_cell, cell_owner)
143
144#--------------------- construct halos  -----------------------#
145
146print 'Constructing halos ...'
147sys.stdout.flush()
148
149def chain(start, links):
150    for link in links:
151        start = link(start).neigh_set
152        yield start
153
154edges_E0 = unst.find_my_cells(edge_owner)
155cells_C0, edges_E1, vertices_V1, edges_E2, cells_C1 = chain(
156    edges_E0, ( edge2cell, cell2edge, edge2vertex, vertex2edge, edge2cell) )
157
158edges_E0, edges_E1, edges_E2 = unst.progressive_list(edges_E0, edges_E1, edges_E2)
159cells_C0, cells_C1 = unst.progressive_list(cells_C0, cells_C1)
160
161print 'E2,E1,E0 ; C1,C0 : ', map(len, (edges_E2, edges_E1, edges_E0, cells_C1, cells_C0))
162sys.stdout.flush()
163
164#com_edges = parallel.Halo_Xchange(24, dim_edge, edges_E2, dim_edge.get(edges_E2, edge_owner))
165
166mycells, halo_cells = cells_C0, cells_C1
167get_mycells, get_halo_cells = dim_cell.getter(mycells), dim_cell.getter(halo_cells)
168com_cells = parallel.Halo_Xchange(42, dim_cell, halo_cells, get_halo_cells(cell_owner))
169
170local_num, total_num = np.zeros(1), np.zeros(1)
171local_num[0]=com_cells.own_len
172comm.Reduce(local_num, total_num, op=MPI.SUM, root=0)
173if(mpi_rank==0): print 'total num :', total_num[0], dim_cell.n
174sys.stdout.flush()
175
176#---------------------------- plot -----------------------------#
177
178if True:
179    print 'Plotting ...'
180    sys.stdout.flush()
181
182    halo_vertex_list, mydegree, mybounds, mylon, mylat = local_mesh(com_cells.get_all)
183    buf = parallel.LocalArray1(com_cells)
184
185    fig, ax = plt.subplots()
186    buf.read_own(latCell) # reads only own values
187    buf.data = np.cos(10.*buf.data)
188    buf.update() # updates halo
189    plot_mesh(ax,[-math.pi/2,math.pi/2], mydegree, mybounds, mylon, mylat, buf.data)
190    plt.xlim(-190.,190.)
191    plt.ylim(-90.,90.)
192    plt.savefig('fig_partition/A%03d.pdf'%mpi_rank, dpi=1600)
193
194    fig, ax = plt.subplots()
195    buf.read_own(cell_owner)
196    buf.update()
197    plot_mesh(ax,[0,mpi_rank+1], mydegree, mybounds, mylon, mylat, buf.data)
198    plt.xlim(-190.,190.)
199    plt.ylim(-90.,90.)
200    plt.savefig('fig_partition/B%03d.pdf'%mpi_rank, dpi=1600)
Note: See TracBrowser for help on using the repository browser.