Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions src/TiledArray/expressions/cont_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,20 @@ class ContEngine : public BinaryEngine<Derived> {
// the synthetic mode from the one-rank mismatch and pad their folded views
// with a unit extent.
const unsigned int u = synthetic_unit_left_external();
// - a FUSED BROADCAST ON THE RIGHT (the right operand is entirely fused,
// no contraction, e.g. C("b,k") = A("b,k") * B("b")) folds to a rank-0
// RIGHT operand; a synthetic unit RIGHT-external mode restores a
// supported (M,K) x (K,1) -> (M,1) shape (see
// synthetic_unit_right_external()).
const unsigned int u_right = synthetic_unit_right_external();

// the tile op operates on the folded (fused-mode-free) shapes; the
// synthetic unit mode leads the folded left operand, so it is NoTrans
// synthetic unit mode leads the folded left operand (trails the folded
// right operand), so it is NoTrans on that side
const auto left_op =
u ? math::blas::NoTranspose : to_cblas_op(left_outer_permtype_);
const auto right_op = to_cblas_op(right_outer_permtype_);
const auto right_op =
u_right ? math::blas::NoTranspose : to_cblas_op(right_outer_permtype_);
// As in init_struct, the ContractReduce tile op needs the per-cell inner
// element op when the operands are nested -- including the dot_inner regime
// (denest_to_scalar), where the result tile is plain but the nested modes
Expand All @@ -749,9 +757,10 @@ class ContEngine : public BinaryEngine<Derived> {
TiledArray::detail::is_tensor_of_tensor_v<value_type> ||
denest_to_scalar;
if constexpr (!tot_aware_op) {
op_ = op_type(left_op, right_op, factor_, outer_size(indices_) - nh + u,
op_ = op_type(left_op, right_op, factor_,
outer_size(indices_) - nh + u + u_right,
outer_size(left_indices_) - nh + u,
outer_size(right_indices_) - nh);
outer_size(right_indices_) - nh + u_right);
} else {
// the batched tile op must be perm-free (BatchedContractReduce cannot
// host the folded-rank result permutation); the outer perm is handled
Expand All @@ -772,10 +781,11 @@ class ContEngine : public BinaryEngine<Derived> {

// factor_ is absorbed into element_nonreturn_op_
op_ = op_type(left_op, right_op, scalar_type(1),
outer_size(indices_) - nh + u,
outer_size(indices_) - nh + u + u_right,
outer_size(left_indices_) - nh + u,
outer_size(right_indices_) - nh, BipartitePermutation{},
this->element_nonreturn_op_, std::move(this->arena_plan_));
outer_size(right_indices_) - nh + u_right,
BipartitePermutation{}, this->element_nonreturn_op_,
std::move(this->arena_plan_));
// ce+e, ce+ce_right and ce+ce_left are mutually exclusive; at most one
// is non-null and only one install fires (see init_struct)
if constexpr (TiledArray::detail::is_tensor_of_tensor_v<value_type>) {
Expand Down Expand Up @@ -826,6 +836,20 @@ class ContEngine : public BinaryEngine<Derived> {
: 0u;
}

/// \return 1 if the folded general product needs a SYNTHETIC unit
/// right-external mode, else 0. Mirror of synthetic_unit_left_external() for
/// the RIGHT operand: a rank-0 RIGHT operand arises when the right argument
/// is entirely fused with no contraction (a fused broadcast on the right,
/// e.g. C("b,k") = A("b,k") * B("b")), i.e.
/// outer_size(right_indices_) == n_fused_modes_. A left-external unit cannot
/// fix it (the right operand stays rank-0), so a unit right-external mode
/// (carried only in the GemmHelper) restores a supported (M,K) x (K,1) ->
/// (M,1) shape. The result gains a trailing unit mode that is absent from the
/// actual tranges/shapes/tiles, exactly as the left case prepends one.
unsigned int synthetic_unit_right_external() const {
return (outer_size(right_indices_) == n_fused_modes_) ? 1u : 0u;
}

/// Tiled range factory function for a general product

/// \return The result tiled range: the fused mode ranges followed by the
Expand All @@ -837,8 +861,9 @@ class ContEngine : public BinaryEngine<Derived> {
// GemmHelper only (see synthetic_unit_left_external()); the actual tranges
// do not have it
const unsigned int u = synthetic_unit_left_external();
const unsigned int u_right = synthetic_unit_right_external();
const unsigned int neA = op_.gemm_helper().left_rank() - nc - u;
const unsigned int neB = op_.gemm_helper().right_rank() - nc;
const unsigned int neB = op_.gemm_helper().right_rank() - nc - u_right;

typename trange_type::Ranges ranges(nh + neA + neB);
unsigned int i = 0ul;
Expand Down Expand Up @@ -893,8 +918,9 @@ class ContEngine : public BinaryEngine<Derived> {
// GemmHelper only (see synthetic_unit_left_external()); the actual tranges
// do not have it
const unsigned int u = synthetic_unit_left_external();
const unsigned int u_right = synthetic_unit_right_external();
const unsigned int neA = op_.gemm_helper().left_rank() - nc - u;
const unsigned int neB = op_.gemm_helper().right_rank() - nc;
const unsigned int neB = op_.gemm_helper().right_rank() - nc - u_right;

// Get pointers to the argument sizes
const auto* MADNESS_RESTRICT const left_tiles_size =
Expand Down
57 changes: 38 additions & 19 deletions src/TiledArray/sparse_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -1727,13 +1727,21 @@ class SparseShape {
const bool unit_external =
(tile_norms_.range().rank() + 1u == nfused + gemm_helper.left_rank());
const unsigned int u = unit_external ? 1u : 0u;
// a fused broadcast on the RIGHT carries a SYNTHETIC unit right-external
// mode in the GemmHelper only (mirror of unit_external; see
// ContEngine::synthetic_unit_right_external); detect it from the one-rank
// mismatch with the actual right norm tensor and pad the folded
// right/result views with a unit extent
const bool unit_right_external = (other.tile_norms_.range().rank() + 1u ==
nfused + gemm_helper.right_rank());
const unsigned int u_right = unit_right_external ? 1u : 0u;

// check that the ranks match the folded gemm ranks plus the fused modes,
// and that the fused and contracted mode extents of the two shapes are
// congruent
TA_ASSERT(tile_norms_.range().rank() + u ==
nfused + gemm_helper.left_rank());
TA_ASSERT(other.tile_norms_.range().rank() ==
TA_ASSERT(other.tile_norms_.range().rank() + u_right ==
nfused + gemm_helper.right_rank());
for (unsigned int d = 0u; d < nfused; ++d)
TA_ASSERT(left_extent[d] == right_extent[d]);
Expand All @@ -1751,14 +1759,16 @@ class SparseShape {
for (unsigned int i = gemm_helper.left_inner_begin();
i < gemm_helper.left_inner_end(); ++i)
K *= left_extent[nfused + i - u];
for (unsigned int i = gemm_helper.right_outer_begin();
i < gemm_helper.right_outer_end(); ++i)
N *= right_extent[nfused + i];
if (!unit_right_external)
for (unsigned int i = gemm_helper.right_outer_begin();
i < gemm_helper.right_outer_end(); ++i)
N *= right_extent[nfused + i];

// result size vectors: fused modes (from this), then the left and right
// outer modes (the synthetic unit left-external mode is absent from the
// actual result)
const unsigned int result_rank = nfused + gemm_helper.result_rank() - u;
// outer modes (the synthetic unit left-/right-external modes are absent
// from the actual result)
const unsigned int result_rank =
nfused + gemm_helper.result_rank() - u - u_right;
std::shared_ptr<vector_type> result_size_vectors(
new vector_type[result_rank], std::default_delete<vector_type[]>());
unsigned int x = 0ul;
Expand All @@ -1768,9 +1778,10 @@ class SparseShape {
for (unsigned int i = gemm_helper.left_outer_begin();
i < gemm_helper.left_outer_end(); ++i, ++x)
result_size_vectors.get()[x] = size_vectors_.get()[nfused + i];
for (unsigned int i = gemm_helper.right_outer_begin();
i < gemm_helper.right_outer_end(); ++i, ++x)
result_size_vectors.get()[x] = other.size_vectors_.get()[nfused + i];
if (!unit_right_external)
for (unsigned int i = gemm_helper.right_outer_begin();
i < gemm_helper.right_outer_end(); ++i, ++x)
result_size_vectors.get()[x] = other.size_vectors_.get()[nfused + i];

// the result norm tensor over (fused..., left outer..., right outer...)
using range_type = typename Tensor<value_type>::range_type;
Expand All @@ -1788,22 +1799,28 @@ class SparseShape {
lobounds.push_back(tile_norms_.range().lobound_data()[nfused + i]);
upbounds.push_back(tile_norms_.range().upbound_data()[nfused + i]);
}
for (unsigned int i = gemm_helper.right_outer_begin();
i < gemm_helper.right_outer_end(); ++i) {
lobounds.push_back(other.tile_norms_.range().lobound_data()[nfused + i]);
upbounds.push_back(other.tile_norms_.range().upbound_data()[nfused + i]);
}
if (!unit_right_external)
for (unsigned int i = gemm_helper.right_outer_begin();
i < gemm_helper.right_outer_end(); ++i) {
lobounds.push_back(
other.tile_norms_.range().lobound_data()[nfused + i]);
upbounds.push_back(
other.tile_norms_.range().upbound_data()[nfused + i]);
}
Tensor<value_type> result_norms(range_type(lobounds, upbounds), 0);

// the range spanned by modes [nfused, rank) of \p r, rebased to zero
// lobounds (scratch view for the slab-batched norm GEMM)
auto fold_range = [nfused](const range_type& r,
const bool prepend_unit = false) {
const bool prepend_unit = false,
const bool append_unit = false) {
const auto* extent = r.extent_data();
container::svector<index1_type> extents;
extents.reserve(r.rank() - nfused + (prepend_unit ? 1u : 0u));
extents.reserve(r.rank() - nfused + (prepend_unit ? 1u : 0u) +
(append_unit ? 1u : 0u));
if (prepend_unit) extents.push_back(1);
extents.insert(extents.end(), extent + nfused, extent + r.rank());
if (append_unit) extents.push_back(1);
return range_type(extents);
};

Expand Down Expand Up @@ -1847,9 +1864,11 @@ class SparseShape {
// buffer, so the accumulation lands in place
auto left_folded =
left.reshape(fold_range(left.range(), unit_external), H);
auto right_folded = right.reshape(fold_range(right.range()), H);
auto right_folded = right.reshape(
fold_range(right.range(), false, unit_right_external), H);
auto result_folded = result_norms.reshape(
fold_range(result_norms.range(), unit_external), H);
fold_range(result_norms.range(), unit_external, unit_right_external),
H);
result_folded.gemm(left_folded, right_folded, abs_factor, gemm_helper);

// Hard zero tiles that are below the zero threshold.
Expand Down
40 changes: 27 additions & 13 deletions src/TiledArray/tile_op/batched_contract_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,19 @@ class BatchedContractReduce {
/// \return the range spanned by modes [nfused_, rank) of \p r, rebased to
/// zero lobounds (the folded view is a GEMM scratch view; only extents
/// matter); \p prepend_unit prepends a unit extent (the synthetic
/// left-external mode of a no-external product, see
/// ContEngine::init_struct_general)
/// left-external mode of a no-external/left-broadcast product) and
/// \p append_unit appends one (the synthetic right-external mode of a
/// right-broadcast product), see ContEngine::init_struct_general
template <typename Range_>
Range_ fold_range(const Range_& r, const bool prepend_unit = false) const {
Range_ fold_range(const Range_& r, const bool prepend_unit = false,
const bool append_unit = false) const {
const auto* extent = r.extent_data();
container::svector<typename Range_::index1_type> extents;
extents.reserve(r.rank() - nfused_ + (prepend_unit ? 1u : 0u));
extents.reserve(r.rank() - nfused_ + (prepend_unit ? 1u : 0u) +
(append_unit ? 1u : 0u));
if (prepend_unit) extents.push_back(1);
extents.insert(extents.end(), extent + nfused_, extent + r.rank());
if (append_unit) extents.push_back(1);
return Range_(extents);
}

Expand Down Expand Up @@ -154,14 +158,21 @@ class BatchedContractReduce {

const auto& gh = op_.gemm_helper();
const unsigned int nc = gh.num_contract_ranks();
// a no-external product carries a SYNTHETIC unit left-external mode in
// the GemmHelper only (see ContEngine::init_struct_general); detect it
// from the one-rank mismatch with the actual left tile and pad the
// folded left/result views with a unit extent
// a no-external / left-broadcast product carries a SYNTHETIC unit
// left-external mode in the GemmHelper only (see
// ContEngine::init_struct_general); detect it from the one-rank mismatch
// with the actual left tile and pad the folded left/result views with a
// unit extent. A right-broadcast product (right operand entirely fused, no
// contraction) likewise carries a synthetic unit right-external mode,
// detected and padded symmetrically (the unit trails the folded right /
// result views).
const bool unit_external =
(left.range().rank() + 1u == nfused_ + gh.left_rank());
const bool unit_right_external =
(right.range().rank() + 1u == nfused_ + gh.right_rank());
const unsigned int neA = gh.left_rank() - nc - (unit_external ? 1u : 0u);
const unsigned int neB = gh.right_rank() - nc;
const unsigned int neB =
gh.right_rank() - nc - (unit_right_external ? 1u : 0u);

// both args must carry the fused modes as their leading modes, with
// equal extents
Expand All @@ -172,10 +183,13 @@ class BatchedContractReduce {
const std::size_t batch = fused_volume(left.range());
TA_ASSERT(batch == fused_volume(right.range()));

// folded, zero-copy argument views
// folded, zero-copy argument views (the synthetic left-external unit
// leads the folded left view; the synthetic right-external unit trails the
// folded right view)
auto left_folded =
left.reshape(fold_range(left.range(), unit_external), batch);
auto right_folded = right.reshape(fold_range(right.range()), batch);
auto right_folded = right.reshape(
fold_range(right.range(), false, unit_right_external), batch);

if (empty(result)) {
// let the wrapped op allocate (and zero- or beta-0-initialize) the
Expand Down Expand Up @@ -206,8 +220,8 @@ class BatchedContractReduce {
} else {
// accumulate through a folded, zero-copy view of the result
const auto full_range = result.range();
auto result_folded =
result.reshape(fold_range(full_range, unit_external), batch);
auto result_folded = result.reshape(
fold_range(full_range, unit_external, unit_right_external), batch);
op_(result_folded, left_folded, right_folded);
// the wrapped op may REBIND the result instead of writing in place:
// the arena grow-to-cover path (a later K-panel touching inner cells
Expand Down
25 changes: 25 additions & 0 deletions tests/dot_inner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,31 @@ BOOST_AUTO_TEST_CASE(hadamard_outer) {
BOOST_REQUIRE((ToTArrayFixture::are_equal<ShapeComp::True>(ref, out)));
}

// RIGHT operand entirely fused at the outer level (a per-fused-block scale on
// the right): bk;ab . b;ab -> bk. Outer b fused, k external (from left), the
// right operand carries NO outer external/contracted mode, so the folded right
// operand is rank-0 and needs the synthetic unit RIGHT-external mode
// (ContEngine::synthetic_unit_right_external). This is the denest (->T) analog
// of the CSV-CCk fused-broadcast-right shape; without the fix the folded right
// reshape aborts. inner ab fully contracted.
BOOST_AUTO_TEST_CASE(broadcast_right_outer) {
TA::TiledRange a_tr{{0, 2, 4}, {0, 3, 4}}; // b, k
TA::TiledRange b_tr{{0, 2, 4}}; // b
auto A = random_array<ArrayToT>(a_tr, {3, 2});
auto B = random_array<ArrayToT>(b_tr, {3, 2});

ArrayT ref;
{
LegacyEinsumGuard g;
ref = TA::einsum<DeNest::True>("bk;mn,b;mn->bk", A, B);
}

ArrayT out;
out("b,k") = A("b,k;m,n").dot_inner(B("b;m,n"));

BOOST_REQUIRE((ToTArrayFixture::are_equal<ShapeComp::True>(ref, out)));
}

// outer: ipk x iqk -> ipq (Hadamard i, external p & q, contracted-outer k);
// inner ab fully contracted. Exercises the Contraction/General outer routing.
BOOST_AUTO_TEST_CASE(hadamard_external_contracted_outer) {
Expand Down
Loading
Loading