TABLE OF CONTENTS


ABINIT/m_octree [ Modules ]

[ Top ] [ Modules ]

NAME

 m_octree

FUNCTION

  A structure to find nearest neightbour k-points

COPYRIGHT

  Copyright (C) 2010-2024 ABINIT group (HM)
  This file is distributed under the terms of the
  GNU General Public License, see ~abinit/COPYING
  or http://www.gnu.org/copyleft/gpl.txt .

SOURCE

 16 #if defined HAVE_CONFIG_H
 17 #include "config.h"
 18 #endif
 19 
 20 #include "abi_common.h"
 21 
 22 module m_octree
 23 
 24   use defs_basis
 25   use m_abicore
 26   implicit none
 27 
 28   real(dp),protected :: shifts(3,2,8)
 29 
 30   type :: octree_node_t
 31     type(octree_node_t),pointer :: childs(:) ! if node this is allocated
 32     integer,allocatable :: ids(:)            ! if leaf this is allocated
 33   end type octree_node_t
 34 
 35   type :: octree_t
 36     real(dp) :: hi(3)
 37     real(dp) :: lo(3)
 38     integer :: max_npoints
 39     real(dp),pointer :: points(:,:)
 40     type(octree_node_t) :: first
 41   end type octree_t
 42 
 43 !--------------------------------------------------------------------
 44 
 45   contains
 46   type(octree_t) function octree_init(points,max_npoints,lo,hi) result (new)
 47     integer,intent(in) :: max_npoints
 48     real(dp), target, intent(in) :: points(:,:)
 49     real(dp), intent(in) :: lo(3), hi(3)
 50     real,parameter :: ieps = 0.1
 51     integer,allocatable :: ids(:)
 52     integer :: ii, jj, kk, npoints, ioctant
 53 
 54     ! determine shifts
 55     do ii=0,1
 56       do jj=0,1
 57         do kk=0,1
 58           ioctant = ii*4+jj*2+kk+1
 59           shifts(:,1,ioctant) = [half*ii,half*jj,half*kk]
 60           shifts(:,2,ioctant) = shifts(:,1,ioctant) + [half,half,half]
 61         end do
 62       end do
 63     end do
 64 
 65     ! box dimensions
 66     new%lo = lo
 67     new%hi = hi
 68 
 69     !first octree contains all the points
 70     npoints = size(points,2)
 71     new%points => points
 72     new%max_npoints = max_npoints
 73     ids = [(ii,ii=1,npoints)]
 74     new%first = octree_node_build(new,new%lo,new%hi,npoints,ids)
 75 
 76   end function octree_init
 77 
 78   type(octree_node_t) recursive function octree_node_build(octree,lo,hi,nids,ids) result (new)
 79     type(octree_t),intent(in) :: octree
 80     integer,intent(in) :: nids
 81     integer,intent(in) :: ids(nids)
 82     integer :: id, counter, ioctant, ipoint
 83     integer :: octants(nids)
 84     integer :: new_ids(nids)
 85     real(dp) :: lo(3), hi(3), new_lo(3), new_hi(3)
 86 
 87     ! check if this is a leaf node
 88     if (nids<octree%max_npoints) then
 89       ABI_MALLOC(new%ids,(nids))
 90       new%ids = ids(:nids)
 91       return
 92     end if
 93 
 94     !determine the octants of each point
 95     call get_octants(lo,hi,nids,ids,octree%points,octants)
 96 
 97     ABI_MALLOC(new%childs,(8))
 98     do ioctant=1,8
 99       ! How many points are in this octant?
100       counter = 0
101       do id=1,nids
102         ipoint = ids(id)
103         if (octants(id) /= ioctant) cycle
104         counter = counter + 1
105         new_ids(counter) = ipoint
106       end do
107       ! Build this octant
108       call get_lo_hi(lo,hi,new_lo,new_hi,ioctant)
109       new%childs(ioctant) = octree_node_build(octree,new_lo,new_hi,counter,new_ids)
110     end do
111 
112   end function octree_node_build
113 
114   integer function octree_find(octree,point,dist) result(closest_id)
115     ! Find the closest point in the box that contains it
116     type(octree_t),target,intent(in) :: octree
117     real(dp),intent(in) :: point(3)
118     real(dp),intent(out) :: dist
119     type(octree_node_t),pointer :: octn
120     integer :: id, ipoint, ioctant
121     real(dp) :: hi(3),lo(3),hi_out(3),lo_out(3)
122     real(dp) :: trial_dist
123 
124     closest_id = 0
125     dist = huge(dist)
126     octn => octree%first
127     lo = octree%lo
128     hi = octree%hi
129 
130     !check if the point is inside the initial box
131     trial_dist = box_dist(lo,hi,point)
132     if (trial_dist>0) then
133       closest_id = -1
134       return
135     end if
136 
137     do
138       ! if leaf node
139       if (allocated(octn%ids)) then
140         do id=1,size(octn%ids)
141           ipoint = octn%ids(id)
142           trial_dist = dist_points(octree%points(:,ipoint),point)
143           if (trial_dist > dist) cycle
144           dist = trial_dist
145           closest_id = ipoint
146         end do
147         return
148       end if
149       ! get octant of this point
150       ioctant = get_octant_lohi(lo,hi,point)
151       ! point to this node
152       octn => octn%childs(ioctant)
153       ! get lo and hi
154       call get_lo_hi(lo,hi,lo_out,hi_out,ioctant)
155       lo = lo_out; hi = hi_out
156     end do
157   end function octree_find
158 
159   integer function octree_find_nearest(octree,point,dist) result(nearest_id)
160     ! Find the ids of the points whose distance to point is smaller than max_dist
161     ! counter is the number of points found so far
162     type(octree_t),target,intent(in) :: octree
163     real(dp),intent(in) :: point(3)
164     real(dp),intent(inout) :: dist
165     real(dp) :: check_dist
166     !check if the point is inside the initial box
167     check_dist = box_dist(octree%lo,octree%hi,point)
168     if (check_dist>0) then
169       nearest_id = -1
170       return
171     end if
172     nearest_id = octn_find_nearest(octree,octree%first,octree%lo,octree%hi,point,dist)
173   end function octree_find_nearest
174 
175   integer recursive function octn_find_nearest(octree,octn,lo,hi,point,min_dist) result(closest_id)
176     ! find the nearest point by recursion
177     type(octree_t),intent(in) :: octree
178     type(octree_node_t),intent(in) :: octn
179     real(dp),intent(in) :: point(3)
180     real(dp),intent(in) :: lo(3), hi(3)
181     real(dp),intent(inout) :: min_dist
182     real(dp) :: new_lo(3), new_hi(3)
183     real(dp) :: dist
184     integer :: id, ioctant, ipoint, trial_id
185     closest_id = 0
186     ! compute distance of point to this box (octant)
187     dist = box_dist(lo,hi,point)
188     ! if the distance is bigger than the closest point so far return
189     if (dist>min_dist) return
190     ! if this node is a leaf compare point by point
191     if (allocated(octn%ids)) then
192       do id=1,size(octn%ids)
193         ipoint = octn%ids(id)
194         dist = dist_points(octree%points(:,ipoint),point)
195         if (dist>min_dist) cycle
196         min_dist = dist
197         closest_id = ipoint
198       end do
199       return
200     end if
201     ! if this is a node then find the nearest for all the childs
202     do ioctant=1,8
203       call get_lo_hi(lo,hi,new_lo,new_hi,ioctant)
204       trial_id = octn_find_nearest(octree,octn%childs(ioctant),new_lo,new_hi,point,min_dist)
205       if (trial_id==0) cycle
206       closest_id = trial_id
207     end do
208   end function octn_find_nearest
209 
210   integer function octree_find_nearest_pbc(octree,point,dist,shift) result(id)
211     ! Same as octree find but using periodic boundary conditions
212     type(octree_t),target,intent(in) :: octree
213     real(dp),intent(in) :: point(3)
214     real(dp),intent(inout) :: dist
215     real(dp),intent(out) :: shift(3)
216     logical :: found
217     integer :: trial_id, first_id
218     integer :: ii,jj,kk
219     real(dp) :: trial_dist
220     real(dp) :: first_shift(3), trial_shift(3)
221     real(dp) :: po(3)
222     id = 0
223     ! bring the point inside the box
224     po = modulo(point,one)
225     !po = modulo(point+half,one)-half
226     ! compute shift
227     first_shift = po-point
228     trial_dist = dist
229     first_id = octn_find_nearest(octree,octree%first,octree%lo,octree%hi,po,trial_dist)
230     if (first_id>0.and.trial_dist<dist) then
231       id   = first_id
232       dist = trial_dist
233     end if
234     ! try unitary shifts
235     found=.false.
236     do ii=-1,1
237       do jj=-1,1
238         do kk=-1,1
239           if (ii==0.and.jj==0.and.kk==0) cycle
240           trial_shift = first_shift+[ii,jj,kk]
241           ! compute shortest distance
242           trial_dist = dist+tol12
243           trial_id = octn_find_nearest(octree,octree%first,octree%lo,octree%hi,&
244                                        point+trial_shift,trial_dist)
245           ! if smaller than previous distance, store this shift and distance
246           if (trial_dist>dist) cycle
247           found = .true.
248           dist  = trial_dist
249           id    = trial_id
250           shift = trial_shift
251         end do
252       end do
253     end do
254     if (.not.found) shift = first_shift
255   end function octree_find_nearest_pbc
256 
257   integer recursive function octn_free(octn) result(ierr)
258     type(octree_node_t) :: octn
259     integer :: ioctant
260     ! if leaf deallocate ids
261     if (allocated(octn%ids)) then
262       ABI_FREE(octn%ids)
263     else
264       do ioctant=1,8
265         ierr = octn_free(octn%childs(ioctant))
266       end do
267     end if
268   end function octn_free
269 
270   integer function octree_free(octree) result(ierr)
271     ! Free octree datastructure
272     type(octree_t),target,intent(in) :: octree
273     ierr = octn_free(octree%first)
274   end function octree_free
275 
276   pure real(dp) function dist_points(p1,p2) result(dist)
277     real(dp),intent(in) :: p1(3),p2(3)
278     dist = pow2(p1(1)-p2(1))+&
279            pow2(p1(2)-p2(2))+&
280            pow2(p1(3)-p2(3))
281   end function dist_points
282 
283   pure logical function box_contains(lo,hi,po) result(inside)
284     ! Find box that contains point
285     real(dp),intent(in) :: lo(3), hi(3), po(3)
286     inside = (po(1)>lo(1).and.po(1)<hi(1).and.&
287               po(2)>lo(2).and.po(2)<hi(2).and.&
288               po(3)>lo(3).and.po(3)<hi(3))
289   end function box_contains
290 
291   pure real(dp) function box_dist(lo,hi,po) result(dist)
292     ! Find the distance between point and the box
293     real(dp),intent(in) :: lo(3), hi(3), po(3)
294     dist = zero
295     if (po(1)<lo(1)) dist = dist + pow2(po(1)-lo(1))
296     if (po(1)>hi(1)) dist = dist + pow2(po(1)-hi(1))
297     if (po(2)<lo(2)) dist = dist + pow2(po(2)-lo(2))
298     if (po(2)>hi(2)) dist = dist + pow2(po(2)-hi(2))
299     if (po(3)<lo(3)) dist = dist + pow2(po(3)-lo(3))
300     if (po(3)>hi(3)) dist = dist + pow2(po(3)-hi(3))
301   end function box_dist
302 
303   pure real(dp) function pow2(x) result(x2)
304     real(dp),intent(in) :: x
305     x2 = x*x
306   end function pow2
307 
308   pure integer function get_octant(mi,po) result(ioctant)
309     real(dp),intent(in) :: po(3), mi(3)
310     integer :: ii,jj,kk
311     ii = 0; if (po(1)>=mi(1)) ii = 1
312     jj = 0; if (po(2)>=mi(2)) jj = 1
313     kk = 0; if (po(3)>=mi(3)) kk = 1
314     ioctant = ii*4+jj*2+kk+1
315   end function get_octant
316 
317   pure integer function get_octant_lohi(lo,hi,po) result(ioctant)
318     real(dp),intent(in) :: lo(3),hi(3),po(3)
319     real(dp) :: mi(3)
320     mi = lo+half*(hi-lo)
321     ioctant = get_octant(mi,po)
322   end function get_octant_lohi
323 
324   pure subroutine get_octants(lo,hi,nids,ids,points,octants)
325     ! From a list of points return the corresponding octant
326     real(dp),intent(in) :: lo(3), hi(3)
327     real(dp),intent(in) :: points(:,:)
328     integer,intent(in) :: nids
329     integer,intent(in) :: ids(nids)
330     integer,intent(out) :: octants(nids)
331     real(dp) :: mi(3)
332     integer :: id, ipoint
333     ! calculate midpoint
334     mi = lo+half*(hi-lo)
335     do id=1,nids
336       ipoint = ids(id)
337       octants(id) = get_octant(mi,points(:,ipoint))
338     end do
339   end subroutine get_octants
340 
341   pure subroutine get_lo_hi(lo_in,hi_in,lo_out,hi_out,ioctant)
342     ! Subdivide a box in an octant
343     integer,intent(in) :: ioctant
344     real(dp),intent(in) :: lo_in(3), hi_in(3)
345     real(dp),intent(out) :: lo_out(3), hi_out(3)
346     real(dp) :: de(3)
347     de = hi_in-lo_in
348     lo_out = lo_in + shifts(:,1,ioctant)*de
349     hi_out = lo_in + shifts(:,2,ioctant)*de
350   end subroutine get_lo_hi
351 
352 end module m_octree