TABLE OF CONTENTS


ABINIT/m_xgScalapack [ Modules ]

[ Top ] [ Modules ]

NAME

  m_xgScalapack

FUNCTION

COPYRIGHT

  Copyright (C) 2017-2022 ABINIT group (J. Bieder)
  This file is distributed under the terms of the
  GNU General Public License, see ~abinit/COPYING
  or http://www.gnu.org/copyleft/gpl.txt .

SOURCE

15 #if defined HAVE_CONFIG_H
16 #include "config.h"
17 #endif
18 
19 #include "abi_common.h"
20 
21 module m_xgScalapack
22 
23   use defs_basis, only : std_err, std_out, dp
24   use m_abicore
25   use m_xmpi
26   use m_errors
27   use m_slk
28   use m_xg
29   use m_xomp
30   use m_time,     only: timab
31 
32 #ifdef HAVE_MPI2
33  use mpi
34 #endif
35 
36   implicit none
37 
38 #ifdef HAVE_MPI1
39  include 'mpif.h'
40 #endif
41 
42 
43   private
44 
45   integer, parameter :: M__SLK = 1
46   integer, parameter :: M__ROW = 1
47   integer, parameter :: M__COL = 2
48   integer, parameter :: M__UNUSED = 4
49   integer, parameter :: M__WORLD = 5
50   integer, parameter :: M__NDATA = 5
51   integer, parameter :: M__tim_init    = 1690
52   integer, parameter :: M__tim_free    = 1691
53   integer, parameter :: M__tim_heev    = 1692
54   integer, parameter :: M__tim_hegv    = 1693
55   integer, parameter :: M__tim_scatter = 1694
56 
57   integer, parameter, public :: SLK_AUTO = -1
58   integer, parameter, public :: SLK_FORCED = 1
59   integer, parameter, public :: SLK_DISABLED = 0
60   integer, save :: M__CONFIG = SLK_AUTO
61   integer, save :: M__MAXDIM = 1000
62 
63   type, public :: xgScalapack_t
64     integer :: comms(M__NDATA)
65     integer :: rank(M__NDATA)
66     integer :: size(M__NDATA)
67     integer :: coords(2)
68     integer :: ngroup
69     integer :: verbosity
70     type(grid_scalapack) :: grid
71   end type xgScalapack_t
72 
73   public :: xgScalapack_init
74   public :: xgScalapack_free
75   public :: xgScalapack_heev
76   public :: xgScalapack_hegv
77   public :: xgScalapack_config
78   contains

m_xgScalapack/xgScalapack_init [ Functions ]

[ Top ] [ m_xgScalapack ] [ Functions ]

NAME

  xgScalapack_init

FUNCTION

  Init the scalapack communicator for next operations.
  If the comm has too many cpus, then take only a subgroup of this comm

INPUTS

OUTPUT

SOURCE

 94   subroutine  xgScalapack_init(xgScalapack,comm,maxDim,verbosity,usable)
 95 
 96     type(xgScalapack_t), intent(inout) :: xgScalapack
 97     integer            , intent(in   ) :: comm
 98     integer            , intent(in   ) :: maxDim
 99     integer            , intent(in   ) :: verbosity
100     logical            , intent(  out) :: usable
101     double precision :: tsec(2)
102 #ifdef HAVE_LINALG_MKL_THREADS
103     integer :: mkl_get_max_threads
104 #endif
105 #ifdef HAVE_LINALG_OPENBLAS_THREADS
106     integer :: openblas_get_num_threads
107 #endif
108     integer :: nthread
109 #ifdef HAVE_LINALG_SCALAPACK
110     integer :: maxProc
111     integer :: nproc
112     integer :: ngroup
113     integer :: subgroup
114     integer :: mycomm(2)
115     integer :: ierr
116     integer :: test_row
117     integer :: test_col
118 #else
119     ABI_UNUSED(comm)
120     ABI_UNUSED(maxDim)
121 #endif
122 
123     call timab(M__tim_init,1,tsec)
124 
125     xgScalapack%comms = xmpi_comm_null
126     xgScalapack%rank = xmpi_undefined_rank
127     xgScalapack%verbosity = verbosity
128 
129     nthread = 1
130 #ifdef HAVE_LINALG_MKL_THREADS
131     nthread =  mkl_get_max_threads()
132 #elif HAVE_LINALG_OPENBLAS_THREADS
133     nthread =  openblas_get_num_threads()
134 #else
135     nthread = xomp_get_num_threads(open_parallel=.true.)
136     if ( nthread == 0 ) nthread = 1
137 #endif
138 
139 #ifdef HAVE_LINALG_SCALAPACK
140 
141     nproc = xmpi_comm_size(comm)
142     xgScalapack%comms(M__WORLD) = comm
143     xgScalapack%rank(M__WORLD) = xmpi_comm_rank(comm)
144     xgScalapack%size(M__WORLD) = nproc
145 
146     maxProc = (maxDim / (M__MAXDIM*nthread))+1 ! ( M__MAXDIM x M__MAXDIM matrice per MPI )
147     if ( M__CONFIG > 0 .and. M__CONFIG <= nproc ) then
148       maxProc = M__CONFIG
149     else if ( maxProc > nproc ) then
150       maxProc = nproc
151     end if
152 
153     if ( maxProc == 1 .or. M__CONFIG == SLK_DISABLED) then
154       usable = .false.
155       return
156     else if ( nthread > 1 ) then ! disable scalapack with threads since it is not threadsafe
157       ! This should be check with new elpa version en MPI+OpenMP
158       if ( M__CONFIG > 0 ) then
159         ABI_WARNING("xgScalapack turned off because you have threads")
160       end if
161       usable = .false.
162       return
163     else
164       usable = .true.
165       maxProc = 2*((maxProc+1)/2) ! Round to next even number
166     end if
167 
168     if ( xgScalapack%verbosity > 0 ) then
169       write(std_out,*) " xgScalapack will use", maxProc, "/", nproc, "MPIs"
170     end if
171 
172     ngroup = nproc/maxProc
173     xgScalapack%ngroup = ngroup
174 
175     if ( maxProc < nproc ) then
176       if ( xgScalapack%rank(M__WORLD) < maxProc*ngroup ) then
177         subgroup = xgScalapack%rank(M__WORLD)/maxProc
178         mycomm(1) = M__SLK
179         mycomm(2) = M__UNUSED
180       else
181         subgroup = ngroup+1
182         mycomm(1) = M__UNUSED
183         mycomm(2) = M__SLK
184       end if
185        call MPI_Comm_split(comm, subgroup, xgScalapack%rank(M__WORLD), xgScalapack%comms(mycomm(1)),ierr)
186        if ( ierr /= 0 ) then
187          ABI_ERROR("Error splitting communicator")
188        end if
189        xgScalapack%comms(mycomm(2)) = xmpi_comm_null
190        xgScalapack%rank(mycomm(1)) = xmpi_comm_rank(xgScalapack%comms(mycomm(1)))
191        xgScalapack%rank(mycomm(2)) = xmpi_undefined_rank
192        xgScalapack%size(mycomm(1)) = xmpi_comm_size(xgScalapack%comms(mycomm(1)))
193        xgScalapack%size(mycomm(2)) = nproc - xgScalapack%size(mycomm(1))
194     else
195        call MPI_Comm_dup(comm,xgScalapack%comms(M__SLK),ierr)
196        if ( ierr /= 0 ) then
197          ABI_ERROR("Error duplicating communicator")
198        end if
199        xgScalapack%rank(M__SLK) = xmpi_comm_rank(xgScalapack%comms(M__SLK))
200        xgScalapack%size(M__SLK) = nproc
201     end if
202 
203     if ( xgScalapack%comms(M__SLK) /= xmpi_comm_null ) then
204       call xgScalapack%grid%init(xgScalapack%size(M__SLK), xgScalapack%comms(M__SLK))
205       call BLACS_GridInfo(xgScalapack%grid%ictxt, &
206         xgScalapack%grid%dims(M__ROW), xgScalapack%grid%dims(M__COL),&
207         xgScalapack%coords(M__ROW), xgScalapack%coords(M__COL))
208 
209      !These values are the same as those computed by BLACS_GRIDINFO
210      !except in the case where the myproc argument is not the local proc
211       test_row = INT((xgScalapack%rank(M__SLK)) / xgScalapack%grid%dims(2))
212       test_col = MOD((xgScalapack%rank(M__SLK)), xgScalapack%grid%dims(2))
213       if ( test_row /= xgScalapack%coords(M__ROW) ) then
214         ABI_WARNING("Row id mismatch")
215       end if
216       if ( test_col /= xgScalapack%coords(M__COL) ) then
217         ABI_WARNING("Col id mismatch")
218       end if
219     end if
220 
221 #else
222     usable = .false.
223 #endif
224 
225     call timab(M__tim_init,2,tsec)
226 
227   end subroutine xgScalapack_init
228 
229   subroutine xgScalapack_config(myconfig,maxDim)
230 
231     integer, intent(in) :: myconfig
232     integer, intent(in) :: maxDim
233     if ( myconfig == SLK_AUTO) then
234       M__CONFIG = myconfig
235       ABI_COMMENT("xgScalapack in auto mode")
236     else if ( myconfig == SLK_DISABLED) then
237       M__CONFIG = myconfig
238       ABI_COMMENT("xgScalapack disabled")
239     else if ( myconfig > 0) then
240       M__CONFIG = myconfig
241       ABI_COMMENT("xgScalapack enabled")
242     else
243       ABI_WARNING("Bad value for xgScalapack config -> autodetection")
244       M__CONFIG = SLK_AUTO
245     end if
246     if ( maxDim > 0 ) then
247       M__MAXDIM = maxDim
248     end if
249 
250   end subroutine xgScalapack_config
251 
252   function toProcessorScalapack(xgScalapack) result(processor)
253 
254     type(xgScalapack_t), intent(in) :: xgScalapack
255     type(processor_scalapack) :: processor
256 
257     processor%myproc = xgScalapack%rank(M__SLK)
258     processor%comm = xgScalapack%comms(M__SLK)
259     processor%coords = xgScalapack%coords
260     processor%grid = xgScalapack%grid
261   end function toProcessorScalapack
262 
263   !This is for testing purpose.
264   !May not be optimal since I do not control old implementation but at least gives a reference.
265   subroutine xgScalapack_heev(xgScalapack,matrixA,eigenvalues)
266     use, intrinsic :: iso_c_binding
267     type(xgScalapack_t), intent(inout) :: xgScalapack
268     type(xgBlock_t)    , intent(inout) :: matrixA
269     type(xgBlock_t)    , intent(inout) :: eigenvalues
270 #ifdef HAVE_LINALG_SCALAPACK
271     double precision, pointer :: matrix(:,:) !(cplex*nbli_global,nbco_global)
272     double precision, pointer :: eigenvalues_tmp(:,:)
273     double precision, pointer :: vector(:)
274     double precision :: tsec(2)
275     integer :: cplex
276     integer :: istwf_k
277     integer :: nbli_global, nbco_global
278     type(c_ptr) :: cptr
279     integer :: req(2), status(MPI_STATUS_SIZE,2), ierr
280 #endif
281 
282 #ifdef HAVE_LINALG_SCALAPACK
283     call timab(M__tim_heev,1,tsec)
284 
285     ! Keep only working processors
286     if ( xgScalapack%comms(M__SLK) /= xmpi_comm_null ) then
287 
288       call xgBlock_getSize(eigenvalues,nbli_global,nbco_global)
289       if ( cols(matrixA) /= nbli_global ) then
290         ABI_ERROR("Number of eigen values differ from number of vectors")
291       end if
292 
293       if ( space(matrixA) == SPACE_C ) then
294         cplex = 2
295         istwf_k = 1
296       else
297         cplex = 1
298         istwf_k = 2
299       endif
300 
301       call xgBlock_getSize(matrixA,nbli_global,nbco_global)
302 
303       call xgBlock_reverseMap(matrixA,matrix,nbli_global,nbco_global)
304       call xgBlock_reverseMap(eigenvalues,eigenvalues_tmp,nbco_global,1)
305       cptr = c_loc(eigenvalues_tmp)
306       call c_f_pointer(cptr,vector,(/ nbco_global /))
307 
308       call compute_eigen1(xgScalapack%comms(M__SLK), &
309         toProcessorScalapack(xgScalapack), &
310         cplex,nbli_global,nbco_global,matrix,vector,istwf_k)
311 
312     end if
313 
314     call timab(M__tim_heev,2,tsec)
315 
316     req(1:2)=-1
317     call xgScalapack_scatter(xgScalapack,matrixA,req(1))
318     call xgScalapack_scatter(xgScalapack,eigenvalues,req(2))
319 #ifdef HAVE_MPI
320     if ( any(req/=-1)  ) then
321       call MPI_WaitAll(2,req,status,ierr)
322       if ( ierr /= 0 ) then
323           ABI_ERROR("Error waiting data")
324       endif
325     end if
326 #endif
327 #else
328    ABI_ERROR("ScaLAPACK support not available")
329    ABI_UNUSED(xgScalapack%verbosity)
330    ABI_UNUSED(matrixA%normal)
331    ABI_UNUSED(eigenvalues%normal)
332 #endif
333 
334   end subroutine xgScalapack_heev
335 
336   !This is for testing purpose.
337   !May not be optimal since I do not control old implementation but at least gives a reference.
338   subroutine xgScalapack_hegv(xgScalapack,matrixA,matrixB,eigenvalues)
339     use, intrinsic :: iso_c_binding
340     type(xgScalapack_t), intent(inout) :: xgScalapack
341     type(xgBlock_t)    , intent(inout) :: matrixA
342     type(xgBlock_t)    , intent(inout) :: matrixB
343     type(xgBlock_t)    , intent(inout) :: eigenvalues
344 #ifdef HAVE_LINALG_SCALAPACK
345     double precision, pointer :: matrix1(:,:) !(cplex*nbli_global,nbco_global)
346     double precision, pointer :: matrix2(:,:) !(cplex*nbli_global,nbco_global)
347     double precision, pointer :: eigenvalues_tmp(:,:)
348     double precision, pointer :: vector(:)
349     double precision :: tsec(2)
350     integer :: cplex
351     integer :: istwf_k
352     integer :: nbli_global, nbco_global
353     type(c_ptr) :: cptr
354     integer :: req(2), status(MPI_STATUS_SIZE,2),ierr
355 #endif
356 
357 #ifdef HAVE_LINALG_SCALAPACK
358     call timab(M__tim_hegv,1,tsec)
359 
360     ! Keep only working processors
361     if ( xgScalapack%comms(M__SLK) /= xmpi_comm_null ) then
362 
363       call xgBlock_getSize(eigenvalues,nbli_global,nbco_global)
364       if ( cols(matrixA) /= cols(matrixB) ) then
365         ABI_ERROR("Matrix A and B don't have the same number of vectors")
366       end if
367 
368       if ( cols(matrixA) /= nbli_global ) then
369         ABI_ERROR("Number of eigen values differ from number of vectors")
370       end if
371 
372       if ( space(matrixA) == SPACE_C ) then
373         cplex = 2
374         istwf_k = 1
375       else
376         cplex = 1
377         istwf_k = 2
378       endif
379 
380       call xgBlock_getSize(matrixA,nbli_global,nbco_global)
381 
382       call xgBlock_reverseMap(matrixA,matrix1,nbli_global,nbco_global)
383       call xgBlock_reverseMap(matrixB,matrix2,nbli_global,nbco_global)
384       call xgBlock_reverseMap(eigenvalues,eigenvalues_tmp,nbco_global,1)
385       cptr = c_loc(eigenvalues_tmp)
386       call c_f_pointer(cptr,vector,(/ nbco_global /))
387 
388       call compute_eigen2(xgScalapack%comms(M__SLK), &
389         toProcessorScalapack(xgScalapack), &
390         cplex,nbli_global,nbco_global,matrix1,matrix2,vector,istwf_k)
391     end if
392 
393     call timab(M__tim_hegv,2,tsec)
394 
395     req(1:2)=-1
396     call xgScalapack_scatter(xgScalapack,matrixA,req(1))
397     call xgScalapack_scatter(xgScalapack,eigenvalues,req(2))
398 #ifdef HAVE_MPI
399     if ( any(req/=-1)   ) then
400       call MPI_WaitAll(2,req,status,ierr)
401       if ( ierr /= 0 ) then
402           ABI_ERROR("Error waiting data")
403       endif
404     end if
405 #endif
406 #else
407    ABI_ERROR("ScaLAPACK support not available")
408    ABI_UNUSED(xgScalapack%verbosity)
409    ABI_UNUSED(matrixA%normal)
410    ABI_UNUSED(matrixB%normal)
411    ABI_UNUSED(eigenvalues%normal)
412 #endif
413 
414   end subroutine xgScalapack_hegv
415 
416 
417   subroutine xgScalapack_scatter(xgScalapack,matrix,req)
418 
419     type(xgScalapack_t), intent(in   ) :: xgScalapack
420     type(xgBlock_t)    , intent(inout) :: matrix
421     integer            , intent(  out) :: req
422     double precision, pointer :: tab(:,:)
423     double precision :: tsec(2)
424     integer :: cols, rows
425     integer :: ierr
426     integer :: sendto, receivefrom
427     integer :: lap
428 
429     call timab(M__tim_scatter,1,tsec)
430 
431     call xgBlock_getSize(matrix,rows,cols)
432     call xgBlock_reverseMap(matrix,tab,rows,cols)
433 
434     ! If we did the he(e|g)v and we are the first group
435     if ( xgScalapack%comms(M__SLK) /= xmpi_comm_null .and. xgScalapack%rank(M__WORLD)<xgScalapack%size(M__SLK) ) then
436       lap = xgScalapack%ngroup
437       sendto = xgScalapack%rank(M__WORLD) + lap*xgScalapack%size(M__SLK)
438       if ( sendto < xgScalapack%size(M__WORLD) ) then
439       !do while ( sendto < xgScalapack%size(M__WORLD) )
440         !call xmpi_send(tab,sendto,sendto,xgScalapack%comms(M__WORLD),ierr)
441         call xmpi_isend(tab,sendto,sendto,xgScalapack%comms(M__WORLD),req,ierr)
442         !write(*,*) xgScalapack%rank(M__WORLD), "sends to", sendto
443         if ( ierr /= 0 ) then
444           ABI_ERROR("Error sending data")
445         end if
446         !lap = lap+1
447         !sendto = xgScalapack%rank(M__WORLD) + lap*xgScalapack%size(M__SLK)
448       !end do
449       end if
450     else if ( xgScalapack%comms(M__UNUSED) /= xmpi_comm_null ) then
451       receivefrom = MODULO(xgScalapack%rank(M__WORLD), xgScalapack%size(M__SLK))
452       if ( receivefrom >= 0 ) then
453         !call xmpi_recv(tab,receivefrom,xgScalapack%rank(M__WORLD),xgScalapack%comms(M__WORLD),ierr)
454         call xmpi_irecv(tab,receivefrom,xgScalapack%rank(M__WORLD),xgScalapack%comms(M__WORLD),req,ierr)
455         !write(*,*) xgScalapack%rank(M__WORLD), "receive from", receivefrom
456         if ( ierr /= 0 ) then
457           ABI_ERROR("Error receiving data")
458         end if
459       end if
460     !else
461       !ABI_BUG("Error scattering data")
462     end if
463 
464     call timab(M__tim_scatter,2,tsec)
465 
466   end subroutine xgScalapack_scatter
467 
468 
469   subroutine  xgScalapack_free(xgScalapack)
470 
471     type(xgScalapack_t), intent(inout) :: xgScalapack
472     double precision :: tsec(2)
473 #ifdef HAVE_LINALG_SCALAPACK
474     integer :: ierr
475 #endif
476 
477     call timab(M__tim_free,1,tsec)
478 #ifdef HAVE_LINALG_SCALAPACK
479     if ( xgScalapack%comms(M__SLK) /= xmpi_comm_null ) then
480       call BLACS_GridExit(xgScalapack%grid%ictxt)
481       call MPI_Comm_free(xgScalapack%comms(M__SLK),ierr)
482     end if
483     if ( xgScalapack%comms(M__UNUSED) /= xmpi_comm_null ) then
484       call MPI_Comm_free(xgScalapack%comms(M__UNUSED),ierr)
485     end if
486 #else
487     ABI_UNUSED(xgScalapack%verbosity)
488 #endif
489     call timab(M__tim_free,2,tsec)
490 
491   end subroutine xgScalapack_free
492 
493 end module m_xgScalapack