diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 081148dd60..a18607af7e 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -735,12 +735,20 @@ class ContEngine : public BinaryEngine { // the synthetic mode from the one-rank mismatch and pad their folded views // with a unit extent. const unsigned int u = synthetic_unit_left_external(); + // - a FUSED BROADCAST ON THE RIGHT (the right operand is entirely fused, + // no contraction, e.g. C("b,k") = A("b,k") * B("b")) folds to a rank-0 + // RIGHT operand; a synthetic unit RIGHT-external mode restores a + // supported (M,K) x (K,1) -> (M,1) shape (see + // synthetic_unit_right_external()). + const unsigned int u_right = synthetic_unit_right_external(); // the tile op operates on the folded (fused-mode-free) shapes; the - // synthetic unit mode leads the folded left operand, so it is NoTrans + // synthetic unit mode leads the folded left operand (trails the folded + // right operand), so it is NoTrans on that side const auto left_op = u ? math::blas::NoTranspose : to_cblas_op(left_outer_permtype_); - const auto right_op = to_cblas_op(right_outer_permtype_); + const auto right_op = + u_right ? math::blas::NoTranspose : to_cblas_op(right_outer_permtype_); // As in init_struct, the ContractReduce tile op needs the per-cell inner // element op when the operands are nested -- including the dot_inner regime // (denest_to_scalar), where the result tile is plain but the nested modes @@ -749,9 +757,10 @@ class ContEngine : public BinaryEngine { TiledArray::detail::is_tensor_of_tensor_v || denest_to_scalar; if constexpr (!tot_aware_op) { - op_ = op_type(left_op, right_op, factor_, outer_size(indices_) - nh + u, + op_ = op_type(left_op, right_op, factor_, + outer_size(indices_) - nh + u + u_right, outer_size(left_indices_) - nh + u, - outer_size(right_indices_) - nh); + outer_size(right_indices_) - nh + u_right); } else { // the batched tile op must be perm-free (BatchedContractReduce cannot // host the folded-rank result permutation); the outer perm is handled @@ -772,10 +781,11 @@ class ContEngine : public BinaryEngine { // factor_ is absorbed into element_nonreturn_op_ op_ = op_type(left_op, right_op, scalar_type(1), - outer_size(indices_) - nh + u, + outer_size(indices_) - nh + u + u_right, outer_size(left_indices_) - nh + u, - outer_size(right_indices_) - nh, BipartitePermutation{}, - this->element_nonreturn_op_, std::move(this->arena_plan_)); + outer_size(right_indices_) - nh + u_right, + BipartitePermutation{}, this->element_nonreturn_op_, + std::move(this->arena_plan_)); // ce+e, ce+ce_right and ce+ce_left are mutually exclusive; at most one // is non-null and only one install fires (see init_struct) if constexpr (TiledArray::detail::is_tensor_of_tensor_v) { @@ -826,6 +836,20 @@ class ContEngine : public BinaryEngine { : 0u; } + /// \return 1 if the folded general product needs a SYNTHETIC unit + /// right-external mode, else 0. Mirror of synthetic_unit_left_external() for + /// the RIGHT operand: a rank-0 RIGHT operand arises when the right argument + /// is entirely fused with no contraction (a fused broadcast on the right, + /// e.g. C("b,k") = A("b,k") * B("b")), i.e. + /// outer_size(right_indices_) == n_fused_modes_. A left-external unit cannot + /// fix it (the right operand stays rank-0), so a unit right-external mode + /// (carried only in the GemmHelper) restores a supported (M,K) x (K,1) -> + /// (M,1) shape. The result gains a trailing unit mode that is absent from the + /// actual tranges/shapes/tiles, exactly as the left case prepends one. + unsigned int synthetic_unit_right_external() const { + return (outer_size(right_indices_) == n_fused_modes_) ? 1u : 0u; + } + /// Tiled range factory function for a general product /// \return The result tiled range: the fused mode ranges followed by the @@ -837,8 +861,9 @@ class ContEngine : public BinaryEngine { // GemmHelper only (see synthetic_unit_left_external()); the actual tranges // do not have it const unsigned int u = synthetic_unit_left_external(); + const unsigned int u_right = synthetic_unit_right_external(); const unsigned int neA = op_.gemm_helper().left_rank() - nc - u; - const unsigned int neB = op_.gemm_helper().right_rank() - nc; + const unsigned int neB = op_.gemm_helper().right_rank() - nc - u_right; typename trange_type::Ranges ranges(nh + neA + neB); unsigned int i = 0ul; @@ -893,8 +918,9 @@ class ContEngine : public BinaryEngine { // GemmHelper only (see synthetic_unit_left_external()); the actual tranges // do not have it const unsigned int u = synthetic_unit_left_external(); + const unsigned int u_right = synthetic_unit_right_external(); const unsigned int neA = op_.gemm_helper().left_rank() - nc - u; - const unsigned int neB = op_.gemm_helper().right_rank() - nc; + const unsigned int neB = op_.gemm_helper().right_rank() - nc - u_right; // Get pointers to the argument sizes const auto* MADNESS_RESTRICT const left_tiles_size = diff --git a/src/TiledArray/sparse_shape.h b/src/TiledArray/sparse_shape.h index 5ab12a3dee..21988fe248 100644 --- a/src/TiledArray/sparse_shape.h +++ b/src/TiledArray/sparse_shape.h @@ -1727,13 +1727,21 @@ class SparseShape { const bool unit_external = (tile_norms_.range().rank() + 1u == nfused + gemm_helper.left_rank()); const unsigned int u = unit_external ? 1u : 0u; + // a fused broadcast on the RIGHT carries a SYNTHETIC unit right-external + // mode in the GemmHelper only (mirror of unit_external; see + // ContEngine::synthetic_unit_right_external); detect it from the one-rank + // mismatch with the actual right norm tensor and pad the folded + // right/result views with a unit extent + const bool unit_right_external = (other.tile_norms_.range().rank() + 1u == + nfused + gemm_helper.right_rank()); + const unsigned int u_right = unit_right_external ? 1u : 0u; // check that the ranks match the folded gemm ranks plus the fused modes, // and that the fused and contracted mode extents of the two shapes are // congruent TA_ASSERT(tile_norms_.range().rank() + u == nfused + gemm_helper.left_rank()); - TA_ASSERT(other.tile_norms_.range().rank() == + TA_ASSERT(other.tile_norms_.range().rank() + u_right == nfused + gemm_helper.right_rank()); for (unsigned int d = 0u; d < nfused; ++d) TA_ASSERT(left_extent[d] == right_extent[d]); @@ -1751,14 +1759,16 @@ class SparseShape { for (unsigned int i = gemm_helper.left_inner_begin(); i < gemm_helper.left_inner_end(); ++i) K *= left_extent[nfused + i - u]; - for (unsigned int i = gemm_helper.right_outer_begin(); - i < gemm_helper.right_outer_end(); ++i) - N *= right_extent[nfused + i]; + if (!unit_right_external) + for (unsigned int i = gemm_helper.right_outer_begin(); + i < gemm_helper.right_outer_end(); ++i) + N *= right_extent[nfused + i]; // result size vectors: fused modes (from this), then the left and right - // outer modes (the synthetic unit left-external mode is absent from the - // actual result) - const unsigned int result_rank = nfused + gemm_helper.result_rank() - u; + // outer modes (the synthetic unit left-/right-external modes are absent + // from the actual result) + const unsigned int result_rank = + nfused + gemm_helper.result_rank() - u - u_right; std::shared_ptr result_size_vectors( new vector_type[result_rank], std::default_delete()); unsigned int x = 0ul; @@ -1768,9 +1778,10 @@ class SparseShape { for (unsigned int i = gemm_helper.left_outer_begin(); i < gemm_helper.left_outer_end(); ++i, ++x) result_size_vectors.get()[x] = size_vectors_.get()[nfused + i]; - for (unsigned int i = gemm_helper.right_outer_begin(); - i < gemm_helper.right_outer_end(); ++i, ++x) - result_size_vectors.get()[x] = other.size_vectors_.get()[nfused + i]; + if (!unit_right_external) + for (unsigned int i = gemm_helper.right_outer_begin(); + i < gemm_helper.right_outer_end(); ++i, ++x) + result_size_vectors.get()[x] = other.size_vectors_.get()[nfused + i]; // the result norm tensor over (fused..., left outer..., right outer...) using range_type = typename Tensor::range_type; @@ -1788,22 +1799,28 @@ class SparseShape { lobounds.push_back(tile_norms_.range().lobound_data()[nfused + i]); upbounds.push_back(tile_norms_.range().upbound_data()[nfused + i]); } - for (unsigned int i = gemm_helper.right_outer_begin(); - i < gemm_helper.right_outer_end(); ++i) { - lobounds.push_back(other.tile_norms_.range().lobound_data()[nfused + i]); - upbounds.push_back(other.tile_norms_.range().upbound_data()[nfused + i]); - } + if (!unit_right_external) + for (unsigned int i = gemm_helper.right_outer_begin(); + i < gemm_helper.right_outer_end(); ++i) { + lobounds.push_back( + other.tile_norms_.range().lobound_data()[nfused + i]); + upbounds.push_back( + other.tile_norms_.range().upbound_data()[nfused + i]); + } Tensor result_norms(range_type(lobounds, upbounds), 0); // the range spanned by modes [nfused, rank) of \p r, rebased to zero // lobounds (scratch view for the slab-batched norm GEMM) auto fold_range = [nfused](const range_type& r, - const bool prepend_unit = false) { + const bool prepend_unit = false, + const bool append_unit = false) { const auto* extent = r.extent_data(); container::svector extents; - extents.reserve(r.rank() - nfused + (prepend_unit ? 1u : 0u)); + extents.reserve(r.rank() - nfused + (prepend_unit ? 1u : 0u) + + (append_unit ? 1u : 0u)); if (prepend_unit) extents.push_back(1); extents.insert(extents.end(), extent + nfused, extent + r.rank()); + if (append_unit) extents.push_back(1); return range_type(extents); }; @@ -1847,9 +1864,11 @@ class SparseShape { // buffer, so the accumulation lands in place auto left_folded = left.reshape(fold_range(left.range(), unit_external), H); - auto right_folded = right.reshape(fold_range(right.range()), H); + auto right_folded = right.reshape( + fold_range(right.range(), false, unit_right_external), H); auto result_folded = result_norms.reshape( - fold_range(result_norms.range(), unit_external), H); + fold_range(result_norms.range(), unit_external, unit_right_external), + H); result_folded.gemm(left_folded, right_folded, abs_factor, gemm_helper); // Hard zero tiles that are below the zero threshold. diff --git a/src/TiledArray/tile_op/batched_contract_reduce.h b/src/TiledArray/tile_op/batched_contract_reduce.h index 48b028e267..78ec5b8e42 100644 --- a/src/TiledArray/tile_op/batched_contract_reduce.h +++ b/src/TiledArray/tile_op/batched_contract_reduce.h @@ -63,15 +63,19 @@ class BatchedContractReduce { /// \return the range spanned by modes [nfused_, rank) of \p r, rebased to /// zero lobounds (the folded view is a GEMM scratch view; only extents /// matter); \p prepend_unit prepends a unit extent (the synthetic - /// left-external mode of a no-external product, see - /// ContEngine::init_struct_general) + /// left-external mode of a no-external/left-broadcast product) and + /// \p append_unit appends one (the synthetic right-external mode of a + /// right-broadcast product), see ContEngine::init_struct_general template - Range_ fold_range(const Range_& r, const bool prepend_unit = false) const { + Range_ fold_range(const Range_& r, const bool prepend_unit = false, + const bool append_unit = false) const { const auto* extent = r.extent_data(); container::svector extents; - extents.reserve(r.rank() - nfused_ + (prepend_unit ? 1u : 0u)); + extents.reserve(r.rank() - nfused_ + (prepend_unit ? 1u : 0u) + + (append_unit ? 1u : 0u)); if (prepend_unit) extents.push_back(1); extents.insert(extents.end(), extent + nfused_, extent + r.rank()); + if (append_unit) extents.push_back(1); return Range_(extents); } @@ -154,14 +158,21 @@ class BatchedContractReduce { const auto& gh = op_.gemm_helper(); const unsigned int nc = gh.num_contract_ranks(); - // a no-external product carries a SYNTHETIC unit left-external mode in - // the GemmHelper only (see ContEngine::init_struct_general); detect it - // from the one-rank mismatch with the actual left tile and pad the - // folded left/result views with a unit extent + // a no-external / left-broadcast product carries a SYNTHETIC unit + // left-external mode in the GemmHelper only (see + // ContEngine::init_struct_general); detect it from the one-rank mismatch + // with the actual left tile and pad the folded left/result views with a + // unit extent. A right-broadcast product (right operand entirely fused, no + // contraction) likewise carries a synthetic unit right-external mode, + // detected and padded symmetrically (the unit trails the folded right / + // result views). const bool unit_external = (left.range().rank() + 1u == nfused_ + gh.left_rank()); + const bool unit_right_external = + (right.range().rank() + 1u == nfused_ + gh.right_rank()); const unsigned int neA = gh.left_rank() - nc - (unit_external ? 1u : 0u); - const unsigned int neB = gh.right_rank() - nc; + const unsigned int neB = + gh.right_rank() - nc - (unit_right_external ? 1u : 0u); // both args must carry the fused modes as their leading modes, with // equal extents @@ -172,10 +183,13 @@ class BatchedContractReduce { const std::size_t batch = fused_volume(left.range()); TA_ASSERT(batch == fused_volume(right.range())); - // folded, zero-copy argument views + // folded, zero-copy argument views (the synthetic left-external unit + // leads the folded left view; the synthetic right-external unit trails the + // folded right view) auto left_folded = left.reshape(fold_range(left.range(), unit_external), batch); - auto right_folded = right.reshape(fold_range(right.range()), batch); + auto right_folded = right.reshape( + fold_range(right.range(), false, unit_right_external), batch); if (empty(result)) { // let the wrapped op allocate (and zero- or beta-0-initialize) the @@ -206,8 +220,8 @@ class BatchedContractReduce { } else { // accumulate through a folded, zero-copy view of the result const auto full_range = result.range(); - auto result_folded = - result.reshape(fold_range(full_range, unit_external), batch); + auto result_folded = result.reshape( + fold_range(full_range, unit_external, unit_right_external), batch); op_(result_folded, left_folded, right_folded); // the wrapped op may REBIND the result instead of writing in place: // the arena grow-to-cover path (a later K-panel touching inner cells diff --git a/tests/dot_inner.cpp b/tests/dot_inner.cpp index 6ca6adc73b..4f1b0132d8 100644 --- a/tests/dot_inner.cpp +++ b/tests/dot_inner.cpp @@ -94,6 +94,31 @@ BOOST_AUTO_TEST_CASE(hadamard_outer) { BOOST_REQUIRE((ToTArrayFixture::are_equal(ref, out))); } +// RIGHT operand entirely fused at the outer level (a per-fused-block scale on +// the right): bk;ab . b;ab -> bk. Outer b fused, k external (from left), the +// right operand carries NO outer external/contracted mode, so the folded right +// operand is rank-0 and needs the synthetic unit RIGHT-external mode +// (ContEngine::synthetic_unit_right_external). This is the denest (->T) analog +// of the CSV-CCk fused-broadcast-right shape; without the fix the folded right +// reshape aborts. inner ab fully contracted. +BOOST_AUTO_TEST_CASE(broadcast_right_outer) { + TA::TiledRange a_tr{{0, 2, 4}, {0, 3, 4}}; // b, k + TA::TiledRange b_tr{{0, 2, 4}}; // b + auto A = random_array(a_tr, {3, 2}); + auto B = random_array(b_tr, {3, 2}); + + ArrayT ref; + { + LegacyEinsumGuard g; + ref = TA::einsum("bk;mn,b;mn->bk", A, B); + } + + ArrayT out; + out("b,k") = A("b,k;m,n").dot_inner(B("b;m,n")); + + BOOST_REQUIRE((ToTArrayFixture::are_equal(ref, out))); +} + // outer: ipk x iqk -> ipq (Hadamard i, external p & q, contracted-outer k); // inner ab fully contracted. Exercises the Contraction/General outer routing. BOOST_AUTO_TEST_CASE(hadamard_external_contracted_outer) { diff --git a/tests/general_product.cpp b/tests/general_product.cpp index 48b6b39329..e76fa24db9 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -257,11 +257,20 @@ BOOST_AUTO_TEST_CASE(expression_general_product_dense_broadcast) { ref("b,k") = s_rep("k,b") * v("b,k"); // (a) LEFT operand fully fused, straight through the expression layer + // (synthetic unit LEFT-external mode) TA::TArrayD cc; BOOST_REQUIRE_NO_THROW(cc("b,k") = s("b") * v("b,k")); BOOST_CHECK_SMALL(diff_norm(cc, ref, "b,k"), 1e-10); - // (b) einsum, both operand orders (RIGHT operand fully fused too) + // (b) RIGHT operand fully fused, straight through the expression layer + // (synthetic unit RIGHT-external mode). The expression layer respects + // operand order (no swap), so this is the only place the right-broadcast + // shape is exercised; einsum (case (c)) may reorder operands and dodge it. + TA::TArrayD cc_r; + BOOST_REQUIRE_NO_THROW(cc_r("b,k") = v("b,k") * s("b")); + BOOST_CHECK_SMALL(diff_norm(cc_r, ref, "b,k"), 1e-10); + + // (c) einsum, both operand orders (RIGHT operand fully fused too) auto c_lr = TA::einsum(s("b"), v("b,k"), "b,k"); BOOST_CHECK_SMALL(diff_norm(c_lr, ref, "b,k"), 1e-10); auto c_rl = TA::einsum(v("b,k"), s("b"), "b,k"); @@ -439,6 +448,26 @@ BOOST_AUTO_TEST_CASE(expression_general_product_sparse_batched_outer) { BOOST_CHECK_SMALL(diff_norm_sp(c, c_ref, "b,i,k"), 1e-10); } +BOOST_AUTO_TEST_CASE(expression_general_product_sparse_broadcast_right) { + // block-sparse fused broadcast on the RIGHT: C(b,k) = A(b,k) * B(b), the + // right operand entirely fused (no outer external/contracted mode). Exercises + // the synthetic unit RIGHT-external mode in SparseShape::gemm_batched (the + // shape-level analog) as well as the batched tile op, through the expression + // layer (which respects operand order). + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + + TA::TiledRange tr_a{{0, 2, 5}, {0, 4, 5}}; // b, k + TA::TiledRange tr_b{{0, 2, 5}}; // b + auto a = make_patterned_sparse_array(world, tr_a, 1.0, 3); + auto b = make_patterned_sparse_array(world, tr_b, 2.0, 2); + + TA::TSpArrayD c; + BOOST_REQUIRE_NO_THROW(c("b,k") = a("b,k") * b("b")); + auto c_ref = TA::einsum(a("b,k"), b("b"), "b,k"); + BOOST_CHECK_SMALL(diff_norm_sp(c, c_ref, "b,k"), 1e-10); +} + namespace { using TArrayToT = @@ -529,6 +558,28 @@ BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_contraction) { BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); } +BOOST_AUTO_TEST_CASE(expression_general_product_tot_broadcast_right) { + // ToT general product, RIGHT operand entirely fused at the outer level (a + // per-fused-block scale on the right), inner Hadamard: + // C("b,k;m") = A("b,k;m") * B("b;m") + // outer: b fused, k free (from left), right operand B("b") carries NO outer + // external/contracted mode -> a rank-0 RIGHT fold that needs the synthetic + // unit RIGHT-external mode (mirror of the left broadcast above). The + // expression layer respects operand order, so this exercises the + // right-broadcast path that einsum's operand reordering would otherwise hide. + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr_a{{0, 2, 4}, {0, 3, 4}}; // b, k + TA::TiledRange tr_b{{0, 2, 4}}; // b + auto a = make_patterned_tot_array(world, tr_a, {3}, 1.0); + auto b = make_patterned_tot_array(world, tr_b, {3}, 2.0); + + TArrayToT c; + BOOST_REQUIRE_NO_THROW(c("b,k;m") = a("b,k;m") * b("b;m")); + auto c_ref = TA::einsum(a("b,k;m"), b("b;m"), "b,k;m"); + BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); +} + BOOST_AUTO_TEST_CASE(expression_general_product_tot_times_t) { // mixed ToT x T general product (inner Scale): // C("b,i,k;m") = A("b,i,j;m") * B("b,j,k")