source: codes/icosagcm/devel/Python/test/py/partition.py

Last change on this file was 977, checked in by dubos, 5 years ago

devel/Python : now only dynamico.dev modules require link to DYNAMICO/XIOS shared objects

File size: 4.2 KB
Line 
1print 'Starting'
2
3from dynamico import meshes
4from dynamico import parallel
5from dynamico import maps
6from dynamico import partition
7
8from mpi4py import MPI
9comm = MPI.COMM_WORLD
10mpi_rank, mpi_size = comm.Get_rank(), comm.Get_size()
11print '%d/%d starting'%(mpi_rank,mpi_size)
12
13import math as math
14import numpy as np
15import netCDF4 as cdf
16
17import matplotlib.pyplot as plt
18from matplotlib.patches import Polygon
19from matplotlib.collections import PatchCollection
20
21print 'Done loading modules'
22
23#----------------- partition hand-written 15-cell mesh ------------------#
24
25if mpi_size<15:
26    send=np.random.randn(mpi_size)
27    recv=np.zeros(mpi_size)
28    comm.Alltoall(send,recv)
29    #time.sleep(mpi_rank)
30#    print mpi_rank, send, recv
31   
32    adjncy=[1, 5, 0, 2, 6, 1, 3, 7, 2, 4, 8, 3, 9, 
33              0, 6, 10, 1, 5, 7, 11, 2, 6, 8, 12, 3, 7, 9, 13, 4, 8, 14,
34              5, 11, 6, 10, 12, 7, 11, 13, 8, 12, 14, 9, 13 ]
35    xadj=[0, 2, 5, 8, 11, 13, 16, 20, 24, 28, 31, 33, 36, 39, 42, 44]
36   
37    nb_vert = len(xadj)-1
38    vtxdist = [i*nb_vert/mpi_size for i in range(mpi_size+1)]
39    xadj, adjncy, vtxdist = [np.asarray(x,np.int32) for x in xadj,adjncy,vtxdist]
40   
41    idx_start = vtxdist[mpi_rank]
42    idx_end = vtxdist[mpi_rank+1]
43    nb_vert = idx_end - idx_start
44   
45    xadj_loc = xadj[idx_start:idx_end+1]-xadj[idx_start]
46    adjncy_loc = adjncy[ xadj[idx_start]:xadj[idx_end] ]
47    part = 0*xadj_loc[0:-1];
48
49    partition.partition_graph(comm, vtxdist, xadj_loc, adjncy_loc, part, nparts=4)
50
51    for i in range(len(part)):
52        print 'vertex', i+idx_start, 'proc', part[i]
53
54#-----------------------------------------------------------------------------#
55#---------------         partition and plot MPAS mesh       ------------------#
56#-----------------------------------------------------------------------------#
57
58# Helper functions to plot unstructured graph
59
60def local_mesh(get_mycells):
61    #    mydegree, mybounds = [get_mycells(x) for x in nEdgesOnCell, verticesOnCell]
62    mydegree, mybounds = [get_mycells(x) for x in primal_deg, primal_vertex]
63    print '%d : len(mydegree)=%d'%(mpi_rank, len(mydegree))
64    vertex_list = sorted(set(partition.list_stencil(mydegree,mybounds))) 
65    print '%d : len(vertex_list))=%d'%(mpi_rank, len(vertex_list))
66    get_myvertices = parallel.Get_Indices(dim_vertex, vertex_list)
67    mylon, mylat = [get_myvertices(x)*180./math.pi for x in lonVertex, latVertex]
68    vertex_dict = parallel.inverse_list(vertex_list)
69    meshes.reindex(vertex_dict, mydegree, mybounds)
70    return vertex_list, mydegree, mybounds, mylon, mylat
71
72def members(struct, *names): return [struct.__dict__ [name] for name in names]
73
74#--------------- read MPAS grid file ---------------#
75
76grid = 'x1.2562'
77#grid = 'x1.10242'
78#grid = 'x4.163842'
79print 'Reading MPAS file %s ...'%grid
80
81meshfile = meshes.MPAS_Format('grids/%s.grid.nc'%grid)
82pmesh = meshes.Unstructured_PMesh(comm, meshfile)
83pmesh.partition_metis()
84
85def coriolis(lon,lat): return 0.*lat
86llm, nqdyn, radius = 1,1,1.
87planet = maps.SphereMap(radius, 0.)
88lmesh = meshes.Local_Mesh(pmesh, llm, nqdyn, planet)
89
90(primal_deg, primal_vertex, dim_vertex, dim_cell, cell_owner, 
91 lonVertex, latVertex, lonCell, latCell) = members(
92    pmesh, 'primal_deg', 'primal_vertex', 'dim_dual', 'dim_primal', 'primal_owner', 
93    'lon_v', 'lat_v', 'lon_i', 'lat_i')
94
95local_num, total_num, com_cells = np.zeros(1), np.zeros(1), lmesh.com_primal
96local_num[0]=com_cells.own_len
97comm.Reduce(local_num, total_num, op=MPI.SUM, root=0)
98if(mpi_rank==0): print 'total num :', total_num[0], dim_cell.n
99
100#---------------------------- plot -----------------------------#
101
102print 'Plotting ...'
103
104halo_vertex_list, mydegree, mybounds, mylon, mylat = local_mesh(com_cells.get_all)
105buf = parallel.LocalArray1(com_cells)
106
107fig, ax = plt.subplots()
108buf.read_own(latCell) # reads only own values
109buf.data = np.cos(10.*buf.data)
110buf.update() # updates halo
111lmesh.plot_patches(ax,[-math.pi/2,math.pi/2], mydegree, mybounds, mylon, mylat, buf.data)
112plt.xlim(-190.,190.)
113plt.ylim(-90.,90.)
114plt.savefig('fig_partition/A%03d.png'%mpi_rank, dpi=160)
115
116fig, ax = plt.subplots()
117buf.read_own(cell_owner)
118buf.update()
119lmesh.plot_patches(ax,[0,mpi_rank+1], mydegree, mybounds, mylon, mylat, buf.data)
120plt.xlim(-190.,190.)
121plt.ylim(-90.,90.)
122plt.savefig('fig_partition/B%03d.png'%mpi_rank, dpi=160)
Note: See TracBrowser for help on using the repository browser.