Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 116 additions & 47 deletions source/source_cell/module_neighlist/bin_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
#include <algorithm>
#include <cassert>
#include "bin_manager.h"
#include "source_base/timer.h"

#ifdef _OPENMP
#include <omp.h>
#endif

namespace
{
constexpr int neighbor_build_openmp_threshold = 256;
}

// ========== Bin class implementation ==========

Expand Down Expand Up @@ -69,6 +79,7 @@ void BinManager::init_bins(
const std::vector<NeighborAtom>& ghost_atoms
)
{
ModuleBase::timer::start("BinManager", "init_bins");
sradius_ = sr;
if(inside_atoms.empty() && ghost_atoms.empty())
{
Expand All @@ -77,6 +88,7 @@ void BinManager::init_bins(
nbinx_ = nbiny_ = nbinz_ = 1;
bins_.clear();
bins_.resize(1);
ModuleBase::timer::end("BinManager", "init_bins");
return;
}

Expand Down Expand Up @@ -129,13 +141,15 @@ void BinManager::init_bins(
}
}
}
ModuleBase::timer::end("BinManager", "init_bins");
}

void BinManager::do_binning(
const std::vector<NeighborAtom>& inside_atoms,
const std::vector<NeighborAtom>& ghost_atoms
)
{
ModuleBase::timer::start("BinManager", "do_binning");
auto bin_atom = [&](const NeighborAtom& atom)
{
int ix = std::min(
Expand All @@ -160,77 +174,131 @@ void BinManager::do_binning(

for (const auto& atom : inside_atoms) bin_atom(atom);
for (const auto& atom : ghost_atoms) bin_atom(atom);
ModuleBase::timer::end("BinManager", "do_binning");
}

int BinManager::bin_index(int ix, int iy, int iz) const {
return ix * nbiny_ * nbinz_ + iy * nbinz_ + iz;
}

template <typename Emit>
void BinManager::visit_neighbors(
int i,
const std::vector<NeighborAtom>& atoms,
double sradius2,
const Emit& emit
) const
{
const int ix = std::min(
std::max(int((atoms[i].position_x - x_min_) / bin_sizex_), 0),
nbinx_ - 1
);

const int iy = std::min(
std::max(int((atoms[i].position_y - y_min_) / bin_sizey_), 0),
nbiny_ - 1
);

const int iz = std::min(
std::max(int((atoms[i].position_z - z_min_) / bin_sizez_), 0),
nbinz_ - 1
);

for (int dx = -1; dx <= 1; dx++)
{
for (int dy = -1; dy <= 1; dy++)
{
for (int dz = -1; dz <= 1; dz++)
{
const int jx = ix + dx;
const int jy = iy + dy;
const int jz = iz + dz;

if (jx < 0 || jx >= nbinx_ ||
jy < 0 || jy >= nbiny_ ||
jz < 0 || jz >= nbinz_)
{
continue;
}

const int nidx = bin_index(jx, jy, jz);

for (const NeighborAtom& natom : bins_[nidx].get_atoms())
{
const double dx = atoms[i].position_x - natom.position_x;
const double dy = atoms[i].position_y - natom.position_y;
const double dz = atoms[i].position_z - natom.position_z;

const double dist2 = dx * dx + dy * dy + dz * dz;

if (dist2 <= sradius2 && dist2 != 0)
{
emit(natom.atom_id);
}
}
}
}
}
}

void BinManager::build_atom_neighbors(
NeighborList& neighbor_list,
std::vector<NeighborAtom>& atoms
)
{
ModuleBase::timer::start("BinManager", "build_atom_neighbors");
assert(atoms.size() == static_cast<size_t>(neighbor_list.get_nlocal()));

double sradius2 = sradius_ * sradius_;
const int nlocal = static_cast<int>(atoms.size());
const double sradius2 = sradius_ * sradius_;

neighbor_list.reset();

std::vector<int> neigh_tmp;

for (int i = 0; i < atoms.size(); i++)
#ifdef _OPENMP
const bool use_parallel = nlocal >= neighbor_build_openmp_threshold && omp_get_max_threads() > 1;
if (use_parallel)
{
neigh_tmp.clear();

int ix = std::min(
std::max(int((atoms[i].position_x - x_min_) / bin_sizex_), 0),
nbinx_ - 1
);
std::vector<int> neighbor_counts(nlocal, 0);

int iy = std::min(
std::max(int((atoms[i].position_y - y_min_) / bin_sizey_), 0),
nbiny_ - 1
);
#pragma omp parallel for schedule(static)
for (int i = 0; i < nlocal; i++)
{
int count = 0;
visit_neighbors(i, atoms, sradius2, [&](const int) { ++count; });
neighbor_counts[i] = count;
}

int iz = std::min(
std::max(int((atoms[i].position_z - z_min_) / bin_sizez_), 0),
nbinz_ - 1
);
for (int i = 0; i < nlocal; i++)
{
const int n = neighbor_counts[i];
neighbor_list.firstneigh_[i] = neighbor_list.allocator_.allocate(n);
neighbor_list.numneigh_[i] = n;
}

for (int dx = -1; dx <= 1; dx++)
#pragma omp parallel for schedule(static)
for (int i = 0; i < nlocal; i++)
{
for (int dy = -1; dy <= 1; dy++)
int* ptr = neighbor_list.firstneigh_[i];
int k = 0;
visit_neighbors(i, atoms, sradius2, [&](const int atom_id)
{
for (int dz = -1; dz <= 1; dz++)
{
int jx = ix + dx;
int jy = iy + dy;
int jz = iz + dz;

if (jx < 0 || jx >= nbinx_ ||
jy < 0 || jy >= nbiny_ ||
jz < 0 || jz >= nbinz_)
continue;

int nidx = bin_index(jx, jy, jz);
assert(ptr != nullptr);
ptr[k++] = atom_id;
});
assert(k == neighbor_counts[i]);
}

for (const NeighborAtom& natom : bins_[nidx].get_atoms())
{
double dx = atoms[i].position_x - natom.position_x;
double dy = atoms[i].position_y - natom.position_y;
double dz = atoms[i].position_z - natom.position_z;
ModuleBase::timer::end("BinManager", "build_atom_neighbors");
return;
}
#endif

double dist2 = dx * dx + dy * dy + dz * dz;
std::vector<int> neigh_tmp;

if (dist2 <= sradius2 && dist2 != 0)
{
neigh_tmp.push_back(natom.atom_id);
}
}
}
}
}
for (int i = 0; i < nlocal; i++)
{
neigh_tmp.clear();
visit_neighbors(i, atoms, sradius2, [&](const int atom_id) { neigh_tmp.push_back(atom_id); });

int n = neigh_tmp.size();

Expand All @@ -245,6 +313,7 @@ void BinManager::build_atom_neighbors(
neighbor_list.firstneigh_[i] = ptr;
neighbor_list.numneigh_[i] = n;
}
ModuleBase::timer::end("BinManager", "build_atom_neighbors");
}

void BinManager::clear()
Expand All @@ -255,4 +324,4 @@ void BinManager::clear()
}

bins_.clear();
}
}
10 changes: 9 additions & 1 deletion source/source_cell/module_neighlist/bin_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ class BinManager
* @return Flat index in the bins_ array.
*/
int bin_index(int ix, int iy, int iz) const;

template <typename Emit>
void visit_neighbors(
int i,
const std::vector<NeighborAtom>& atoms,
double sradius2,
const Emit& emit
) const;
};

#endif // BIN_MANAGER_H
#endif // BIN_MANAGER_H
9 changes: 8 additions & 1 deletion source/source_cell/module_neighlist/neighbor_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <algorithm>
#include <limits>
#include <cassert>
#include "source_base/timer.h"

// ========== Getter methods ==========

Expand Down Expand Up @@ -171,6 +172,7 @@ void NeighborSearch::check_expand_condition(const AtomProvider& ucell)

void NeighborSearch::set_member_variables(const AtomProvider& ucell)
{
ModuleBase::timer::start("NeighborSearch", "set_member_variables");
all_atoms_.clear();

ModuleBase::Vector3<double> vec1(ucell.get_latvec().e11, ucell.get_latvec().e12, ucell.get_latvec().e13);
Expand Down Expand Up @@ -209,12 +211,14 @@ void NeighborSearch::set_member_variables(const AtomProvider& ucell)
}
}
}
ModuleBase::timer::end("NeighborSearch", "set_member_variables");
}

// ========== Main public interface ==========

void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank)
{
ModuleBase::timer::start("NeighborSearch", "init");
// clear possible residual data from previous runs
inside_atoms_.clear();
ghost_atoms_.clear();
Expand Down Expand Up @@ -322,13 +326,16 @@ void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank)
}

neighbor_list_.initialize(inside_atoms_.size(), all_atoms_.size() * neighbor_reserve_factor);
ModuleBase::timer::end("NeighborSearch", "init");
}

void NeighborSearch::build_neighbors()
{
ModuleBase::timer::start("NeighborSearch", "build_neighbors");
bin_manager_.init_bins(search_radius_, inside_atoms_, ghost_atoms_);
bin_manager_.do_binning(inside_atoms_, ghost_atoms_);
bin_manager_.build_atom_neighbors(neighbor_list_, inside_atoms_);
ModuleBase::timer::end("NeighborSearch", "build_neighbors");
}

// ========== Utility methods ==========
Expand Down Expand Up @@ -374,4 +381,4 @@ void NeighborSearch::decompose(int mpi_size, int &nx, int &ny, int &nz)
break;
}
}
}
}
Loading