New URL for NEMO forge!   http://forge.nemo-ocean.eu

Since March 2022 along with NEMO 4.2 release, the code development moved to a self-hosted GitLab.
This present forge is now archived and remained online for history.
cuda_fortran.F90 in NEMO/branches/2021/dev_r13747_HPC-11_mcastril_HPDAonline_DiagGPU/tests/DIA_GPU/MY_SRC – NEMO

source: NEMO/branches/2021/dev_r13747_HPC-11_mcastril_HPDAonline_DiagGPU/tests/DIA_GPU/MY_SRC/cuda_fortran.F90 @ 15063

Last change on this file since 15063 was 14091, checked in by mcastril, 4 years ago

Add DIA_GPU test case

File size: 5.6 KB
Line 
1#ifdef key_gpu
2MODULE cuda_fortran
3    USE cudafor
4    USE dom_oce
5    USE lib_mpp
6    CONTAINS
7      ATTRIBUTES(global) &
8      SUBROUTINE dia_hsb_kernel(surf, e3t, surf_ini, e3t_ini, ts, hc_loc_ini, sc_loc_ini,     &
9                            & tmask, tmask_ini, zwrkv, zwrkh, zwrks, zwrk, jpi, jpj, jpk, jpt, Kmm)
10            !
11            INTEGER, VALUE :: jpi, jpj, jpk, jpt                                                                         ! MPI process sub-domain and stream tile
12            INTEGER, VALUE :: Kmm                                                                                   !ocean time level indice
13            REAL(8)        :: surf(jpi, jpj), e3t(jpi, jpj, jpk, jpt), surf_ini(jpi, jpj), e3t_ini(jpi, jpj, jpk), ts(jpi, jpj, jpk, 2, jpt)
14            REAL(8)        :: hc_loc_ini(jpi, jpj, jpk), sc_loc_ini(jpi, jpj, jpk), tmask(jpi, jpj, jpk), tmask_ini(jpi, jpj, jpk)
15            REAL(8)        :: zwrkv(jpi, jpj, jpk), zwrkh(jpi, jpj, jpk), zwrks(jpi, jpj, jpk), zwrk(jpi, jpj, jpk)
16            !
17
18            INTEGER        :: i, j, k                       ! dummy indexes
19
20            !
21            i = blockDim%x * (blockIdx%x -1) + threadIdx%x
22            j = blockDim%y * (blockIdx%y -1) + threadIdx%y
23            k = blockDim%z * (blockIdx%z -1) + threadIdx%z
24            !
25            IF ( (i .le. jpi ) .AND. (j .le. jpj) .AND. (k .le. jpk-1) ) THEN
26                zwrkv(i, j, k) =   surf    (i,j) * e3t    (i,j,k,Kmm) * tmask    (i,j,k)      &
27                &                - surf_ini(i,j) * e3t_ini(i,j,k)     * tmask_ini(i,j,k)
28               
29                zwrkh(i, j, k) = ( surf(i, j) * e3t(i, j, k, Kmm) * ts(i, j, k, 1, Kmm) - surf_ini(i, j) * hc_loc_ini(i, j, k) )
30               
31                zwrks(i, j, k) = ( surf(i, j) * e3t(i, j, k, Kmm) * ts(i, j, k, 2, Kmm) - surf_ini(i, j) * sc_loc_ini(i, j, k) )
32               
33                zwrk(i, j, k)  =   surf(i, j) * e3t(i, j, k, Kmm) * tmask(i, j, k)
34            END IF
35      END SUBROUTINE dia_hsb_kernel
36
37!     ATTRIBUTES(global) &
38!     SUBROUTINE dia_hsb_kernel1d(surf, e3t_n, surf_ini, e3t_ini, &
39!        &  tsn, hc_loc_ini, sc_loc_ini, tmask, zwrkv, zwrkh, zwrks, zwrk, jpi, jpj, jpk)
40!        IMPLICIT NONE
41!        REAL(kind=8) :: surf(:, :), e3t_n(:, :, :), surf_ini(:, :), e3t_ini(:, :, :), tsn(:, :, :, :), &
42!                            & hc_loc_ini(:, :, :), sc_loc_ini(:, :, :), tmask(:, :, :), zwrkv(:, :, :),      &
43!                            & zwrkh(:, :, :), zwrks(:, :, :), zwrk(:, :, :)
44!        REAL(kind=8), SHARED       :: sdata(*)
45!        INTEGER, VALUE             :: jpi, jpj, jpk
46!        INTEGER                    :: i, si, ti, globsize
47!
48!        globsize = jpi*jpj*jpk
49!
50!        i  = blockDim%x * (blockIdx%x-1) + threadIdx%x
51!        ti = threadIdx%x
52!        si = MOD(i, jpk)+1 !Shared memory indexing
53!
54!        !Compute volume (zwrk), volume variation (zwrkv), heat deviation (zwrkh)
55!        !and salinity deviation (zwrks)
56!        IF ( i .le. globsize ) THEN
57!            sdata(ti) = surf(si) * e3t_n(i)
58!            zwrkv(i) = ( sdata(ti) - surf_ini(si) * e3t_ini(i) ) * tmask(i) * surf(si)
59!            zwrkh(i) = ( sdata(ti) * tsn(i) - surf_ini(si) * hc_loc_ini(i) ) * tmask(i) * surf(si)
60!            zwrks(i) = ( sdata(ti) * tsn(globsize + i) - surf_ini(si) * sc_loc_ini(i) ) &
61!                    * tmask(i) * surf(si)
62!            zwrk(i) = sdata(ti) * tmask(i) * surf(si)
63!        END IF
64!     END SUBROUTINE dia_hsb_kernel1d
65
66
67       ATTRIBUTES(global) &
68       SUBROUTINE filter_cuda(ptab, mask, jpi, jpj, jpk)
69           !
70           REAL(kind=8)         :: ptab(jpi,jpj,jpk) , mask(jpi,jpj)
71           !
72           INTEGER, VALUE  :: jpi, jpj, jpk                 ! MPI process sub-domain
73           INTEGER         :: i, j, k                 ! dummy indexes
74           !
75           i = blockDim%x * (blockIdx%x -1) + threadIdx%x
76           j = blockDim%y * (blockIdx%y -1) + threadIdx%y
77           k = blockDim%z * (blockIdx%z -1) + threadIdx%z
78           !
79           tile = 1
80           IF ( (i .le. jpi ) .AND. (j .le. jpj) .AND. (k .le. jpk-1) ) THEN
81               ptab(i, j, k)  = ptab(i, j, k) * mask(i, j)
82           END IF
83       END SUBROUTINE filter_cuda
84
85       ATTRIBUTES(global) SUBROUTINE array3dto1d(d_inp,d_out,ipi, ipj, ipk)
86           !
87           REAL(kind=8)                 :: d_inp(:,:,:)
88           REAL(kind=8), intent(out)    :: d_out(:)
89           !
90           INTEGER, VALUE  :: ipi, ipj, ipk                 ! MPI process sub-domain
91           INTEGER         :: i, j, k                       ! dummy indexes
92           !
93           i = blockDim%x * (blockIdx%x -1) + threadIdx%x
94           j = blockDim%y * (blockIdx%y -1) + threadIdx%y
95           k = blockDim%z * (blockIdx%z -1) + threadIdx%z
96           !
97           IF ( (i .le. ipi ) .AND. (j .le. ipj) .AND. (k .le. ipk) ) THEN
98               d_out(i + ipi*(j-1) + ipi*ipj*(k-1) )  = d_inp(i, j, k)
99           END IF
100       END SUBROUTINE array3dto1d
101
102       !Knuth's trick
103       ATTRIBUTES(device) &
104       SUBROUTINE DDPDD_d(ydda, yddb)
105           IMPLICIT NONE
106           COMPLEX(kind=8), INTENT(in   ) :: ydda              !Scalar to add
107           COMPLEX(kind=8), INTENT(inout) :: yddb              !Total sum
108           REAL(kind=8)                   :: zerr, zt1, zt2
109       
110           zt1  = REAL(ydda) + REAL(yddb)
111           zerr = zt1 - REAL(ydda)
112           zt2  = ( (REAL(yddb) - zerr) + (REAL(ydda) - (zt1 - zerr)) ) &
113                 & + AIMAG(ydda) + AIMAG(yddb)
114           !The result is t1 + t2 after normalization
115           yddb = CMPLX( zt1 + zt2, zt2 - ((zt1 + zt2) - zt1), 8 )
116       END SUBROUTINE DDPDD_d       
117END MODULE
118#endif
Note: See TracBrowser for help on using the repository browser.