diff --git a/source/source_cell/module_neighlist/bin_manager.cpp b/source/source_cell/module_neighlist/bin_manager.cpp index 1077b91dd75..545ff8da21b 100644 --- a/source/source_cell/module_neighlist/bin_manager.cpp +++ b/source/source_cell/module_neighlist/bin_manager.cpp @@ -3,6 +3,16 @@ #include #include #include "bin_manager.h" +#include "source_base/timer.h" + +#ifdef _OPENMP +#include +#endif + +namespace +{ +constexpr int neighbor_build_openmp_threshold = 256; +} // ========== Bin class implementation ========== @@ -69,6 +79,7 @@ void BinManager::init_bins( const std::vector& ghost_atoms ) { + ModuleBase::timer::start("BinManager", "init_bins"); sradius_ = sr; if(inside_atoms.empty() && ghost_atoms.empty()) { @@ -77,6 +88,7 @@ void BinManager::init_bins( nbinx_ = nbiny_ = nbinz_ = 1; bins_.clear(); bins_.resize(1); + ModuleBase::timer::end("BinManager", "init_bins"); return; } @@ -129,6 +141,7 @@ void BinManager::init_bins( } } } + ModuleBase::timer::end("BinManager", "init_bins"); } void BinManager::do_binning( @@ -136,6 +149,7 @@ void BinManager::do_binning( const std::vector& ghost_atoms ) { + ModuleBase::timer::start("BinManager", "do_binning"); auto bin_atom = [&](const NeighborAtom& atom) { int ix = std::min( @@ -160,77 +174,131 @@ void BinManager::do_binning( for (const auto& atom : inside_atoms) bin_atom(atom); for (const auto& atom : ghost_atoms) bin_atom(atom); + ModuleBase::timer::end("BinManager", "do_binning"); } int BinManager::bin_index(int ix, int iy, int iz) const { return ix * nbiny_ * nbinz_ + iy * nbinz_ + iz; } +template +void BinManager::visit_neighbors( + int i, + const std::vector& atoms, + double sradius2, + const Emit& emit +) const +{ + const int ix = std::min( + std::max(int((atoms[i].position_x - x_min_) / bin_sizex_), 0), + nbinx_ - 1 + ); + + const int iy = std::min( + std::max(int((atoms[i].position_y - y_min_) / bin_sizey_), 0), + nbiny_ - 1 + ); + + const int iz = std::min( + std::max(int((atoms[i].position_z - z_min_) / bin_sizez_), 0), + nbinz_ - 1 + ); + + for (int dx = -1; dx <= 1; dx++) + { + for (int dy = -1; dy <= 1; dy++) + { + for (int dz = -1; dz <= 1; dz++) + { + const int jx = ix + dx; + const int jy = iy + dy; + const int jz = iz + dz; + + if (jx < 0 || jx >= nbinx_ || + jy < 0 || jy >= nbiny_ || + jz < 0 || jz >= nbinz_) + { + continue; + } + + const int nidx = bin_index(jx, jy, jz); + + for (const NeighborAtom& natom : bins_[nidx].get_atoms()) + { + const double dx = atoms[i].position_x - natom.position_x; + const double dy = atoms[i].position_y - natom.position_y; + const double dz = atoms[i].position_z - natom.position_z; + + const double dist2 = dx * dx + dy * dy + dz * dz; + + if (dist2 <= sradius2 && dist2 != 0) + { + emit(natom.atom_id); + } + } + } + } + } +} + void BinManager::build_atom_neighbors( NeighborList& neighbor_list, std::vector& atoms ) { + ModuleBase::timer::start("BinManager", "build_atom_neighbors"); assert(atoms.size() == static_cast(neighbor_list.get_nlocal())); - double sradius2 = sradius_ * sradius_; + const int nlocal = static_cast(atoms.size()); + const double sradius2 = sradius_ * sradius_; neighbor_list.reset(); - std::vector neigh_tmp; - - for (int i = 0; i < atoms.size(); i++) +#ifdef _OPENMP + const bool use_parallel = nlocal >= neighbor_build_openmp_threshold && omp_get_max_threads() > 1; + if (use_parallel) { - neigh_tmp.clear(); - - int ix = std::min( - std::max(int((atoms[i].position_x - x_min_) / bin_sizex_), 0), - nbinx_ - 1 - ); + std::vector neighbor_counts(nlocal, 0); - int iy = std::min( - std::max(int((atoms[i].position_y - y_min_) / bin_sizey_), 0), - nbiny_ - 1 - ); +#pragma omp parallel for schedule(static) + for (int i = 0; i < nlocal; i++) + { + int count = 0; + visit_neighbors(i, atoms, sradius2, [&](const int) { ++count; }); + neighbor_counts[i] = count; + } - int iz = std::min( - std::max(int((atoms[i].position_z - z_min_) / bin_sizez_), 0), - nbinz_ - 1 - ); + for (int i = 0; i < nlocal; i++) + { + const int n = neighbor_counts[i]; + neighbor_list.firstneigh_[i] = neighbor_list.allocator_.allocate(n); + neighbor_list.numneigh_[i] = n; + } - for (int dx = -1; dx <= 1; dx++) +#pragma omp parallel for schedule(static) + for (int i = 0; i < nlocal; i++) { - for (int dy = -1; dy <= 1; dy++) + int* ptr = neighbor_list.firstneigh_[i]; + int k = 0; + visit_neighbors(i, atoms, sradius2, [&](const int atom_id) { - for (int dz = -1; dz <= 1; dz++) - { - int jx = ix + dx; - int jy = iy + dy; - int jz = iz + dz; - - if (jx < 0 || jx >= nbinx_ || - jy < 0 || jy >= nbiny_ || - jz < 0 || jz >= nbinz_) - continue; - - int nidx = bin_index(jx, jy, jz); + assert(ptr != nullptr); + ptr[k++] = atom_id; + }); + assert(k == neighbor_counts[i]); + } - for (const NeighborAtom& natom : bins_[nidx].get_atoms()) - { - double dx = atoms[i].position_x - natom.position_x; - double dy = atoms[i].position_y - natom.position_y; - double dz = atoms[i].position_z - natom.position_z; + ModuleBase::timer::end("BinManager", "build_atom_neighbors"); + return; + } +#endif - double dist2 = dx * dx + dy * dy + dz * dz; + std::vector neigh_tmp; - if (dist2 <= sradius2 && dist2 != 0) - { - neigh_tmp.push_back(natom.atom_id); - } - } - } - } - } + for (int i = 0; i < nlocal; i++) + { + neigh_tmp.clear(); + visit_neighbors(i, atoms, sradius2, [&](const int atom_id) { neigh_tmp.push_back(atom_id); }); int n = neigh_tmp.size(); @@ -245,6 +313,7 @@ void BinManager::build_atom_neighbors( neighbor_list.firstneigh_[i] = ptr; neighbor_list.numneigh_[i] = n; } + ModuleBase::timer::end("BinManager", "build_atom_neighbors"); } void BinManager::clear() @@ -255,4 +324,4 @@ void BinManager::clear() } bins_.clear(); -} \ No newline at end of file +} diff --git a/source/source_cell/module_neighlist/bin_manager.h b/source/source_cell/module_neighlist/bin_manager.h index 22b94d394a0..e069ade809c 100644 --- a/source/source_cell/module_neighlist/bin_manager.h +++ b/source/source_cell/module_neighlist/bin_manager.h @@ -218,6 +218,14 @@ class BinManager * @return Flat index in the bins_ array. */ int bin_index(int ix, int iy, int iz) const; + + template + void visit_neighbors( + int i, + const std::vector& atoms, + double sradius2, + const Emit& emit + ) const; }; -#endif // BIN_MANAGER_H \ No newline at end of file +#endif // BIN_MANAGER_H diff --git a/source/source_cell/module_neighlist/neighbor_search.cpp b/source/source_cell/module_neighlist/neighbor_search.cpp index 912515bf9d5..63333cf3907 100644 --- a/source/source_cell/module_neighlist/neighbor_search.cpp +++ b/source/source_cell/module_neighlist/neighbor_search.cpp @@ -3,6 +3,7 @@ #include #include #include +#include "source_base/timer.h" // ========== Getter methods ========== @@ -171,6 +172,7 @@ void NeighborSearch::check_expand_condition(const AtomProvider& ucell) void NeighborSearch::set_member_variables(const AtomProvider& ucell) { + ModuleBase::timer::start("NeighborSearch", "set_member_variables"); all_atoms_.clear(); ModuleBase::Vector3 vec1(ucell.get_latvec().e11, ucell.get_latvec().e12, ucell.get_latvec().e13); @@ -209,12 +211,14 @@ void NeighborSearch::set_member_variables(const AtomProvider& ucell) } } } + ModuleBase::timer::end("NeighborSearch", "set_member_variables"); } // ========== Main public interface ========== void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank) { + ModuleBase::timer::start("NeighborSearch", "init"); // clear possible residual data from previous runs inside_atoms_.clear(); ghost_atoms_.clear(); @@ -322,13 +326,16 @@ void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank) } neighbor_list_.initialize(inside_atoms_.size(), all_atoms_.size() * neighbor_reserve_factor); + ModuleBase::timer::end("NeighborSearch", "init"); } void NeighborSearch::build_neighbors() { + ModuleBase::timer::start("NeighborSearch", "build_neighbors"); bin_manager_.init_bins(search_radius_, inside_atoms_, ghost_atoms_); bin_manager_.do_binning(inside_atoms_, ghost_atoms_); bin_manager_.build_atom_neighbors(neighbor_list_, inside_atoms_); + ModuleBase::timer::end("NeighborSearch", "build_neighbors"); } // ========== Utility methods ========== @@ -374,4 +381,4 @@ void NeighborSearch::decompose(int mpi_size, int &nx, int &ny, int &nz) break; } } -} \ No newline at end of file +}