TABLE OF CONTENTS
ABINIT/m_octree [ Modules ]
NAME
m_octree
FUNCTION
A structure to find nearest neightbour k-points
COPYRIGHT
Copyright (C) 2010-2024 ABINIT group (HM) This file is distributed under the terms of the GNU General Public License, see ~abinit/COPYING or http://www.gnu.org/copyleft/gpl.txt .
SOURCE
16 #if defined HAVE_CONFIG_H 17 #include "config.h" 18 #endif 19 20 #include "abi_common.h" 21 22 module m_octree 23 24 use defs_basis 25 use m_abicore 26 implicit none 27 28 real(dp),protected :: shifts(3,2,8) 29 30 type :: octree_node_t 31 type(octree_node_t),pointer :: childs(:) ! if node this is allocated 32 integer,allocatable :: ids(:) ! if leaf this is allocated 33 end type octree_node_t 34 35 type :: octree_t 36 real(dp) :: hi(3) 37 real(dp) :: lo(3) 38 integer :: max_npoints 39 real(dp),pointer :: points(:,:) 40 type(octree_node_t) :: first 41 end type octree_t 42 43 !-------------------------------------------------------------------- 44 45 contains 46 type(octree_t) function octree_init(points,max_npoints,lo,hi) result (new) 47 integer,intent(in) :: max_npoints 48 real(dp), target, intent(in) :: points(:,:) 49 real(dp), intent(in) :: lo(3), hi(3) 50 real,parameter :: ieps = 0.1 51 integer,allocatable :: ids(:) 52 integer :: ii, jj, kk, npoints, ioctant 53 54 ! determine shifts 55 do ii=0,1 56 do jj=0,1 57 do kk=0,1 58 ioctant = ii*4+jj*2+kk+1 59 shifts(:,1,ioctant) = [half*ii,half*jj,half*kk] 60 shifts(:,2,ioctant) = shifts(:,1,ioctant) + [half,half,half] 61 end do 62 end do 63 end do 64 65 ! box dimensions 66 new%lo = lo 67 new%hi = hi 68 69 !first octree contains all the points 70 npoints = size(points,2) 71 new%points => points 72 new%max_npoints = max_npoints 73 ids = [(ii,ii=1,npoints)] 74 new%first = octree_node_build(new,new%lo,new%hi,npoints,ids) 75 76 end function octree_init 77 78 type(octree_node_t) recursive function octree_node_build(octree,lo,hi,nids,ids) result (new) 79 type(octree_t),intent(in) :: octree 80 integer,intent(in) :: nids 81 integer,intent(in) :: ids(nids) 82 integer :: id, counter, ioctant, ipoint 83 integer :: octants(nids) 84 integer :: new_ids(nids) 85 real(dp) :: lo(3), hi(3), new_lo(3), new_hi(3) 86 87 ! check if this is a leaf node 88 if (nids<octree%max_npoints) then 89 ABI_MALLOC(new%ids,(nids)) 90 new%ids = ids(:nids) 91 return 92 end if 93 94 !determine the octants of each point 95 call get_octants(lo,hi,nids,ids,octree%points,octants) 96 97 ABI_MALLOC(new%childs,(8)) 98 do ioctant=1,8 99 ! How many points are in this octant? 100 counter = 0 101 do id=1,nids 102 ipoint = ids(id) 103 if (octants(id) /= ioctant) cycle 104 counter = counter + 1 105 new_ids(counter) = ipoint 106 end do 107 ! Build this octant 108 call get_lo_hi(lo,hi,new_lo,new_hi,ioctant) 109 new%childs(ioctant) = octree_node_build(octree,new_lo,new_hi,counter,new_ids) 110 end do 111 112 end function octree_node_build 113 114 integer function octree_find(octree,point,dist) result(closest_id) 115 ! Find the closest point in the box that contains it 116 type(octree_t),target,intent(in) :: octree 117 real(dp),intent(in) :: point(3) 118 real(dp),intent(out) :: dist 119 type(octree_node_t),pointer :: octn 120 integer :: id, ipoint, ioctant 121 real(dp) :: hi(3),lo(3),hi_out(3),lo_out(3) 122 real(dp) :: trial_dist 123 124 closest_id = 0 125 dist = huge(dist) 126 octn => octree%first 127 lo = octree%lo 128 hi = octree%hi 129 130 !check if the point is inside the initial box 131 trial_dist = box_dist(lo,hi,point) 132 if (trial_dist>0) then 133 closest_id = -1 134 return 135 end if 136 137 do 138 ! if leaf node 139 if (allocated(octn%ids)) then 140 do id=1,size(octn%ids) 141 ipoint = octn%ids(id) 142 trial_dist = dist_points(octree%points(:,ipoint),point) 143 if (trial_dist > dist) cycle 144 dist = trial_dist 145 closest_id = ipoint 146 end do 147 return 148 end if 149 ! get octant of this point 150 ioctant = get_octant_lohi(lo,hi,point) 151 ! point to this node 152 octn => octn%childs(ioctant) 153 ! get lo and hi 154 call get_lo_hi(lo,hi,lo_out,hi_out,ioctant) 155 lo = lo_out; hi = hi_out 156 end do 157 end function octree_find 158 159 integer function octree_find_nearest(octree,point,dist) result(nearest_id) 160 ! Find the ids of the points whose distance to point is smaller than max_dist 161 ! counter is the number of points found so far 162 type(octree_t),target,intent(in) :: octree 163 real(dp),intent(in) :: point(3) 164 real(dp),intent(inout) :: dist 165 real(dp) :: check_dist 166 !check if the point is inside the initial box 167 check_dist = box_dist(octree%lo,octree%hi,point) 168 if (check_dist>0) then 169 nearest_id = -1 170 return 171 end if 172 nearest_id = octn_find_nearest(octree,octree%first,octree%lo,octree%hi,point,dist) 173 end function octree_find_nearest 174 175 integer recursive function octn_find_nearest(octree,octn,lo,hi,point,min_dist) result(closest_id) 176 ! find the nearest point by recursion 177 type(octree_t),intent(in) :: octree 178 type(octree_node_t),intent(in) :: octn 179 real(dp),intent(in) :: point(3) 180 real(dp),intent(in) :: lo(3), hi(3) 181 real(dp),intent(inout) :: min_dist 182 real(dp) :: new_lo(3), new_hi(3) 183 real(dp) :: dist 184 integer :: id, ioctant, ipoint, trial_id 185 closest_id = 0 186 ! compute distance of point to this box (octant) 187 dist = box_dist(lo,hi,point) 188 ! if the distance is bigger than the closest point so far return 189 if (dist>min_dist) return 190 ! if this node is a leaf compare point by point 191 if (allocated(octn%ids)) then 192 do id=1,size(octn%ids) 193 ipoint = octn%ids(id) 194 dist = dist_points(octree%points(:,ipoint),point) 195 if (dist>min_dist) cycle 196 min_dist = dist 197 closest_id = ipoint 198 end do 199 return 200 end if 201 ! if this is a node then find the nearest for all the childs 202 do ioctant=1,8 203 call get_lo_hi(lo,hi,new_lo,new_hi,ioctant) 204 trial_id = octn_find_nearest(octree,octn%childs(ioctant),new_lo,new_hi,point,min_dist) 205 if (trial_id==0) cycle 206 closest_id = trial_id 207 end do 208 end function octn_find_nearest 209 210 integer function octree_find_nearest_pbc(octree,point,dist,shift) result(id) 211 ! Same as octree find but using periodic boundary conditions 212 type(octree_t),target,intent(in) :: octree 213 real(dp),intent(in) :: point(3) 214 real(dp),intent(inout) :: dist 215 real(dp),intent(out) :: shift(3) 216 logical :: found 217 integer :: trial_id, first_id 218 integer :: ii,jj,kk 219 real(dp) :: trial_dist 220 real(dp) :: first_shift(3), trial_shift(3) 221 real(dp) :: po(3) 222 id = 0 223 ! bring the point inside the box 224 po = modulo(point,one) 225 !po = modulo(point+half,one)-half 226 ! compute shift 227 first_shift = po-point 228 trial_dist = dist 229 first_id = octn_find_nearest(octree,octree%first,octree%lo,octree%hi,po,trial_dist) 230 if (first_id>0.and.trial_dist<dist) then 231 id = first_id 232 dist = trial_dist 233 end if 234 ! try unitary shifts 235 found=.false. 236 do ii=-1,1 237 do jj=-1,1 238 do kk=-1,1 239 if (ii==0.and.jj==0.and.kk==0) cycle 240 trial_shift = first_shift+[ii,jj,kk] 241 ! compute shortest distance 242 trial_dist = dist+tol12 243 trial_id = octn_find_nearest(octree,octree%first,octree%lo,octree%hi,& 244 point+trial_shift,trial_dist) 245 ! if smaller than previous distance, store this shift and distance 246 if (trial_dist>dist) cycle 247 found = .true. 248 dist = trial_dist 249 id = trial_id 250 shift = trial_shift 251 end do 252 end do 253 end do 254 if (.not.found) shift = first_shift 255 end function octree_find_nearest_pbc 256 257 integer recursive function octn_free(octn) result(ierr) 258 type(octree_node_t) :: octn 259 integer :: ioctant 260 ! if leaf deallocate ids 261 if (allocated(octn%ids)) then 262 ABI_FREE(octn%ids) 263 else 264 do ioctant=1,8 265 ierr = octn_free(octn%childs(ioctant)) 266 end do 267 end if 268 end function octn_free 269 270 integer function octree_free(octree) result(ierr) 271 ! Free octree datastructure 272 type(octree_t),target,intent(in) :: octree 273 ierr = octn_free(octree%first) 274 end function octree_free 275 276 pure real(dp) function dist_points(p1,p2) result(dist) 277 real(dp),intent(in) :: p1(3),p2(3) 278 dist = pow2(p1(1)-p2(1))+& 279 pow2(p1(2)-p2(2))+& 280 pow2(p1(3)-p2(3)) 281 end function dist_points 282 283 pure logical function box_contains(lo,hi,po) result(inside) 284 ! Find box that contains point 285 real(dp),intent(in) :: lo(3), hi(3), po(3) 286 inside = (po(1)>lo(1).and.po(1)<hi(1).and.& 287 po(2)>lo(2).and.po(2)<hi(2).and.& 288 po(3)>lo(3).and.po(3)<hi(3)) 289 end function box_contains 290 291 pure real(dp) function box_dist(lo,hi,po) result(dist) 292 ! Find the distance between point and the box 293 real(dp),intent(in) :: lo(3), hi(3), po(3) 294 dist = zero 295 if (po(1)<lo(1)) dist = dist + pow2(po(1)-lo(1)) 296 if (po(1)>hi(1)) dist = dist + pow2(po(1)-hi(1)) 297 if (po(2)<lo(2)) dist = dist + pow2(po(2)-lo(2)) 298 if (po(2)>hi(2)) dist = dist + pow2(po(2)-hi(2)) 299 if (po(3)<lo(3)) dist = dist + pow2(po(3)-lo(3)) 300 if (po(3)>hi(3)) dist = dist + pow2(po(3)-hi(3)) 301 end function box_dist 302 303 pure real(dp) function pow2(x) result(x2) 304 real(dp),intent(in) :: x 305 x2 = x*x 306 end function pow2 307 308 pure integer function get_octant(mi,po) result(ioctant) 309 real(dp),intent(in) :: po(3), mi(3) 310 integer :: ii,jj,kk 311 ii = 0; if (po(1)>=mi(1)) ii = 1 312 jj = 0; if (po(2)>=mi(2)) jj = 1 313 kk = 0; if (po(3)>=mi(3)) kk = 1 314 ioctant = ii*4+jj*2+kk+1 315 end function get_octant 316 317 pure integer function get_octant_lohi(lo,hi,po) result(ioctant) 318 real(dp),intent(in) :: lo(3),hi(3),po(3) 319 real(dp) :: mi(3) 320 mi = lo+half*(hi-lo) 321 ioctant = get_octant(mi,po) 322 end function get_octant_lohi 323 324 pure subroutine get_octants(lo,hi,nids,ids,points,octants) 325 ! From a list of points return the corresponding octant 326 real(dp),intent(in) :: lo(3), hi(3) 327 real(dp),intent(in) :: points(:,:) 328 integer,intent(in) :: nids 329 integer,intent(in) :: ids(nids) 330 integer,intent(out) :: octants(nids) 331 real(dp) :: mi(3) 332 integer :: id, ipoint 333 ! calculate midpoint 334 mi = lo+half*(hi-lo) 335 do id=1,nids 336 ipoint = ids(id) 337 octants(id) = get_octant(mi,points(:,ipoint)) 338 end do 339 end subroutine get_octants 340 341 pure subroutine get_lo_hi(lo_in,hi_in,lo_out,hi_out,ioctant) 342 ! Subdivide a box in an octant 343 integer,intent(in) :: ioctant 344 real(dp),intent(in) :: lo_in(3), hi_in(3) 345 real(dp),intent(out) :: lo_out(3), hi_out(3) 346 real(dp) :: de(3) 347 de = hi_in-lo_in 348 lo_out = lo_in + shifts(:,1,ioctant)*de 349 hi_out = lo_in + shifts(:,2,ioctant)*de 350 end subroutine get_lo_hi 351 352 end module m_octree