source: XIOS/dev/dev_olga/src/extern/blitz/include/blitz/array/cgsolve.h @ 1022

Last change on this file since 1022 was 1022, checked in by mhnguyen, 7 years ago
File size: 4.2 KB
Line 
1// -*- C++ -*-
2/***************************************************************************
3 * blitz/array/cgsolve.h  Basic conjugate gradient solver for linear systems
4 *
5 * $Id$
6 *
7 * Copyright (C) 1997-2011 Todd Veldhuizen <tveldhui@acm.org>
8 *
9 * This file is a part of Blitz.
10 *
11 * Blitz is free software: you can redistribute it and/or modify
12 * it under the terms of the GNU Lesser General Public License
13 * as published by the Free Software Foundation, either version 3
14 * of the License, or (at your option) any later version.
15 *
16 * Blitz is distributed in the hope that it will be useful,
17 * but WITHOUT ANY WARRANTY; without even the implied warranty of
18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19 * GNU Lesser General Public License for more details.
20 *
21 * You should have received a copy of the GNU Lesser General Public
22 * License along with Blitz.  If not, see <http://www.gnu.org/licenses/>.
23 *
24 * Suggestions:          blitz-devel@lists.sourceforge.net
25 * Bugs:                 blitz-support@lists.sourceforge.net   
26 *
27 * For more information, please see the Blitz++ Home Page:
28 *    https://sourceforge.net/projects/blitz/
29 *
30 ****************************************************************************/
31#ifndef BZ_CGSOLVE_H
32#define BZ_CGSOLVE_H
33
34BZ_NAMESPACE(blitz)
35
36template<typename T_numtype>
37void dump(const char* name, Array<T_numtype,3>& A)
38{
39    T_numtype normA = 0;
40
41    for (int i=A.lbound(0); i <= A.ubound(0); ++i)
42    {
43      for (int j=A.lbound(1); j <= A.ubound(1); ++j)
44      {
45        for (int k=A.lbound(2); k <= A.ubound(2); ++k)
46        {
47            T_numtype tmp = A(i,j,k);
48            normA += BZ_MATHFN_SCOPE(fabs)(tmp);
49        }
50      }
51    }
52
53    normA /= A.numElements();
54    cout << "Average magnitude of " << name << " is " << normA << endl;
55}
56
57template<typename T_stencil, typename T_numtype, int N_rank, typename T_BCs>
58int conjugateGradientSolver(T_stencil stencil,
59    Array<T_numtype,N_rank>& x,
60    Array<T_numtype,N_rank>& rhs, double haltrho, 
61    const T_BCs& boundaryConditions)
62{
63    // NEEDS_WORK: only apply CG updates over interior; need to handle
64    // BCs separately.
65
66    // x = unknowns being solved for (initial guess assumed)
67    // r = residual
68    // p = descent direction for x
69    // q = descent direction for r
70
71    RectDomain<N_rank> interior = interiorDomain(stencil, x, rhs);
72
73cout << "Interior: " << interior.lbound() << ", " << interior.ubound()
74     << endl;
75
76    // Calculate initial residual
77    Array<T_numtype,N_rank> r = rhs.copy();
78    r *= -1.0;
79
80    boundaryConditions.applyBCs(x);
81
82    applyStencil(stencil, r, x);
83
84 dump("r after stencil", r);
85 cout << "Slice through r: " << endl << r(23,17,Range::all()) << endl;
86 cout << "Slice through x: " << endl << x(23,17,Range::all()) << endl;
87 cout << "Slice through rhs: " << endl << rhs(23,17,Range::all()) << endl;
88
89    r *= -1.0;
90
91 dump("r", r);
92
93    // Allocate the descent direction arrays
94    Array<T_numtype,N_rank> p, q;
95    allocateArrays(x.shape(), p, q);
96
97    int iteration = 0;
98    int converged = 0;
99    T_numtype rho = 0.;
100    T_numtype oldrho = 0.;
101
102    const int maxIterations = 1000;
103
104    // Get views of interior of arrays (without boundaries)
105    Array<T_numtype,N_rank> rint = r(interior);
106    Array<T_numtype,N_rank> pint = p(interior);
107    Array<T_numtype,N_rank> qint = q(interior);
108    Array<T_numtype,N_rank> xint = x(interior);
109
110    while (iteration < maxIterations)
111    {
112        rho = sum(r * r);
113
114        if ((iteration % 20) == 0)
115            cout << "CG: Iter " << iteration << "\t rho = " << rho << endl;
116
117        // Check halting condition
118        if (rho < haltrho)
119        {
120            converged = 1;
121            break;
122        }
123
124        if (iteration == 0)
125        {
126            p = r;
127        }
128        else {
129            T_numtype beta = rho / oldrho;
130            p = beta * p + r;
131        }
132
133        q = 0.;
134//        boundaryConditions.applyBCs(p);
135        applyStencil(stencil, q, p);
136
137        T_numtype pq = sum(p*q);
138
139        T_numtype alpha = rho / pq;
140
141        x += alpha * p;
142        r -= alpha * q;
143
144        oldrho = rho;
145        ++iteration;
146    }
147
148    if (!converged)
149        cout << "Warning: CG solver did not converge" << endl;
150
151    return iteration;
152}
153
154BZ_NAMESPACE_END
155
156#endif // BZ_CGSOLVE_H
Note: See TracBrowser for help on using the repository browser.