Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion devito/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,31 @@ def T(self):
"""
return self.transpose()

def _mpi_advanced_1d_target(self, glb_idx, axis):
"""
Return the raw local ndarray addressed by ``glb_idx`` without

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: This seems out of step with the function name?

advanced indexing along ``axis``.

The MPI advanced-indexing helper code in ``devito.data.utils`` owns
the communication pattern; this hook is kept on ``Data`` because only
the subclass can bypass its own ``__getitem__`` and obtain a plain
ndarray view.
"""
target_idx = list(glb_idx)
target_idx[axis] = slice(None)
loc_idx = self._index_glb_to_loc(tuple(target_idx))
target_axis = sum(not is_integer(i) for i in glb_idx[:axis])
return super().__getitem__(loc_idx).view(np.ndarray), target_axis

@_check_idx
def __getitem__(self, glb_idx, comm_type, gather_rank=None):
advanced = mpi_advanced_1d_index(self, glb_idx)
if advanced is not None:
# Global integer indices may refer to data owned by any rank.
return mpi_advanced_1d_get(
self, *advanced, target_getter=self._mpi_advanced_1d_target
)

loc_idx = self._index_glb_to_loc(glb_idx)
is_gather = isinstance(gather_rank, int)
if is_gather and comm_type is gather:
Expand Down Expand Up @@ -383,6 +406,17 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):

@_check_idx
def __setitem__(self, glb_idx, val, comm_type):
advanced = mpi_advanced_1d_index(self, glb_idx)
if advanced is not None:
# ``val`` is rank-local and ordered by the caller's global integer
# indices; route entries to their owner ranks.
glb_idx, axis, indices, decomposition = advanced
mpi_advanced_1d_set(
self, glb_idx, val, axis, indices, decomposition,
target_getter=self._mpi_advanced_1d_target
)
return

loc_idx = self._index_glb_to_loc(glb_idx)

if loc_idx is NONLOCAL:
Expand Down Expand Up @@ -461,7 +495,9 @@ def __setitem__(self, glb_idx, val, comm_type):
raise ValueError(f"Cannot insert obj of type `{type(val)}` into a Data")

def _normalize_index(self, idx):
if isinstance(idx, np.ndarray):
if isinstance(idx, np.ndarray) or (
isinstance(idx, list) and index_contains_integer_sequence(idx, self.ndim)
):
# Advanced indexing mode
return (idx,)
else:
Expand Down
313 changes: 312 additions & 1 deletion devito/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from devito.tools import Tag, as_list, as_tuple, is_integer
from devito.tools import Tag, as_list, as_tuple, dtype_to_mpidtype, is_integer, prod

__all__ = [
'NONLOCAL',
Expand All @@ -9,10 +9,15 @@
'convert_index',
'flip_idx',
'index_apply_modulo',
'index_contains_integer_sequence',
'index_dist_to_repl',
'index_handle_oob',
'index_is_basic',
'index_is_integer_sequence',
'loc_data_idx',
'mpi_advanced_1d_get',
'mpi_advanced_1d_index',
'mpi_advanced_1d_set',
'mpi_index_maps',
]

Expand All @@ -32,6 +37,312 @@ def index_is_basic(idx):
return all(is_integer(i) or (i is NONLOCAL) for i in idx)


def index_is_integer_sequence(idx):
"""
Return True for a one-dimensional integer array-like index.

NumPy treats ``a[[...]]`` and ``a[np.array([...])]`` as advanced indexing.
This helper recognizes only the integer-sequence form used by
``mpi_advanced_1d_*``; slices and scalars continue through the existing
global-to-local index conversion.
"""
if not isinstance(idx, (list, tuple, np.ndarray)):
return False

arr = np.asarray(idx)
return (arr.ndim == 1 and
(arr.size == 0 or np.issubdtype(arr.dtype, np.integer)))


def index_is_legacy_multidim_basic(idx, ndim):
"""
Return True when a top-level Python list is a legacy basic index.

Historically, Devito accepted expressions such as ``data[[0, 1, 2]]`` as
shorthand for ``data[(0, 1, 2)]`` on 3D ``Data`` objects. NumPy treats a
plain list of integers as advanced indexing, but changing this top-level
Devito behavior would break existing MPI tests and user code. Lists nested
inside an index tuple, e.g. ``data[:, [0, 2]]``, remain available for the
routed MPI advanced-indexing path.
"""
return (ndim > 1 and isinstance(idx, list) and len(idx) == ndim and
all(is_integer(i) for i in idx))


def index_contains_integer_sequence(idx, ndim):
"""
Return True if an index expression contains a 1D integer array-like item.

This is the full-index-expression counterpart to
:func:`index_is_integer_sequence`, which checks one index component. It is
used to cheaply reject ordinary basic indexing before normalizing ``idx``,
keeping the MPI advanced-indexing path out of the hot path for
scalar/slice-only accesses.
"""
if isinstance(idx, list) and index_is_legacy_multidim_basic(idx, ndim):
return False
elif isinstance(idx, (np.ndarray, list)):
return index_is_integer_sequence(idx)
elif isinstance(idx, tuple):
return any(index_is_integer_sequence(i) for i in idx)
else:
return False


def mpi_advanced_1d_index(data, glb_idx):
"""
Normalize the supported 1D MPI advanced-indexing subset.

Returns ``None`` when ``glb_idx`` can be handled by the regular Data
indexing path. Otherwise returns the normalized index, advanced axis,
global integer indices, and the axis decomposition used by
:func:`mpi_advanced_1d_get` or :func:`mpi_advanced_1d_set`.

The supported case is deliberately narrow: exactly one MPI-distributed
dimension, indexed by exactly one one-dimensional integer sequence. The
integer sequence is interpreted as global indices in that distributed
dimension. All other dimensions must use basic indexing, i.e. slices or
scalar integers.
"""
if not index_contains_integer_sequence(glb_idx, data.ndim):
return None

if not data._is_decomposed:
return None

glb_idx = data._normalize_index(glb_idx)
if len(glb_idx) > data.ndim:
return None
elif len(glb_idx) < data.ndim:
glb_idx = glb_idx + (slice(None),)*(data.ndim - len(glb_idx))

distributed = []
advanced = []
for i, d in enumerate(data._decomposition):
if d is None:
continue

distributed.append(i)
if index_is_integer_sequence(glb_idx[i]):
advanced.append(i)

if not advanced:
return None
elif len(distributed) != 1:
raise NotImplementedError(
"Advanced indexing with MPI-distributed Data is currently "
"supported only for data with a single distributed dimension"
)
elif len(advanced) != 1:
raise NotImplementedError(
"Advanced indexing with MPI-distributed Data supports a single "
"integer index array"
)

axis = advanced[0]
for i, idx in enumerate(glb_idx):
if i != axis and index_is_integer_sequence(idx):
raise NotImplementedError(
"Advanced indexing with MPI-distributed Data supports a single "
"integer index array"
)

indices = np.asarray(glb_idx[axis], dtype=np.int64)
return glb_idx, axis, indices, data._decomposition[axis]


def mpi_advanced_1d_get(data, glb_idx, axis, indices, decomposition,
target_getter):
"""
Read MPI-distributed ``data`` using one global integer index sequence.

This implements the read side of the supported NumPy advanced-indexing
subset. Each rank supplies the global indices it wants in its local output.
The helper asks owner ranks for those entries and returns a normal NumPy
array ordered exactly like the caller's index sequence.
"""
indices, owners = _mpi_advanced_1d_owners(data, indices, decomposition)
shape = _mpi_advanced_1d_result_shape(data, glb_idx, axis, indices)
positions, scount, rcount, recv_indices = \
_mpi_advanced_1d_indices_alltoall(data, indices, owners)

source, source_axis = target_getter(glb_idx, axis)
source = np.moveaxis(source, source_axis, 0)
local_offsets = _mpi_advanced_1d_local_offsets(recv_indices, decomposition)
send_data = np.ascontiguousarray(source[local_offsets])

payload_shape = send_data.shape[1:]
recv_data = _mpi_advanced_1d_data_alltoall(data, send_data, rcount, scount,
payload_shape)

ret = np.empty(shape, dtype=data.dtype)
ret_view = np.moveaxis(ret, source_axis, 0)
ret_view[np.concatenate(positions)] = recv_data
return ret


def mpi_advanced_1d_set(data, glb_idx, val, axis, indices, decomposition,
target_getter):
"""
Assign into MPI-distributed ``data`` using one global integer index sequence.

``val`` is interpreted as local to the calling rank and ordered according
to ``indices``. The helper routes each value to the rank that owns the
corresponding global index, preserving NumPy broadcasting for the
non-distributed dimensions.
"""
indices, owners = _mpi_advanced_1d_owners(data, indices, decomposition)
shape = _mpi_advanced_1d_result_shape(data, glb_idx, axis, indices)
val = np.asarray(val, dtype=data.dtype)
val = np.broadcast_to(val, shape)
value_axis = sum(not is_integer(i) for i in glb_idx[:axis])
val = np.ascontiguousarray(np.moveaxis(val, value_axis, 0))
payload_shape = val.shape[1:]

positions, scount, rcount, recv_indices = \
_mpi_advanced_1d_indices_alltoall(data, indices, owners)

send_data = _mpi_advanced_1d_pack_axis0(val, positions)
recv_data = _mpi_advanced_1d_data_alltoall(data, send_data, scount, rcount,
payload_shape)

error = None
if recv_indices.size and np.unique(recv_indices).size != recv_indices.size:
error = "Duplicate global indices in MPI-distributed advanced assignment"
_mpi_advanced_1d_error(data, error)

target, target_axis = target_getter(glb_idx, axis)
target = np.moveaxis(target, target_axis, 0)
local_offsets = _mpi_advanced_1d_local_offsets(recv_indices, decomposition)
target[local_offsets] = recv_data


def _mpi_advanced_1d_error(data, error):
"""Raise the first error reported by any rank, on every rank."""
if data._distributor.nprocs > 1:
errors = data._distributor.comm.allgather(error)
error = next((i for i in errors if i is not None), None)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth reporting the rank(s) that raised the error here?


if error is not None:
raise ValueError(error)


def _mpi_advanced_1d_owners(data, indices, decomposition):
"""Map global indices to owning ranks, normalizing negative indices."""
indices = indices.copy()
if decomposition.glb_max is not None:
indices[indices < 0] += decomposition.glb_max + 1

owners = np.full(indices.size, -1, dtype=np.int32)
for i, r in enumerate(decomposition):
if r.size:
owners[np.isin(indices, r)] = i

error = None
if np.any(owners < 0):
error = "Advanced index contains out-of-bounds global indices"
_mpi_advanced_1d_error(data, error)

return indices, owners


def _mpi_advanced_1d_result_shape(data, glb_idx, axis, indices):
"""Return the NumPy result shape for the supported advanced-index case."""
shape = []
for i, idx in enumerate(glb_idx):
if is_integer(idx):
continue
elif i == axis:
shape.append(indices.size)
elif isinstance(idx, slice):
shape.append(len(range(*idx.indices(data.shape[i]))))
else:
raise NotImplementedError(
"Advanced indexing with MPI-distributed Data supports only "
"integer arrays, integer indices, and slices"
)
return tuple(shape)


def _mpi_advanced_1d_local_offsets(indices, decomposition):
"""Convert global indices received by an owner rank to local offsets."""
return np.array([decomposition.index_glb_to_loc(int(i)) for i in indices],
dtype=np.int64)


def _mpi_advanced_1d_positions(owners, nprocs):
"""
Group local index positions by destination rank.

``positions[r]`` are the entries in the caller's advanced index whose data
must be exchanged with rank ``r``.
"""
return [np.where(owners == i)[0] for i in range(nprocs)]


def _mpi_advanced_1d_indices_alltoall(data, indices, owners):
"""
Exchange requested global indices with their owner ranks.

The returned counts describe the same exchange pattern used later for the
payload values.
"""
comm = data._distributor.comm
nprocs = data._distributor.nprocs
positions = _mpi_advanced_1d_positions(owners, nprocs)
scount = np.array([i.size for i in positions], dtype=np.int32)
rcount = _mpi_advanced_1d_count_alltoall(comm, scount)

send_indices = _mpi_advanced_1d_pack_axis0(indices, positions)
recv_indices = np.empty(int(np.sum(rcount)), dtype=np.int64)
idtype = dtype_to_mpidtype(np.int64)
_mpi_advanced_1d_alltoallv(comm, send_indices, recv_indices, scount,
rcount, idtype)

return positions, scount, rcount, recv_indices


def _mpi_advanced_1d_data_alltoall(data, send_data, scount, rcount, payload_shape):
"""Exchange payload arrays using an already established index exchange."""
comm = data._distributor.comm
payload_size = prod(payload_shape)
recv_data = np.empty((int(np.sum(rcount)), *payload_shape), dtype=data.dtype)

dscount = scount * payload_size
drcount = rcount * payload_size
mpitype = dtype_to_mpidtype(data.dtype)
_mpi_advanced_1d_alltoallv(comm, send_data, recv_data, dscount, drcount,
mpitype)

return recv_data


def _mpi_advanced_1d_count_alltoall(comm, scount):
rcount = np.empty_like(scount)
comm.Alltoall(scount, rcount)
return rcount


def _mpi_advanced_1d_alltoallv(comm, send, recv, scount, rcount, mpitype):
sdisp = _mpi_advanced_1d_displacements(scount)
rdisp = _mpi_advanced_1d_displacements(rcount)
comm.Alltoallv([send, scount, sdisp, mpitype],
[recv, rcount, rdisp, mpitype])


def _mpi_advanced_1d_displacements(counts):
displacements = np.empty_like(counts, dtype=np.int32)
displacements[0] = 0
displacements[1:] = np.cumsum(counts[:-1], dtype=np.int64)
return displacements


def _mpi_advanced_1d_pack_axis0(array, positions):
"""Pack an axis-0 array in destination-rank order."""
return np.ascontiguousarray(np.concatenate([array[i] for i in positions],
axis=0))


def index_apply_modulo(idx, modulo):
if is_integer(idx):
return idx % modulo
Expand Down
Loading
Loading