From 77d87a3e8d321a2958c8962308c2a63e03abe92c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Rze=C5=BAnicki?= Date: Fri, 12 Jun 2026 10:17:58 +0200 Subject: [PATCH 1/8] zipper_algebra: Add XOR --- src/experimental/zipper_algebra.rs | 694 ++++++++++++++++++++++++++++- 1 file changed, 675 insertions(+), 19 deletions(-) diff --git a/src/experimental/zipper_algebra.rs b/src/experimental/zipper_algebra.rs index 13181e1..56937e3 100644 --- a/src/experimental/zipper_algebra.rs +++ b/src/experimental/zipper_algebra.rs @@ -29,7 +29,7 @@ pub use zipper_algebra_poly::ZipperMergeF; /// without visiting unrelated regions. /// /// Each method delegates to a corresponding free function ([`zipper_join`], -/// [`zipper_meet`], [`zipper_subtract`]), preserving their performance +/// [`zipper_meet`], [`zipper_subtract`], [`zipper_xor`]), preserving their performance /// characteristics and semantics. /// /// # Semantics @@ -39,6 +39,7 @@ pub use zipper_algebra_poly::ZipperMergeF; /// - [`join`](Self::join): least upper bound (union-like merge), /// - [`meet`](Self::meet): greatest lower bound (intersection), /// - [`subtract`](Self::subtract): asymmetric difference (`lhs \ rhs`). +/// - [`xor`](Self::xor): symmetric difference (`(lhs \/ rhs) \ (lhs /\ rhs)`). /// /// All operations write their result into a separate output zipper implementing /// [`ZipperWriting`]. @@ -56,6 +57,7 @@ pub use zipper_algebra_poly::ZipperMergeF; /// - [`zipper_join`] /// - [`zipper_meet`] /// - [`zipper_subtract`] +/// - [`zipper_xor`] pub trait ZipperAlgebraExt: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving + Sized { @@ -88,6 +90,16 @@ pub trait ZipperAlgebraExt: { zipper_subtract(self, rhs, out); } + + #[inline] + fn xor(&mut self, rhs: &mut ZR, out: &mut Out) + where + V: DistributiveLattice + Lattice, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Out: ZipperWriting, + { + zipper_xor(self, rhs, out); + } } impl ZipperAlgebraExt @@ -327,6 +339,90 @@ pub fn zipper_subtract3( zipper_merge3::(lhs, mid, rhs, out); } +/// Performs a symmetric difference of two radix-256 tries using zipper traversal. +/// +/// This function computes the symmetric difference (`XOR`) of two tries whose +/// values form a distributive lattice. The traversal strategy is identical to +/// [`zipper_join`]: +/// +/// - Child edges are treated as sorted sequences and merged lexicographically. +/// - Subtries that exist only on one side are grafted directly into the output. +/// - The traversal descends only into child edges present in both inputs. +/// - Descent is implemented iteratively via zipper movement and an explicit +/// depth counter. +/// +/// # Value semantics +/// +/// When both tries contain a value at the same key, the values are combined +/// using the lattice symmetric-difference operation. +/// +/// For distributive lattices, symmetric difference is defined as: +/// +/// ```text +/// (a ∨ b) \ (a ∧ b) +/// ``` +/// +/// Equivalently: +/// +/// ```text +/// (a \ b) ∨ (b \ a) +/// ``` +/// +/// If the resulting value is the lattice bottom element, no value is written +/// to the output. +/// +/// Values that occur in only one input trie are propagated unchanged. +/// +/// # Complexity +/// +/// Let: +/// +/// - `h` be the maximum key length, +/// - `d` be the size of overlapping subtries, +/// - `f` be the number of distinct child edges encountered during traversal. +/// +/// Then: +/// +/// - Best case (disjoint tries): `O(h)` +/// - Typical case: `O(h + f)` +/// - Worst case (identical structure): `O(n)` +/// +/// Entire disjoint subtries are copied without traversal whenever possible. +/// +/// # Notes +/// +/// The traversal logic is identical to [`zipper_join`]. Only the value +/// combination operation differs. +pub fn zipper_xor(lhs: &mut ZL, rhs: &mut ZR, out: &mut Out) +where + V: DistributiveLattice + Lattice + Clone + Send + Sync, + A: Allocator, + ZL: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Out: ZipperWriting, +{ + zipper_merge::(lhs, rhs, out); +} + +/// Performs a symmetric difference (XOR) of three radix-256 tries using zipper traversal. +/// That is, it performs: (`lhs` △ `mid`) △ `rhs`, where `△` = [`zipper_xor`] +/// +/// # See also +/// +/// [`zipper_xor`] +/// +pub fn zipper_xor3(lhs: &mut ZL, mid: &mut ZM, rhs: &mut ZR, out: &mut Out) +where + V: DistributiveLattice + Lattice + Clone + Send + Sync, + A: Allocator, + ZL: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZM: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + ZR: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Out: ZipperWriting, +{ + zipper_merge3::(lhs, mid, rhs, out); +} + trait MergePolicy { #[inline] fn on_left_only(z: &mut Z, range: ByteMask, out: &mut Out) @@ -355,7 +451,8 @@ trait MergePolicy { Out: ZipperWriting; fn descend_on_some_equal(mask: u64) -> bool; - fn on_id(z: &mut Z, out: &mut Out) + + fn on_id(z: &mut Z, n: usize, out: &mut Out) where A: Allocator, Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, @@ -427,7 +524,7 @@ where if let Some(v) = P::combine_n_times(lift(lhs.val()), 2) { out.set_val(v); } - P::on_id(lhs, out); + P::on_id(lhs, 2, out); return; } @@ -476,7 +573,7 @@ where if let Some(v) = P::combine_n_times(lift(lhs.val()), 2) { out.set_val(v); } - P::on_id(lhs, out); + P::on_id(lhs, 2, out); rhs.ascend_byte(); rhs_next = rhs_mask.next_bit(lhs_byte); @@ -604,7 +701,7 @@ where if let Some(v) = P::combine_n_times(lift(lhs.val()), 3) { out.set_val(v); } - P::on_id(lhs, out); + P::on_id(lhs, 3, out); return; } @@ -709,7 +806,7 @@ where if let Some(v) = P::combine_n_times(lift(lhs.val()), 3) { out.set_val(v); } - P::on_id(lhs, out); + P::on_id(lhs, 3, out); rhs.ascend_byte(); r = rhs_mask.next_bit(min); @@ -811,7 +908,7 @@ fn zipper_merge4( if let Some(v) = P::combine_n_times(lift(z0.val()), 4) { out.set_val(v); } - P::on_id(z0, out); + P::on_id(z0, 4, out); return; } @@ -873,7 +970,7 @@ fn zipper_merge4( if let Some(v) = P::combine_n_times(lift(z0.val()), 4) { out.set_val(v); } - P::on_id(z0, out); + P::on_id(z0, 4, out); z3.ascend_byte(); b3 = m3.next_bit(min); @@ -1201,6 +1298,69 @@ where zipper_merge_n_mono::(zs, (1 << N) - 1, out); } +/// Performs an n-way symmetric difference of radix-256 tries using zipper traversal. +/// +/// The input tries are interpreted as sparse mappings from paths to values, +/// where missing paths implicitly contain the lattice bottom element. +/// +/// For each path `p`, the resulting value is computed as the n-ary symmetric +/// difference of all values present at `p`: +/// +/// `text +/// result(p) = xor(v₁(p), v₂(p), ..., vₙ(p)) +/// ` +/// +/// where absent values are treated as bottom. +/// +/// Operationally, the traversal behaves similarly to an n-way join: +/// +/// - A child edge is traversed whenever it is present in at least one input. +/// - Subtries that appear in only one input are grafted directly into the +/// output. +/// - Values at coincident paths are combined using the lattice symmetric +/// difference operation. +/// - Paths whose resulting value is bottom are omitted from the output. +/// +/// # Why traversal follows join semantics +/// +/// Although symmetric difference is often associated with parity ("a path +/// survives iff it appears an odd number of times"), this implementation +/// performs XOR on values rather than on path presence. +/// +/// This distinction is important because cancellation at a path does not imply +/// cancellation of its descendants. A node whose value XORs to bottom may still +/// contain descendant paths whose values survive. +/// +/// Consequently, traversal proceeds whenever any input contains a child edge, +/// just as in join. Pruning is only valid when an entire subtrie is known to +/// cancel. +/// +/// # Complexity +/// +/// Let: +/// +/// - `h` be the maximum key length, +/// - `f` be the number of frontier edges visited during traversal, +/// - `d` be the size of overlapping subtries. +/// +/// Then: +/// +/// - Best case (disjoint tries): `O(h)` +/// - Typical case: `O(h + f)` +/// - Worst case: `O(n)` +/// +/// As with join, large disjoint subtries are copied without traversal whenever +/// possible. +pub fn zipper_n_xor(zs: &mut [Z; N], out: &mut Out) +where + V: Lattice + DistributiveLattice + Clone + Send + Sync + Unpin, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Out: ZipperWriting, + A: Allocator, +{ + zipper_merge_n_mono::(zs, (1 << N) - 1, out); +} + // small micro-helpers #[inline(always)] fn for_each_bit(mut bits: u64, mut f: impl FnMut(usize)) { @@ -1302,10 +1462,11 @@ where // check for node-sharing first if all_active_share(zs, active) { let z0 = first_active_mut(zs, active); - if let Some(v) = P::combine_n_times(lift(z0.val()), active.count_ones() as usize) { + let n = active.count_ones() as usize; + if let Some(v) = P::combine_n_times(lift(z0.val()), n) { out.set_val(v); } - P::on_id(z0, out); + P::on_id(z0, n, out); return; } @@ -1383,6 +1544,7 @@ where } Some(a) => { // Dispatch + let cnt = frontier.count_ones() as usize; // - Case A: full match (frontier == all bits) if frontier == active { @@ -1396,12 +1558,10 @@ where // check structural sharing first if all_active_share(zs, active) { let z0 = first_active_mut(zs, active); - if let Some(v) = - P::combine_n_times(lift(z0.val()), active.count_ones() as usize) - { + if let Some(v) = P::combine_n_times(lift(z0.val()), cnt) { out.set_val(v); } - P::on_id(z0, out); + P::on_id(z0, cnt, out); for_each_bit(active, |i| { zs[i].ascend_byte(); @@ -1424,7 +1584,6 @@ where continue 'merge_level; } - let cnt = frontier.count_ones(); // - Case B: singleton (|frontier| = 1) if (cnt == 1) { let i = frontier.trailing_zeros() as usize; @@ -1613,7 +1772,7 @@ where if let Some(v) = z0.val() { out.set_val(v.clone()); } - Meet::on_id(z0, out); + Meet::on_id(z0, 1, out); return; } [z0, z1] => { @@ -1755,7 +1914,7 @@ impl MergePolicy for Join { } #[inline] - fn on_id(z: &mut Z, out: &mut Out) + fn on_id(z: &mut Z, _n: usize, out: &mut Out) where A: Allocator, Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, @@ -1812,7 +1971,7 @@ impl MergePolicy for Meet { } #[inline(always)] - fn on_id(z: &mut Z, out: &mut Out) + fn on_id(z: &mut Z, _n: usize, out: &mut Out) where A: Allocator, Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, @@ -1934,7 +2093,7 @@ impl MergePolicy for Subtract { } #[inline] - fn on_id(_z: &mut Z, _out: &mut Out) + fn on_id(_z: &mut Z, _n: usize, _out: &mut Out) where A: Allocator, Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, @@ -1991,6 +2150,95 @@ fn subtract_impl<'a, V: DistributiveLattice + Clone>( } } +// ==================== XOR ==================== + +struct Xor; +impl MergePolicy for Xor { + #[inline(always)] + fn on_single(z: &mut Z, _mask: u64, range: ByteMask, out: &mut Out) + where + A: Allocator, + Z: ZipperInfallibleSubtries + ZipperMoving, + Out: ZipperWriting, + { + out.graft_masked_branches(z, range, false) + } + + #[inline(always)] + fn descend_on_some_equal(_mask: u64) -> bool { + // the traversal shape essentially the same as join + true + } + + #[inline(always)] + fn on_id(z: &mut Z, n: usize, out: &mut Out) + where + A: Allocator, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, + Out: ZipperWriting, + { + // an element belongs to A△A△A ... iff it belongs to an odd number of the sets. + if n % 2 == 1 { + out.graft(z) + } + } +} + +impl ValuePolicy for Xor { + fn combine_impl<'a>(l: Option>, r: Option>) -> Option> { + match (l, r) { + (None, x) | (x, None) => x, + + (Some(a), Some(b)) => match a.pjoin(&b) { + AlgebraicResult::None => None, + AlgebraicResult::Identity(join) => match a.pmeet(&b) { + AlgebraicResult::None => { + if join & SELF_IDENT != 0 { + Some(a) + } else { + Some(b) + } + } + AlgebraicResult::Identity(meet) => { + if join == meet { + None + } else if join & SELF_IDENT != 0 { + subtract_impl(a, b) + } else { + subtract_impl(b, a) + } + } + AlgebraicResult::Element(meet) => { + if join & SELF_IDENT != 0 { + subtract_impl(a, Cow::Owned(meet)) + } else { + subtract_impl(b, Cow::Owned(meet)) + } + } + }, + AlgebraicResult::Element(join) => match a.pmeet(&b) { + AlgebraicResult::None => Some(Cow::Owned(join)), + AlgebraicResult::Identity(meet) => { + if meet & SELF_IDENT != 0 { + subtract_impl(Cow::Owned(join), a) + } else { + subtract_impl(Cow::Owned(join), b) + } + } + AlgebraicResult::Element(meet) => { + subtract_impl(Cow::Owned(join), Cow::Owned(meet)) + } + }, + }, + } + } + + fn combine_n_times(val: Option>, n: usize) -> Option { + // an element belongs to A△A△A ... iff it belongs to an odd number of the sets. + if n % 2 == 1 { unlift(val) } else { None } + } +} + mod zipper_algebra_poly { // ==================== Machinery for zipper_merge_n ==================== use crate as pathmap; @@ -2095,6 +2343,16 @@ mod zipper_algebra_poly { self.merge_n::(out); } + /// Performs an N-way ordered symmetric difference of radix-256 trie zippers using a stackless traversal. + /// + /// This function generalizes pairwise [`super::zipper_xor`] to an arbitrary number of input tries, + fn xor_n(self, out: &mut Out) + where + V: Lattice + DistributiveLattice, + { + self.merge_n::(out); + } + fn merge_n

(self, out: &mut Out) where P: super::MergePolicy + super::ValuePolicy; @@ -2328,6 +2586,28 @@ mod zipper_algebra_poly { ( $( &mut $z ),+ ).subtract_n(&mut $out) }}; } + + /// Performs an N-ary zipper symmetric difference by borrowing all inputs mutably + /// and forwarding them to [`ZipperMergeF::xor_n`]. + /// + /// # Example + /// ```ignore + /// zipper_xor_n!(z1, z2, z3 => out); + /// ``` + /// + /// Expands roughly to: + /// ```ignore + /// (&mut z1, &mut z2, &mut z3).xor_n(&mut out) + /// ``` + /// + /// # See also + /// [`ZipperMergeF::xor_n`] + #[macro_export] + macro_rules! zipper_xor_n { + ( $($z:ident),+ => $out:ident ) => {{ + ( $( &mut $z ),+ ).xor_n(&mut $out) + }}; +} } #[cfg(test)] @@ -3801,6 +4081,382 @@ mod tests { } } + mod xor { + use super::*; + use crate::experimental::zipper_algebra::{ + ZipperAlgebraExt, ZipperMergeF, zipper_join, zipper_xor3, + }; + use crate::zipper_xor_n; + + #[test] + fn test_disjoint() { + check2( + &DISJOINT_PATHS, + &[DISJOINT_PATHS.0, DISJOINT_PATHS.1].concat(), + |lhs, rhs, out| lhs.xor(rhs, out), + ); + } + + #[test] + fn test_disjoint3() { + check3( + &DISJOINT_PATHS_3, + &[DISJOINT_PATHS_3.0, DISJOINT_PATHS_3.1, DISJOINT_PATHS_3.2].concat(), + |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + ); + } + + #[test] + fn test_disjoint_n() { + checkn( + &DISJOINT_PATHS_N, + &[ + DISJOINT_PATHS_N[0], + DISJOINT_PATHS_N[1], + DISJOINT_PATHS_N[2], + DISJOINT_PATHS_N[3], + DISJOINT_PATHS_N[4], + DISJOINT_PATHS_N[5], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_xor_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + + #[test] + fn test_deep_shared_prefix_then_split() { + check2( + &PATHS_WITH_SHARED_PREFIX, + &[PATHS_WITH_SHARED_PREFIX.0, PATHS_WITH_SHARED_PREFIX.1].concat(), + |lhs, rhs, out| lhs.xor(rhs, out), + ); + } + + #[test] + fn test_deep_shared_prefix_then_split3() { + check3( + &PATHS_WITH_SHARED_PREFIX_3, + &[ + PATHS_WITH_SHARED_PREFIX_3.0, + PATHS_WITH_SHARED_PREFIX_3.1, + PATHS_WITH_SHARED_PREFIX_3.2, + ] + .concat(), + |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + ); + } + + #[test] + fn test_deep_shared_prefix_then_split_n() { + checkn( + &PATHS_WITH_SHARED_PREFIX_N, + &[ + PATHS_WITH_SHARED_PREFIX_N[0], + PATHS_WITH_SHARED_PREFIX_N[1], + PATHS_WITH_SHARED_PREFIX_N[2], + PATHS_WITH_SHARED_PREFIX_N[3], + PATHS_WITH_SHARED_PREFIX_N[4], + PATHS_WITH_SHARED_PREFIX_N[5], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_xor_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + + #[test] + fn test_interleaving_paths() { + check2( + &INTERLEAVING_PATHS, + &[INTERLEAVING_PATHS.0, INTERLEAVING_PATHS.1].concat(), + |lhs, rhs, out| lhs.xor(rhs, out), + ); + } + + #[test] + fn test_interleaving_paths3() { + check3( + &INTERLEAVING_PATHS_3, + &[ + INTERLEAVING_PATHS_3.0, + INTERLEAVING_PATHS_3.1, + INTERLEAVING_PATHS_3.2, + ] + .concat(), + |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + ); + } + + #[test] + fn test_interleaving_paths_n() { + checkn( + &INTERLEAVING_PATHS_N, + &[ + INTERLEAVING_PATHS_N[0], + INTERLEAVING_PATHS_N[1], + INTERLEAVING_PATHS_N[2], + INTERLEAVING_PATHS_N[3], + INTERLEAVING_PATHS_N[4], + INTERLEAVING_PATHS_N[5], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_xor_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + + #[test] + fn test_one_side_empty_at_many_levels() { + let expected: Paths = &[ + (&[0x00, 0x01], 1), + (&[0x00, 0x01, 0x02], 2), + (&[0x01], 4), + (&[0x01, 0x02], 5), + (&[0x01, 0x02, 0x03], 6), + (&[0x01, 0x02, 0x03, 0x04], 7), + (&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], 9), + ]; + check2(&ONE_SIDED_PATHS, expected, |lhs, rhs, out| { + lhs.xor(rhs, out) + }); + } + + #[test] + fn test_one_side_empty_at_many_levels3() { + let expected: Paths = &[ + (&[0x00], 0), + (&[0x00, 0x01], 1), + (&[0x01], 4), + (&[0x01, 0x02], 5), + (&[0x01, 0x02, 0x03], 6), + (&[0x01, 0x02, 0x03, 0x04], 7), + (&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], 9), + ]; + check3(&ONE_SIDED_PATHS_3, expected, |lhs, mid, rhs, out| { + zipper_xor3(lhs, mid, rhs, out) + }); + } + + #[test] + fn test_one_side_empty_at_many_levels_n() { + let expected: Paths = &[ + (&[0x00], 0), + (&[0x00, 0x01], 1), + (&[0x01, 0x02, 0x03, 0x04], 7), + (&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], 9), + ]; + checkn( + &ONE_SIDED_PATHS_N, + expected, + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_xor_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + + #[test] + fn test_almost_identical_paths() { + let expected: Paths = &[(b"hijklmnop", 1), (b"2", 5), (b"3", 6)]; + check2(&ALMOST_IDENTICAL_PATHS, expected, |lhs, rhs, out| { + lhs.xor(rhs, out) + }); + } + + #[test] + fn test_almost_identical_paths3() { + let expected: Paths = &[(b"abcdefg", 0), (b"1", 4), (b"4", 7), (b"5", 8)]; + check3(&ALMOST_IDENTICAL_PATHS_3, expected, |lhs, mid, rhs, out| { + zipper_xor3(lhs, mid, rhs, out) + }); + } + + #[test] + fn test_almost_identical_paths_n() { + let expected: Paths = &[(b"abcdefg", 0), (b"hijklmnop", 1), (b"2", 5), (b"3", 6)]; + checkn( + &ALMOST_IDENTICAL_PATHS_N, + expected, + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_xor_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + + #[test] + fn test_one_side_empty() { + check2(&LHS_EMPTY, LHS_EMPTY.1, |lhs, rhs, out| lhs.xor(rhs, out)); + check2(&RHS_EMPTY, RHS_EMPTY.0, |lhs, rhs, out| lhs.xor(rhs, out)); + } + + #[test] + fn test_one_side_empty3() { + check3( + &LHS_EMPTY_3, + &[LHS_EMPTY_3.1, LHS_EMPTY_3.2].concat(), + |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + ); + check3( + &MID_EMPTY, + &[MID_EMPTY.0, MID_EMPTY.2].concat(), + |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + ); + check3( + &RHS_EMPTY_3, + &[RHS_EMPTY_3.0, RHS_EMPTY_3.1].concat(), + |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + ); + } + + #[test] + fn test_one_side_empty_n() { + checkn( + &LHS_EMPTY_N, + &[ + LHS_EMPTY_N[1], + LHS_EMPTY_N[2], + LHS_EMPTY_N[3], + LHS_EMPTY_N[4], + LHS_EMPTY_N[5], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_xor_n!(z0, z1, z2, z3, z4, z5 => out), + ); + checkn( + &MID_EMPTY_N, + &[ + MID_EMPTY_N[0], + MID_EMPTY_N[1], + MID_EMPTY_N[3], + MID_EMPTY_N[4], + MID_EMPTY_N[5], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_xor_n!(z0, z1, z2, z3, z4, z5 => out), + ); + checkn( + &RHS_EMPTY_N, + &[ + RHS_EMPTY_N[0], + RHS_EMPTY_N[1], + RHS_EMPTY_N[2], + RHS_EMPTY_N[3], + RHS_EMPTY_N[4], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_xor_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + + #[test] + fn test_exact_overlap_divergent_subtries() { + let expected: Paths = &[ + (&[1, 2, 3, 4], 1), + (&[1, 2, 3, 5], 11), + (&[1, 2, 3, 10, 11, 0], 12), + (&[1, 2, 3, 10, 11, 12], 2), + ]; + check2( + &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN, + expected, + |lhs, rhs, out| lhs.xor(rhs, out), + ); + } + + #[test] + fn test_exact_overlap_divergent_subtries3() { + let expected: Paths = &[ + (&[1, 2, 3], 20), + (&[1, 2, 3, 4], 1), + (&[1, 2, 3, 5], 11), + (&[1, 2, 3, 6], 21), + (&[1, 2, 3, 10, 11, 0], 12), + (&[1, 2, 3, 10, 11, 1], 22), + (&[1, 2, 3, 10, 11, 12], 2), + ]; + check3( + &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_3, + expected, + |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + ); + } + + #[test] + fn test_exact_overlap_divergent_subtries_n() { + let expected: Paths = &[ + (&[1, 2, 3, 4], 1), + (&[1, 2, 3, 5], 11), + (&[1, 2, 3, 6], 21), + (&[1, 2, 3, 7], 31), + (&[1, 2, 3, 8], 41), + (&[1, 2, 3, 9], 51), + (&[1, 2, 3, 10, 11, 0], 12), + (&[1, 2, 3, 10, 11, 1], 22), + (&[1, 2, 3, 10, 11, 2], 32), + (&[1, 2, 3, 10, 11, 3], 42), + (&[1, 2, 3, 10, 11, 4], 52), + (&[1, 2, 3, 10, 11, 12], 2), + ]; + checkn( + &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_N, + expected, + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_xor_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + + #[test] + fn test_zigzag() { + let expected: Paths = &[ + (&[1, 1], 0), + (&[2], 1), + (&[3, 2, 1], 4), + (&[4], 4), + (&[4, 3, 2, 1], 5), + (&[1], 0), + (&[1, 2], 1), + (&[3, 4], 4), + (&[4, 3], 5), + ]; + check2(&ZIGZAG_PATHS, expected, |lhs, rhs, out| lhs.xor(rhs, out)); + } + + #[test] + fn test_zigzag3() { + let expected: Paths = &[ + (&[1, 1], 0), + (&[2, 1], 2), + (&[3, 2, 1], 4), + (&[4], 4), + (&[1, 2], 1), + (&[3, 4], 4), + (&[4, 3], 5), + (&[3, 2, 1, 0], 3), + (&[4, 3, 2, 1, 0], 4), + ]; + check3(&ZIGZAG_PATHS_3, expected, |lhs, mid, rhs, out| { + zipper_xor3(lhs, mid, rhs, out) + }); + } + + #[test] + fn test_root_values() { + check2(&PATHS_WITH_ROOT_VALS_AND_CHILDREN, &[], |lhs, rhs, out| { + lhs.xor(rhs, out) + }); + } + + #[test] + fn test_root_values3() { + check3( + &PATHS_WITH_ROOT_VALS_AND_CHILDREN_3, + PATHS_WITH_ROOT_VALS_AND_CHILDREN_3.2, + |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + ); + } + + #[test] + fn test_root_values_n() { + checkn( + &PATHS_WITH_ROOT_VALS_AND_CHILDREN_N, + &[], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_xor_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + } + const FOR_MERKLEIZATION: Paths = &[ // X (&[0b100000, 0b00, 0b0001], 1), From dc6573baa520b809c23b0674d466611bfb982c75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Rze=C5=BAnicki?= Date: Fri, 12 Jun 2026 12:43:49 +0200 Subject: [PATCH 2/8] zipper_algebra: Implement majority-of-three using DNF zipper merge --- src/experimental/zipper_algebra.rs | 516 +++++++++++++++++++++-------- 1 file changed, 371 insertions(+), 145 deletions(-) diff --git a/src/experimental/zipper_algebra.rs b/src/experimental/zipper_algebra.rs index 56937e3..3ff2e9d 100644 --- a/src/experimental/zipper_algebra.rs +++ b/src/experimental/zipper_algebra.rs @@ -1894,6 +1894,46 @@ where zipper_merge_dnf_branch(clauses, ((1 << M) - 1), out); } +/// Computes the majority (2-of-3) combination of three zippers. +/// +/// A value is present in the result iff it is present in at least two +/// of the three inputs. +/// +/// Algebraically: +/// +/// ```text +/// maj(a, b, c) +/// = (a ∧ b) +/// ∨ (a ∧ c) +/// ∨ (b ∧ c) +/// ``` +/// +/// This operation is monotone and can be expressed as a Disjunctive +/// Normal Form (DNF) evaluated by [`zipper_merge_dnf`]. +/// +/// The implementation reuses the DNF merge engine rather than performing +/// a specialized traversal. +/// +/// This is the lattice analogue of the Boolean majority function. +pub fn zipper_majority(x: Z, y: Z, z: Z, out: &mut Out) +where + V: Lattice + Clone + Send + Sync + Unpin, + A: Allocator, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving + Clone, + Out: ZipperWriting, +{ + let x_1 = x.clone(); + let y_1 = y.clone(); + let z_1 = z.clone(); + let mut xy = [x, y]; + let mut xz = [x_1, z]; + let mut yz = [y_1, z_1]; + + let mut clauses = [xy.as_mut_slice(), xz.as_mut_slice(), yz.as_mut_slice()]; + + zipper_merge_dnf::(&mut clauses, out); +} + // ==================== JOIN ==================== struct Join; @@ -2647,8 +2687,8 @@ mod tests { 'x, T: IntoIterator, F: for<'a> FnOnce( - &mut ReadZipperUntracked<'a, 'x, u64>, - &mut ReadZipperUntracked<'a, 'x, u64>, + ReadZipperUntracked<'a, 'x, u64>, + ReadZipperUntracked<'a, 'x, u64>, &mut WriteZipperUntracked<'a, 'x, u64>, ), >( @@ -2664,7 +2704,7 @@ mod tests { let mut rhs = right.read_zipper(); let mut out = result.write_zipper(); - op(&mut lhs, &mut rhs, &mut out); + op(lhs, rhs, &mut out); assert_trie(expected.into_iter().copied(), result); } @@ -2673,9 +2713,9 @@ mod tests { 'x, T: IntoIterator, F: for<'a> FnOnce( - &mut ReadZipperUntracked<'a, 'x, u64>, - &mut ReadZipperUntracked<'a, 'x, u64>, - &mut ReadZipperUntracked<'a, 'x, u64>, + ReadZipperUntracked<'a, 'x, u64>, + ReadZipperUntracked<'a, 'x, u64>, + ReadZipperUntracked<'a, 'x, u64>, &mut WriteZipperUntracked<'a, 'x, u64>, ), >( @@ -2692,7 +2732,7 @@ mod tests { let mut rhs = right.read_zipper(); let mut out = result.write_zipper(); - op(&mut lhs, &mut mid, &mut rhs, &mut out); + op(lhs, mid, rhs, &mut out); assert_trie(expected.into_iter().copied(), result); } @@ -3215,7 +3255,7 @@ mod tests { check2( &DISJOINT_PATHS, &[DISJOINT_PATHS.0, DISJOINT_PATHS.1].concat(), - |lhs, rhs, out| lhs.join(rhs, out), + |mut lhs, mut rhs, out| lhs.join(&mut rhs, out), ); } @@ -3224,7 +3264,7 @@ mod tests { check3( &DISJOINT_PATHS_3, &[DISJOINT_PATHS_3.0, DISJOINT_PATHS_3.1, DISJOINT_PATHS_3.2].concat(), - |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_join3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -3250,7 +3290,7 @@ mod tests { check2( &PATHS_WITH_SHARED_PREFIX, &[PATHS_WITH_SHARED_PREFIX.0, PATHS_WITH_SHARED_PREFIX.1].concat(), - |lhs, rhs, out| lhs.join(rhs, out), + |mut lhs, mut rhs, out| lhs.join(&mut rhs, out), ); } @@ -3264,7 +3304,7 @@ mod tests { PATHS_WITH_SHARED_PREFIX_3.2, ] .concat(), - |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_join3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -3290,7 +3330,7 @@ mod tests { check2( &INTERLEAVING_PATHS, &[INTERLEAVING_PATHS.0, INTERLEAVING_PATHS.1].concat(), - |lhs, rhs, out| lhs.join(rhs, out), + |mut lhs, mut rhs, out| lhs.join(&mut rhs, out), ); } @@ -3304,7 +3344,7 @@ mod tests { INTERLEAVING_PATHS_3.2, ] .concat(), - |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_join3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -3327,9 +3367,11 @@ mod tests { #[test] fn test_one_side_empty_at_many_levels() { - check2(&ONE_SIDED_PATHS, ONE_SIDED_PATHS.0, |lhs, rhs, out| { - lhs.join(rhs, out) - }); + check2( + &ONE_SIDED_PATHS, + ONE_SIDED_PATHS.0, + |mut lhs, mut rhs, out| lhs.join(&mut rhs, out), + ); } #[test] @@ -3337,7 +3379,7 @@ mod tests { check3( &ONE_SIDED_PATHS_3, ONE_SIDED_PATHS_3.0, - |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_join3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -3355,7 +3397,7 @@ mod tests { check2( &ALMOST_IDENTICAL_PATHS, ALMOST_IDENTICAL_PATHS.0, - |lhs, rhs, out| lhs.join(rhs, out), + |mut lhs, mut rhs, out| lhs.join(&mut rhs, out), ); } @@ -3364,7 +3406,7 @@ mod tests { check3( &ALMOST_IDENTICAL_PATHS_3, ALMOST_IDENTICAL_PATHS_3.0, - |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_join3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -3379,8 +3421,12 @@ mod tests { #[test] fn test_one_side_empty() { - check2(&LHS_EMPTY, LHS_EMPTY.1, |lhs, rhs, out| lhs.join(rhs, out)); - check2(&RHS_EMPTY, RHS_EMPTY.0, |lhs, rhs, out| lhs.join(rhs, out)); + check2(&LHS_EMPTY, LHS_EMPTY.1, |mut lhs, mut rhs, out| { + lhs.join(&mut rhs, out) + }); + check2(&RHS_EMPTY, RHS_EMPTY.0, |mut lhs, mut rhs, out| { + lhs.join(&mut rhs, out) + }); } #[test] @@ -3388,17 +3434,17 @@ mod tests { check3( &LHS_EMPTY_3, &[LHS_EMPTY_3.1, LHS_EMPTY_3.2].concat(), - |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_join3(&mut lhs, &mut mid, &mut rhs, out), ); check3( &MID_EMPTY, &[MID_EMPTY.0, MID_EMPTY.2].concat(), - |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_join3(&mut lhs, &mut mid, &mut rhs, out), ); check3( &RHS_EMPTY_3, &[RHS_EMPTY_3.0, RHS_EMPTY_3.1].concat(), - |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_join3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -3454,7 +3500,7 @@ mod tests { check2( &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN, expected, - |lhs, rhs, out| lhs.join(rhs, out), + |mut lhs, mut rhs, out| lhs.join(&mut rhs, out), ); } @@ -3472,7 +3518,7 @@ mod tests { check3( &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_3, expected, - |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_join3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -3505,7 +3551,7 @@ mod tests { check2( &ZIGZAG_PATHS, &[ZIGZAG_PATHS.0, ZIGZAG_PATHS.1].concat(), - |lhs, rhs, out| lhs.join(rhs, out), + |mut lhs, mut rhs, out| lhs.join(&mut rhs, out), ); } @@ -3514,7 +3560,7 @@ mod tests { check3( &ZIGZAG_PATHS_3, &[ZIGZAG_PATHS_3.0, ZIGZAG_PATHS_3.1, ZIGZAG_PATHS_3.2].concat(), - |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_join3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -3523,7 +3569,7 @@ mod tests { check2( &PATHS_WITH_ROOT_VALS_AND_CHILDREN, PATHS_WITH_ROOT_VALS_AND_CHILDREN.0, - |lhs, rhs, out| lhs.join(rhs, out), + |mut lhs, mut rhs, out| lhs.join(&mut rhs, out), ); } @@ -3532,7 +3578,7 @@ mod tests { check3( &PATHS_WITH_ROOT_VALS_AND_CHILDREN_3, PATHS_WITH_ROOT_VALS_AND_CHILDREN_3.0, - |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_join3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -3555,15 +3601,15 @@ mod tests { #[test] fn test_disjoint() { - check2(&DISJOINT_PATHS, [], |lhs, rhs, out| { - lhs.meet(rhs, out); + check2(&DISJOINT_PATHS, [], |mut lhs, mut rhs, out| { + lhs.meet(&mut rhs, out); }); } #[test] fn test_disjoint3() { - check3(&DISJOINT_PATHS_3, [], |lhs, mid, rhs, out| { - zipper_meet3(lhs, mid, rhs, out); + check3(&DISJOINT_PATHS_3, [], |mut lhs, mut mid, mut rhs, out| { + zipper_meet3(&mut lhs, &mut mid, &mut rhs, out); }); } @@ -3578,16 +3624,20 @@ mod tests { #[test] fn test_deep_shared_prefix_then_split() { - check2(&PATHS_WITH_SHARED_PREFIX, [], |lhs, rhs, out| { - lhs.meet(rhs, out); + check2(&PATHS_WITH_SHARED_PREFIX, [], |mut lhs, mut rhs, out| { + lhs.meet(&mut rhs, out); }); } #[test] fn test_deep_shared_prefix_then_split3() { - check3(&PATHS_WITH_SHARED_PREFIX_3, [], |lhs, mid, rhs, out| { - zipper_meet3(lhs, mid, rhs, out); - }); + check3( + &PATHS_WITH_SHARED_PREFIX_3, + [], + |mut lhs, mut mid, mut rhs, out| { + zipper_meet3(&mut lhs, &mut mid, &mut rhs, out); + }, + ); } #[test] @@ -3601,16 +3651,20 @@ mod tests { #[test] fn test_interleaving_paths() { - check2(&INTERLEAVING_PATHS, [], |lhs, rhs, out| { - lhs.meet(rhs, out); + check2(&INTERLEAVING_PATHS, [], |mut lhs, mut rhs, out| { + lhs.meet(&mut rhs, out); }); } #[test] fn test_interleaving_paths3() { - check3(&INTERLEAVING_PATHS_3, [], |lhs, mid, rhs, out| { - zipper_meet3(lhs, mid, rhs, out); - }); + check3( + &INTERLEAVING_PATHS_3, + [], + |mut lhs, mut mid, mut rhs, out| { + zipper_meet3(&mut lhs, &mut mid, &mut rhs, out); + }, + ); } #[test] @@ -3629,17 +3683,21 @@ mod tests { (&[0x00, 0x01, 0x02, 0x03], 3), (&[0x01, 0x02, 0x03, 0x04, 0x05], 8), ]; - check2(&ONE_SIDED_PATHS, expected, |lhs, rhs, out| { - lhs.meet(rhs, out); + check2(&ONE_SIDED_PATHS, expected, |mut lhs, mut rhs, out| { + lhs.meet(&mut rhs, out); }); } #[test] fn test_one_side_empty_at_many_levels3() { let expected: Paths = &[(&[0x00], 0)]; - check3(&ONE_SIDED_PATHS_3, expected, |lhs, mid, rhs, out| { - zipper_meet3(lhs, mid, rhs, out); - }); + check3( + &ONE_SIDED_PATHS_3, + expected, + |mut lhs, mut mid, mut rhs, out| { + zipper_meet3(&mut lhs, &mut mid, &mut rhs, out); + }, + ); } #[test] @@ -3656,16 +3714,22 @@ mod tests { check2( &ALMOST_IDENTICAL_PATHS, ALMOST_IDENTICAL_PATHS.1, - |lhs, rhs, out| lhs.meet(rhs, out), + |mut lhs, mut rhs, out| { + lhs.meet(&mut rhs, out); + }, ); } #[test] fn test_almost_identical_paths3() { let expected: Paths = &[(b"abcdefg", 0), (b"1", 4), (b"4", 7), (b"5", 8)]; - check3(&ALMOST_IDENTICAL_PATHS_3, expected, |lhs, mid, rhs, out| { - zipper_meet3(lhs, mid, rhs, out); - }); + check3( + &ALMOST_IDENTICAL_PATHS_3, + expected, + |mut lhs, mut mid, mut rhs, out| { + zipper_meet3(&mut lhs, &mut mid, &mut rhs, out); + }, + ); } #[test] @@ -3680,20 +3744,24 @@ mod tests { #[test] fn test_one_side_empty() { - check2(&LHS_EMPTY, [], |lhs, rhs, out| lhs.meet(rhs, out)); - check2(&RHS_EMPTY, [], |lhs, rhs, out| lhs.meet(rhs, out)); + check2(&LHS_EMPTY, [], |mut lhs, mut rhs, out| { + lhs.meet(&mut rhs, out); + }); + check2(&RHS_EMPTY, [], |mut lhs, mut rhs, out| { + lhs.meet(&mut rhs, out); + }); } #[test] fn test_one_side_empty3() { - check3(&LHS_EMPTY_3, [], |lhs, mid, rhs, out| { - zipper_meet3(lhs, mid, rhs, out) + check3(&LHS_EMPTY_3, [], |mut lhs, mut mid, mut rhs, out| { + zipper_meet3(&mut lhs, &mut mid, &mut rhs, out); }); - check3(&MID_EMPTY, [], |lhs, mid, rhs, out| { - zipper_meet3(lhs, mid, rhs, out) + check3(&MID_EMPTY, [], |mut lhs, mut mid, mut rhs, out| { + zipper_meet3(&mut lhs, &mut mid, &mut rhs, out); }); - check3(&RHS_EMPTY_3, [], |lhs, mid, rhs, out| { - zipper_meet3(lhs, mid, rhs, out) + check3(&RHS_EMPTY_3, [], |mut lhs, mut mid, mut rhs, out| { + zipper_meet3(&mut lhs, &mut mid, &mut rhs, out); }); } @@ -3722,7 +3790,9 @@ mod tests { check2( &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN, expected, - |lhs, rhs, out| lhs.meet(rhs, out), + |mut lhs, mut rhs, out| { + lhs.meet(&mut rhs, out); + }, ); } @@ -3732,7 +3802,9 @@ mod tests { check3( &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_3, expected, - |lhs, mid, rhs, out| zipper_meet3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| { + zipper_meet3(&mut lhs, &mut mid, &mut rhs, out); + }, ); } @@ -3749,17 +3821,21 @@ mod tests { #[test] fn test_zigzag() { let expected: Paths = &[(&[2, 1], 2), (&[3], 3)]; - check2(&ZIGZAG_PATHS, expected, |lhs, rhs, out| { - lhs.meet(rhs, out); + check2(&ZIGZAG_PATHS, expected, |mut lhs, mut rhs, out| { + lhs.meet(&mut rhs, out); }); } #[test] fn test_zigzag3() { let expected: Paths = &[(&[2, 1], 2)]; - check3(&ZIGZAG_PATHS_3, expected, |lhs, mid, rhs, out| { - zipper_meet3(lhs, mid, rhs, out) - }); + check3( + &ZIGZAG_PATHS_3, + expected, + |mut lhs, mut mid, mut rhs, out| { + zipper_meet3(&mut lhs, &mut mid, &mut rhs, out); + }, + ); } #[test] @@ -3767,7 +3843,9 @@ mod tests { check2( &PATHS_WITH_ROOT_VALS_AND_CHILDREN, PATHS_WITH_ROOT_VALS_AND_CHILDREN.0, - |lhs, rhs, out| lhs.meet(rhs, out), + |mut lhs, mut rhs, out| { + lhs.meet(&mut rhs, out); + }, ); } @@ -3776,7 +3854,9 @@ mod tests { check3( &PATHS_WITH_ROOT_VALS_AND_CHILDREN_3, PATHS_WITH_ROOT_VALS_AND_CHILDREN.0, - |lhs, mid, rhs, out| zipper_meet3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| { + zipper_meet3(&mut lhs, &mut mid, &mut rhs, out); + }, ); } @@ -3799,9 +3879,13 @@ mod tests { #[test] fn test_disjoint() { - check2(&DISJOINT_PATHS, DISJOINT_PATHS.0, |lhs, rhs, out| { - lhs.subtract(rhs, out); - }); + check2( + &DISJOINT_PATHS, + DISJOINT_PATHS.0, + |mut lhs, mut rhs, out| { + lhs.subtract(&mut rhs, out); + }, + ); } #[test] @@ -3809,8 +3893,8 @@ mod tests { check3( &DISJOINT_PATHS_3, DISJOINT_PATHS_3.0, - |lhs, mid, rhs, out| { - zipper_subtract3(lhs, mid, rhs, out); + |mut lhs, mut mid, mut rhs, out| { + zipper_subtract3(&mut lhs, &mut mid, &mut rhs, out); }, ); } @@ -3829,7 +3913,9 @@ mod tests { check2( &PATHS_WITH_SHARED_PREFIX, PATHS_WITH_SHARED_PREFIX.0, - |lhs, rhs, out| lhs.subtract(rhs, out), + |mut lhs, mut rhs, out| { + lhs.subtract(&mut rhs, out); + }, ); } @@ -3838,8 +3924,8 @@ mod tests { check3( &PATHS_WITH_SHARED_PREFIX_3, PATHS_WITH_SHARED_PREFIX_3.0, - |lhs, mid, rhs, out| { - zipper_subtract3(lhs, mid, rhs, out); + |mut lhs, mut mid, mut rhs, out| { + zipper_subtract3(&mut lhs, &mut mid, &mut rhs, out); }, ); } @@ -3858,7 +3944,9 @@ mod tests { check2( &INTERLEAVING_PATHS, INTERLEAVING_PATHS.0, - |lhs, rhs, out| lhs.subtract(rhs, out), + |mut lhs, mut rhs, out| { + lhs.subtract(&mut rhs, out); + }, ); } @@ -3867,8 +3955,8 @@ mod tests { check3( &INTERLEAVING_PATHS_3, INTERLEAVING_PATHS_3.0, - |lhs, mid, rhs, out| { - zipper_subtract3(lhs, mid, rhs, out); + |mut lhs, mut mid, mut rhs, out| { + zipper_subtract3(&mut lhs, &mut mid, &mut rhs, out); }, ); } @@ -3895,8 +3983,8 @@ mod tests { (&[0x01, 0x02, 0x03, 0x04, 0x05], 8), (&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], 9), ]; - check2(&ONE_SIDED_PATHS, expected, |lhs, rhs, out| { - lhs.subtract(rhs, out) + check2(&ONE_SIDED_PATHS, expected, |mut lhs, mut rhs, out| { + lhs.subtract(&mut rhs, out); }); } @@ -3913,9 +4001,13 @@ mod tests { (&[0x01, 0x02, 0x03, 0x04, 0x05], 8), (&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], 9), ]; - check3(&ONE_SIDED_PATHS_3, expected, |lhs, mid, rhs, out| { - zipper_subtract3(lhs, mid, rhs, out); - }); + check3( + &ONE_SIDED_PATHS_3, + expected, + |mut lhs, mut mid, mut rhs, out| { + zipper_subtract3(&mut lhs, &mut mid, &mut rhs, out); + }, + ); } #[test] @@ -3939,16 +4031,24 @@ mod tests { #[test] fn test_almost_identical_paths() { let expected: Paths = &[(b"hijklmnop", 1), (b"2", 5), (b"3", 6)]; - check2(&ALMOST_IDENTICAL_PATHS, expected, |lhs, rhs, out| { - lhs.subtract(rhs, out) - }); + check2( + &ALMOST_IDENTICAL_PATHS, + expected, + |mut lhs, mut rhs, out| { + lhs.subtract(&mut rhs, out); + }, + ); } #[test] fn test_almost_identical_paths3() { - check3(&ALMOST_IDENTICAL_PATHS_3, [], |lhs, mid, rhs, out| { - zipper_subtract3(lhs, mid, rhs, out); - }); + check3( + &ALMOST_IDENTICAL_PATHS_3, + [], + |mut lhs, mut mid, mut rhs, out| { + zipper_subtract3(&mut lhs, &mut mid, &mut rhs, out); + }, + ); } #[test] @@ -3962,23 +4062,29 @@ mod tests { #[test] fn test_one_side_empty() { - check2(&LHS_EMPTY, [], |lhs, rhs, out| lhs.subtract(rhs, out)); - check2(&RHS_EMPTY, RHS_EMPTY.0, |lhs, rhs, out| { - lhs.subtract(rhs, out) + check2(&LHS_EMPTY, [], |mut lhs, mut rhs, out| { + lhs.subtract(&mut rhs, out); + }); + check2(&RHS_EMPTY, RHS_EMPTY.0, |mut lhs, mut rhs, out| { + lhs.subtract(&mut rhs, out); }); } #[test] fn test_one_side_empty3() { - check3(&LHS_EMPTY_3, [], |lhs, mid, rhs, out| { - zipper_subtract3(lhs, mid, rhs, out); - }); - check3(&MID_EMPTY, MID_EMPTY.0, |lhs, mid, rhs, out| { - zipper_subtract3(lhs, mid, rhs, out); + check3(&LHS_EMPTY_3, [], |mut lhs, mut mid, mut rhs, out| { + zipper_subtract3(&mut lhs, &mut mid, &mut rhs, out); }); - check3(&RHS_EMPTY_3, RHS_EMPTY_3.0, |lhs, mid, rhs, out| { - zipper_subtract3(lhs, mid, rhs, out); + check3(&MID_EMPTY, MID_EMPTY.0, |mut lhs, mut mid, mut rhs, out| { + zipper_subtract3(&mut lhs, &mut mid, &mut rhs, out); }); + check3( + &RHS_EMPTY_3, + RHS_EMPTY_3.0, + |mut lhs, mut mid, mut rhs, out| { + zipper_subtract3(&mut lhs, &mut mid, &mut rhs, out); + }, + ); } #[test] @@ -4005,7 +4111,9 @@ mod tests { check2( &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN, PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN.0, - |lhs, rhs, out| lhs.subtract(rhs, out), + |mut lhs, mut rhs, out| { + lhs.subtract(&mut rhs, out); + }, ); } @@ -4014,8 +4122,8 @@ mod tests { check3( &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_3, PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_3.0, - |lhs, mid, rhs, out| { - zipper_subtract3(lhs, mid, rhs, out); + |mut lhs, mut mid, mut rhs, out| { + zipper_subtract3(&mut lhs, &mut mid, &mut rhs, out); }, ); } @@ -4038,17 +4146,21 @@ mod tests { (&[4], 4), (&[4, 3, 2, 1], 5), ]; - check2(&ZIGZAG_PATHS, expected, |lhs, rhs, out| { - lhs.subtract(rhs, out) + check2(&ZIGZAG_PATHS, expected, |mut lhs, mut rhs, out| { + lhs.subtract(&mut rhs, out); }); } #[test] fn test_zigzag3() { let expected: Paths = &[(&[1, 1], 0), (&[3, 2, 1], 4), (&[4], 4)]; - check3(&ZIGZAG_PATHS_3, expected, |lhs, mid, rhs, out| { - zipper_subtract3(lhs, mid, rhs, out); - }); + check3( + &ZIGZAG_PATHS_3, + expected, + |mut lhs, mut mid, mut rhs, out| { + zipper_subtract3(&mut lhs, &mut mid, &mut rhs, out); + }, + ); } #[test] @@ -4056,7 +4168,9 @@ mod tests { check2( &PATHS_WITH_ROOT_VALS_AND_CHILDREN, PATHS_WITH_ROOT_VALS_AND_CHILDREN.0, - |lhs, rhs, out| lhs.subtract(rhs, out), + |mut lhs, mut rhs, out| { + lhs.subtract(&mut rhs, out); + }, ); } @@ -4065,8 +4179,8 @@ mod tests { check3( &PATHS_WITH_ROOT_VALS_AND_CHILDREN_3, PATHS_WITH_ROOT_VALS_AND_CHILDREN_3.0, - |lhs, mid, rhs, out| { - zipper_subtract3(lhs, mid, rhs, out); + |mut lhs, mut mid, mut rhs, out| { + zipper_subtract3(&mut lhs, &mut mid, &mut rhs, out); }, ); } @@ -4093,7 +4207,7 @@ mod tests { check2( &DISJOINT_PATHS, &[DISJOINT_PATHS.0, DISJOINT_PATHS.1].concat(), - |lhs, rhs, out| lhs.xor(rhs, out), + |mut lhs, mut rhs, out| lhs.xor(&mut rhs, out), ); } @@ -4102,7 +4216,7 @@ mod tests { check3( &DISJOINT_PATHS_3, &[DISJOINT_PATHS_3.0, DISJOINT_PATHS_3.1, DISJOINT_PATHS_3.2].concat(), - |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_xor3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -4128,7 +4242,7 @@ mod tests { check2( &PATHS_WITH_SHARED_PREFIX, &[PATHS_WITH_SHARED_PREFIX.0, PATHS_WITH_SHARED_PREFIX.1].concat(), - |lhs, rhs, out| lhs.xor(rhs, out), + |mut lhs, mut rhs, out| lhs.xor(&mut rhs, out), ); } @@ -4142,7 +4256,7 @@ mod tests { PATHS_WITH_SHARED_PREFIX_3.2, ] .concat(), - |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_xor3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -4168,7 +4282,7 @@ mod tests { check2( &INTERLEAVING_PATHS, &[INTERLEAVING_PATHS.0, INTERLEAVING_PATHS.1].concat(), - |lhs, rhs, out| lhs.xor(rhs, out), + |mut lhs, mut rhs, out| lhs.xor(&mut rhs, out), ); } @@ -4182,7 +4296,7 @@ mod tests { INTERLEAVING_PATHS_3.2, ] .concat(), - |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_xor3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -4214,8 +4328,8 @@ mod tests { (&[0x01, 0x02, 0x03, 0x04], 7), (&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], 9), ]; - check2(&ONE_SIDED_PATHS, expected, |lhs, rhs, out| { - lhs.xor(rhs, out) + check2(&ONE_SIDED_PATHS, expected, |mut lhs, mut rhs, out| { + lhs.xor(&mut rhs, out) }); } @@ -4230,9 +4344,11 @@ mod tests { (&[0x01, 0x02, 0x03, 0x04], 7), (&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], 9), ]; - check3(&ONE_SIDED_PATHS_3, expected, |lhs, mid, rhs, out| { - zipper_xor3(lhs, mid, rhs, out) - }); + check3( + &ONE_SIDED_PATHS_3, + expected, + |mut lhs, mut mid, mut rhs, out| zipper_xor3(&mut lhs, &mut mid, &mut rhs, out), + ); } #[test] @@ -4253,17 +4369,21 @@ mod tests { #[test] fn test_almost_identical_paths() { let expected: Paths = &[(b"hijklmnop", 1), (b"2", 5), (b"3", 6)]; - check2(&ALMOST_IDENTICAL_PATHS, expected, |lhs, rhs, out| { - lhs.xor(rhs, out) - }); + check2( + &ALMOST_IDENTICAL_PATHS, + expected, + |mut lhs, mut rhs, out| lhs.xor(&mut rhs, out), + ); } #[test] fn test_almost_identical_paths3() { let expected: Paths = &[(b"abcdefg", 0), (b"1", 4), (b"4", 7), (b"5", 8)]; - check3(&ALMOST_IDENTICAL_PATHS_3, expected, |lhs, mid, rhs, out| { - zipper_xor3(lhs, mid, rhs, out) - }); + check3( + &ALMOST_IDENTICAL_PATHS_3, + expected, + |mut lhs, mut mid, mut rhs, out| zipper_xor3(&mut lhs, &mut mid, &mut rhs, out), + ); } #[test] @@ -4278,8 +4398,12 @@ mod tests { #[test] fn test_one_side_empty() { - check2(&LHS_EMPTY, LHS_EMPTY.1, |lhs, rhs, out| lhs.xor(rhs, out)); - check2(&RHS_EMPTY, RHS_EMPTY.0, |lhs, rhs, out| lhs.xor(rhs, out)); + check2(&LHS_EMPTY, LHS_EMPTY.1, |mut lhs, mut rhs, out| { + lhs.xor(&mut rhs, out) + }); + check2(&RHS_EMPTY, RHS_EMPTY.0, |mut lhs, mut rhs, out| { + lhs.xor(&mut rhs, out) + }); } #[test] @@ -4287,17 +4411,17 @@ mod tests { check3( &LHS_EMPTY_3, &[LHS_EMPTY_3.1, LHS_EMPTY_3.2].concat(), - |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_xor3(&mut lhs, &mut mid, &mut rhs, out), ); check3( &MID_EMPTY, &[MID_EMPTY.0, MID_EMPTY.2].concat(), - |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_xor3(&mut lhs, &mut mid, &mut rhs, out), ); check3( &RHS_EMPTY_3, &[RHS_EMPTY_3.0, RHS_EMPTY_3.1].concat(), - |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_xor3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -4352,7 +4476,7 @@ mod tests { check2( &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN, expected, - |lhs, rhs, out| lhs.xor(rhs, out), + |mut lhs, mut rhs, out| lhs.xor(&mut rhs, out), ); } @@ -4370,7 +4494,7 @@ mod tests { check3( &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_3, expected, - |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_xor3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -4410,7 +4534,9 @@ mod tests { (&[3, 4], 4), (&[4, 3], 5), ]; - check2(&ZIGZAG_PATHS, expected, |lhs, rhs, out| lhs.xor(rhs, out)); + check2(&ZIGZAG_PATHS, expected, |mut lhs, mut rhs, out| { + lhs.xor(&mut rhs, out) + }); } #[test] @@ -4426,16 +4552,20 @@ mod tests { (&[3, 2, 1, 0], 3), (&[4, 3, 2, 1, 0], 4), ]; - check3(&ZIGZAG_PATHS_3, expected, |lhs, mid, rhs, out| { - zipper_xor3(lhs, mid, rhs, out) - }); + check3( + &ZIGZAG_PATHS_3, + expected, + |mut lhs, mut mid, mut rhs, out| zipper_xor3(&mut lhs, &mut mid, &mut rhs, out), + ); } #[test] fn test_root_values() { - check2(&PATHS_WITH_ROOT_VALS_AND_CHILDREN, &[], |lhs, rhs, out| { - lhs.xor(rhs, out) - }); + check2( + &PATHS_WITH_ROOT_VALS_AND_CHILDREN, + &[], + |mut lhs, mut rhs, out| lhs.xor(&mut rhs, out), + ); } #[test] @@ -4443,7 +4573,7 @@ mod tests { check3( &PATHS_WITH_ROOT_VALS_AND_CHILDREN_3, PATHS_WITH_ROOT_VALS_AND_CHILDREN_3.2, - |lhs, mid, rhs, out| zipper_xor3(lhs, mid, rhs, out), + |mut lhs, mut mid, mut rhs, out| zipper_xor3(&mut lhs, &mut mid, &mut rhs, out), ); } @@ -4754,4 +4884,100 @@ mod tests { assert_trie(expected, result); } } + + mod maj { + use super::*; + use crate::experimental::zipper_algebra::zipper_majority; + + #[test] + fn test_disjoint() { + check3(&DISJOINT_PATHS_3, [], |lhs, mid, rhs, out| { + zipper_majority(lhs, mid, rhs, out); + }); + } + + #[test] + fn test_deep_shared_prefix_then_split() { + check3(&PATHS_WITH_SHARED_PREFIX_3, [], |lhs, mid, rhs, out| { + zipper_majority(lhs, mid, rhs, out); + }); + } + + #[test] + fn test_interleaving_paths() { + check3(&INTERLEAVING_PATHS_3, [], |lhs, mid, rhs, out| { + zipper_majority(lhs, mid, rhs, out); + }); + } + + #[test] + fn test_one_side_empty_at_many_levels() { + let expected: Paths = &[ + (&[0x00], 0), + (&[0x00, 0x01, 0x02], 2), + (&[0x00, 0x01, 0x02, 0x03], 3), + (&[0x01, 0x02, 0x03, 0x04, 0x05], 8), + ]; + check3(&ONE_SIDED_PATHS_3, expected, |lhs, mid, rhs, out| { + zipper_majority(lhs, mid, rhs, out); + }); + } + + #[test] + fn test_almost_identical_paths() { + check3( + &ALMOST_IDENTICAL_PATHS_3, + ALMOST_IDENTICAL_PATHS_3.0, + |lhs, mid, rhs, out| { + zipper_majority(lhs, mid, rhs, out); + }, + ); + } + + #[test] + fn test_one_side_empty() { + check3(&LHS_EMPTY_3, [], |lhs, mid, rhs, out| { + zipper_majority(lhs, mid, rhs, out) + }); + check3(&MID_EMPTY, [], |lhs, mid, rhs, out| { + zipper_majority(lhs, mid, rhs, out) + }); + check3(&RHS_EMPTY_3, [], |lhs, mid, rhs, out| { + zipper_majority(lhs, mid, rhs, out) + }); + } + + #[test] + fn test_exact_overlap_divergent_subtries() { + let expected: Paths = &[(&[1, 2, 3], 0)]; + check3( + &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_3, + expected, + |lhs, mid, rhs, out| zipper_majority(lhs, mid, rhs, out), + ); + } + + #[test] + fn test_zigzag() { + let expected: Paths = &[ + (&[2, 1], 2), + (&[1], 0), + (&[2], 1), + (&[3], 3), + (&[4, 3, 2, 1], 5), + ]; + check3(&ZIGZAG_PATHS_3, expected, |lhs, mid, rhs, out| { + zipper_majority(lhs, mid, rhs, out) + }); + } + + #[test] + fn test_root_values() { + check3( + &PATHS_WITH_ROOT_VALS_AND_CHILDREN_3, + PATHS_WITH_ROOT_VALS_AND_CHILDREN.0, + |lhs, mid, rhs, out| zipper_majority(lhs, mid, rhs, out), + ); + } + } } From 9917e7d720f2a5c83def50cd280c84f90e0c3312 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Rze=C5=BAnicki?= Date: Thu, 18 Jun 2026 12:37:16 +0200 Subject: [PATCH 3/8] zipper_algebra: convert DNF engine to a much stronger abstraction than the original slice-of-slices version Everything is now expressed in terms of two bounded bitspaces --- src/experimental/zipper_algebra.rs | 237 +++++++++++++++-------------- 1 file changed, 122 insertions(+), 115 deletions(-) diff --git a/src/experimental/zipper_algebra.rs b/src/experimental/zipper_algebra.rs index 3ff2e9d..a04d258 100644 --- a/src/experimental/zipper_algebra.rs +++ b/src/experimental/zipper_algebra.rs @@ -1383,6 +1383,13 @@ fn only_active<'a, T, const N: usize>( active_bits::(active).map(|i| (i, &ts[i])) } +#[inline(always)] +fn first_active(ts: &[T; N], active: u64) -> &T { + debug_assert_ne!(active, 0); + let i0 = active.trailing_zeros() as usize; + &ts[i0] +} + #[inline(always)] fn first_active_mut(ts: &mut [T; N], active: u64) -> &mut T { debug_assert_ne!(active, 0); @@ -1390,6 +1397,35 @@ fn first_active_mut(ts: &mut [T; N], active: u64) -> &mut T { &mut ts[i0] } +#[inline(always)] +fn with_k( + xs: &mut [T], + mut bits: u64, + f: impl FnOnce([&mut T; K]) -> R, +) -> R { + debug_assert!(bits.count_ones() as usize >= K); + + // collect raw pointers first (safe) + let mut ptrs: [*mut T; K] = [std::ptr::null_mut(); K]; + + let mut i = 0; + while i < K { + let idx = bits.trailing_zeros() as usize; + bits &= bits - 1; + ptrs[i] = unsafe { xs.as_mut_ptr().add(idx) }; + i += 1; + } + + // SAFETY: + // - indices are distinct (bitmask) + // - derived from same slice + + // should be zero-cost after inlining + let refs = unsafe { ptrs.map(|p| &mut *p) }; + + f(refs) +} + // - The function is fully monomorphized over `Z` and `N` and uses a bitmask (`active`) // to track participating zippers. // - Small frontiers (`k ≤ 4`) are dispatched to specialized implementations @@ -1430,35 +1466,6 @@ where } } - #[inline(always)] - fn with_k( - xs: &mut [T], - mut bits: u64, - f: impl FnOnce([&mut T; K]) -> R, - ) -> R { - debug_assert!(bits.count_ones() as usize >= K); - - // collect raw pointers first (safe) - let mut ptrs: [*mut T; K] = [std::ptr::null_mut(); K]; - - let mut i = 0; - while i < K { - let idx = bits.trailing_zeros() as usize; - bits &= bits - 1; - ptrs[i] = unsafe { xs.as_mut_ptr().add(idx) }; - i += 1; - } - - // SAFETY: - // - indices are distinct (bitmask) - // - derived from same slice - - // should be zero-cost after inlining - let refs = unsafe { ptrs.map(|p| &mut *p) }; - - f(refs) - } - // check for node-sharing first if all_active_share(zs, active) { let z0 = first_active_mut(zs, active); @@ -1683,23 +1690,26 @@ where } } -pub fn zipper_merge_dnf(clauses: &mut [&mut [Z]; M], out: &mut Out) -where +pub fn zipper_merge_dnf( + zs: &mut [Z; N], + clauses: [u64; M], + out: &mut Out, +) where V: Lattice + Clone + Send + Sync + Unpin, A: Allocator, Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { #[inline(always)] - fn clause_mask(zs: &[Z]) -> ByteMask + fn clause_mask(zs: &[Z; N], members: u64) -> ByteMask where Z: Zipper, { - if zs.is_empty() { + if members == 0 { return ByteMask::EMPTY; }; - zs.iter() - .try_fold(ByteMask::FULL, |mut mask, z| { + only_active(zs, members) + .try_fold(ByteMask::FULL, |mut mask, (_, z)| { mask &= z.child_mask(); if mask.is_empty_mask() { None @@ -1711,27 +1721,32 @@ where } #[inline(always)] - fn clause_value(zs: &[Z]) -> Option + fn clause_value(zs: &[Z; N], members: u64) -> Option where V: Lattice + Clone, Z: ZipperValues, { - Meet::combine_n(zs.iter().map(|z| lift(z.val()))) + Meet::combine_n(only_active(zs, members).map(|(_, z)| lift(z.val()))) } - fn active_clauses_value(clauses: &[&mut [Z]; M], active: u64) -> Option + fn active_clauses_value( + zs: &[Z; N], + clauses: &[u64; M], + active: u64, + ) -> Option where V: Lattice + Clone, Z: ZipperValues, { Join::combine_n( - only_active(clauses, active).map(|(_, zs)| clause_value(zs).map(Cow::Owned)), + only_active(clauses, active).map(|(_, i)| clause_value(zs, *i).map(Cow::Owned)), ) } #[inline(always)] - fn compute_masks( - clauses: &[&mut [Z]; M], + fn compute_masks( + zs: &[Z; N], + clauses: &[u64; M], active: u64, clause_masks: &mut [ByteMask; M], ) -> ByteMask @@ -1740,18 +1755,19 @@ where { let mut global = ByteMask::EMPTY; - for (i, zs) in only_active(clauses, active) { - let m = clause_mask(zs); + for_each_bit(active, |i| { + let m = clause_mask(zs, clauses[i]); clause_masks[i] = m; global |= m; - } + }); global } - fn zipper_merge_dnf_branch( - clauses: &mut [&mut [Z]; M], + fn zipper_merge_dnf_branch( + zs: &mut [Z; N], + clauses: &[u64; M], active: u64, out: &mut Out, ) where @@ -1766,25 +1782,32 @@ where // Single clause fast path // ------------------------------------------------- if active.count_ones() == 1 { - let single_clause = first_active_mut(clauses, active); - match single_clause { - [z0] => { + let members = first_active(clauses, active); + match members.count_ones() { + 1 => { + let z0 = first_active_mut(zs, *members); if let Some(v) = z0.val() { out.set_val(v.clone()); } Meet::on_id(z0, 1, out); return; } - [z0, z1] => { - zipper_meet(z0, z1, out); + 2 => { + with_k::<2, _, _>(zs, *members, |[z0, z1]| { + zipper_meet(z0, z1, out); + }); return; } - [z0, z1, z2] => { - zipper_meet3(z0, z1, z2, out); + 3 => { + with_k::<3, _, _>(zs, *members, |[z0, z1, z2]| { + zipper_meet3(z0, z1, z2, out); + }); return; } - [z0, z1, z2, z3] => { - zipper_merge4::(z0, z1, z2, z3, out); + 4 => { + with_k::<4, _, _>(zs, *members, |[z0, z1, z2, z3]| { + zipper_merge4::(z0, z1, z2, z3, out); + }); return; } _ => {} // do nothing special @@ -1798,14 +1821,14 @@ where // Emit values // ------------------------------------------------- - if let Some(v) = active_clauses_value(clauses, active) { + if let Some(v) = active_clauses_value(zs, clauses, active) { out.set_val(v); } // ------------------------------------------------- // Compute clause masks // ------------------------------------------------- - let mut global = compute_masks(clauses, active, &mut clause_masks); + let mut global = compute_masks(zs, clauses, active, &mut clause_masks); let mut next = global.indexed_bit::(0); 'descend: loop { // ------------------------------------------------- @@ -1815,17 +1838,18 @@ where out.descend_to_byte(byte); let mut sub_active = 0u64; + let mut participating = 0u64; // descend participating clauses for_each_bit(active, |i| { if clause_masks[i].test_bit(byte) { sub_active |= 1 << i; - - for z in clauses[i].iter_mut() { - z.descend_to_byte(byte); - } + participating |= clauses[i]; } }); + for_each_bit(participating, |i| { + zs[i].descend_to_byte(byte); + }); // ------------------------------------------------- // Tail-descent fast path @@ -1834,11 +1858,11 @@ where if sub_active == active { depth += 1; - if let Some(v) = active_clauses_value(clauses, active) { + if let Some(v) = active_clauses_value(zs, clauses, active) { out.set_val(v); } - global = compute_masks(clauses, active, &mut clause_masks); + global = compute_masks(zs, clauses, active, &mut clause_masks); next = global.indexed_bit::(0); continue 'descend; } @@ -1847,13 +1871,11 @@ where // Branching recursion // ------------------------------------------------- - zipper_merge_dnf_branch(clauses, sub_active, out); + zipper_merge_dnf_branch(zs, clauses, sub_active, out); // ascend - for_each_bit(sub_active, |i| { - for z in clauses[i].iter_mut() { - z.ascend_byte(); - } + for_each_bit(participating, |i| { + zs[i].ascend_byte(); }); out.ascend_byte(); @@ -1868,30 +1890,33 @@ where break; } - let byte_from = first_active_mut(clauses, active) - .first() - .and_then(|z| z.path().last().copied()) + let byte_from = *first_active(zs, *first_active(clauses, active)) + .path() + .last() .expect("non-empty path at depth > 0"); + let mut active_zippers = 0; for_each_bit(active, |i| { - for z in clauses[i].iter_mut() { - z.ascend_byte(); - } + active_zippers |= clauses[i]; }); + for_each_bit(active_zippers, |i| { + zs[i].ascend_byte(); + }); out.ascend_byte(); depth -= 1; // recompute masks after ascent - global = compute_masks(clauses, active, &mut clause_masks); + global = compute_masks(zs, clauses, active, &mut clause_masks); // resume sibling traversal next = global.next_bit(byte_from); } } - debug_assert!(M > 0 && M <= 64); - zipper_merge_dnf_branch(clauses, ((1 << M) - 1), out); + assert!(N > 0 && N <= 64); + assert!(M > 0 && M <= 64); + zipper_merge_dnf_branch(zs, &clauses, ((1 << M) - 1), out); } /// Computes the majority (2-of-3) combination of three zippers. @@ -1919,19 +1944,16 @@ pub fn zipper_majority(x: Z, y: Z, z: Z, out: &mut Out) where V: Lattice + Clone + Send + Sync + Unpin, A: Allocator, - Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving + Clone, + Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { - let x_1 = x.clone(); - let y_1 = y.clone(); - let z_1 = z.clone(); - let mut xy = [x, y]; - let mut xz = [x_1, z]; - let mut yz = [y_1, z_1]; - - let mut clauses = [xy.as_mut_slice(), xz.as_mut_slice(), yz.as_mut_slice()]; + let clauses = [ + 0b011, // x ∧ y + 0b101, // x ∧ z + 0b110, // y ∧ z + ]; - zipper_merge_dnf::(&mut clauses, out); + zipper_merge_dnf(&mut [x, y, z], clauses, out); } // ==================== JOIN ==================== @@ -4731,7 +4753,7 @@ mod tests { let mut result = PathMap::new(); let mut out = result.write_zipper(); - zipper_merge_dnf(&mut [&mut [&mut z1, &mut z2, &mut z3]], &mut out); + zipper_merge_dnf(&mut [&mut z1, &mut z2, &mut z3], [0b111], &mut out); let mut expected = PathMap::new(); { @@ -4757,7 +4779,8 @@ mod tests { let mut result = PathMap::new(); let mut out = result.write_zipper(); zipper_merge_dnf( - &mut [&mut [&mut z1], &mut [&mut z2], &mut [&mut z3]], + &mut [&mut z1, &mut z2, &mut z3], + [0b001, 0b010, 0b100], &mut out, ); @@ -4778,18 +4801,13 @@ mod tests { let mut trie2 = PathMap::from_iter(SMALL_TRIE_2); let mut trie3 = PathMap::from_iter(SMALL_TRIE_3); + let mut z1 = trie1.read_zipper(); let mut z2 = trie2.read_zipper(); let mut z3 = trie3.read_zipper(); let mut result = PathMap::new(); let mut out = result.write_zipper(); - zipper_merge_dnf( - &mut [ - &mut [&mut trie1.read_zipper(), &mut z2], - &mut [&mut trie1.read_zipper(), &mut z3], - ], - &mut out, - ); + zipper_merge_dnf(&mut [z1, z2, z3], [0b011, 0b101], &mut out); let expected = trie1.meet(&trie2.join(&trie3)); assert_trie(expected, result); } @@ -4806,13 +4824,7 @@ mod tests { let mut result = PathMap::new(); let mut out = result.write_zipper(); - zipper_merge_dnf( - &mut [ - &mut [&mut z1, &mut trie2.read_zipper()], - &mut [&mut trie2.read_zipper(), &mut z3], - ], - &mut out, - ); + zipper_merge_dnf(&mut [z1, z2, z3], [0b011, 0b110], &mut out); let expected = trie2.meet(&trie1.join(&trie3)); assert_trie(expected, result); } @@ -4823,17 +4835,13 @@ mod tests { let mut trie2 = PathMap::from_iter(SMALL_TRIE_2); let mut trie3 = PathMap::from_iter(SMALL_TRIE_3); + let mut z1 = trie1.read_zipper(); + let mut z2 = trie2.read_zipper(); let mut z3 = trie3.read_zipper(); let mut result = PathMap::new(); let mut out = result.write_zipper(); - zipper_merge_dnf( - &mut [ - &mut [&mut trie1.read_zipper(), &mut trie2.read_zipper(), &mut z3], - &mut [&mut trie1.read_zipper(), &mut trie2.read_zipper()], - ], - &mut out, - ); + zipper_merge_dnf(&mut [z1, z2, z3], [0b111, 0b011], &mut out); let expected = trie2.meet(&trie1); assert_trie(expected, result); } @@ -4868,16 +4876,15 @@ mod tests { let a_shallow_chain = prefixed(&trie2.read_zipper(), a); let a_branching = prefixed(&trie3.read_zipper(), a); - let mut z3 = trie3.read_zipper(); - let mut result = PathMap::new(); let mut out = result.write_zipper(); zipper_merge_dnf( &mut [ - &mut [&mut a_deep_chain.read_zipper()], - &mut [&mut a_shallow_chain.read_zipper()], - &mut [&mut a_branching.read_zipper()], + a_deep_chain.read_zipper(), + a_shallow_chain.read_zipper(), + a_branching.read_zipper(), ], + [0b001, 0b010, 0b100], &mut out, ); let expected = a_deep_chain.join(&a_shallow_chain.join(&a_branching)); From ab62492ba25ba02fea1ef129a33e1fe55b3e3932 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Rze=C5=BAnicki?= Date: Thu, 18 Jun 2026 17:54:54 +0200 Subject: [PATCH 4/8] Add DNF zipper merge and clause-based expression representation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce zipper_merge_dnf(), a generalized trie merge capable of evaluating monotone Boolean expressions in Disjunctive Normal Form. The implementation represents a DNF as: (Clause0) ∨ (Clause1) ∨ ... where each Clause is a conjunction of input zippers encoded as a compact bitmask. This replaces the earlier slice-of-slices approach with a fixed universe of zippers plus clause membership masks. Highlights: * Add Clause newtype for conjunction representation. * Enforce clause validity at construction time. * Add clause![] helper macro for ergonomic DNF construction. * Share zipper traversal across clauses that reference the same input. * Perform iterative tail descent when all active clauses follow the same path. * Dispatch single-clause cases to specialized meet implementations. * Support arbitrary monotone DNF expressions, including majority and threshold functions. Example: (x ∧ y) ∨ (x ∧ z) ∨ (y ∧ z) can now be expressed as: [ clause![0, 1], clause![0, 2], clause![1, 2], ] and evaluated directly via zipper_merge_dnf(). --- src/experimental/zipper_algebra.rs | 436 ++++++++++++++++++++++++++--- 1 file changed, 404 insertions(+), 32 deletions(-) diff --git a/src/experimental/zipper_algebra.rs b/src/experimental/zipper_algebra.rs index a04d258..a0effde 100644 --- a/src/experimental/zipper_algebra.rs +++ b/src/experimental/zipper_algebra.rs @@ -1690,9 +1690,358 @@ where } } +/// A conjunction clause in a Disjunctive Normal Form (DNF) expression. +/// +/// A clause is represented as a bitmask over a fixed universe of `N` input +/// zippers. Bit `i` is set iff zipper `i` participates in the conjunction. +/// +/// For example, for `N = 4`: +/// +/// ```text +/// {0,2} => 0b0101 +/// {1,3} => 0b1010 +/// ``` +/// +/// The DNF expression +/// +/// ```text +/// (x₀ ∧ x₂) ∨ (x₁ ∧ x₃) +/// ``` +/// +/// can therefore be represented as: +/// +/// ```ignore +/// [ +/// Clause::<4>::from_indices([0, 2]), +/// Clause::<4>::from_indices([1, 3]), +/// ] +/// ``` +/// +/// All indices are validated at construction time, guaranteeing that no bit +/// outside the range `[0, N)` is ever set. +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub struct Clause { + members: u64, +} + +impl Clause { + pub const EMPTY: Self = Self { members: 0 }; + pub const FULL: Self = Self { + members: if N == 64 { u64::MAX } else { (1u64 << N) - 1 }, + }; + + /// Creates a clause from a raw bitmask. + /// + /// # Panics + /// + /// Panics if the mask references a zipper index greater than or equal to `N`. + /// + /// # Examples + /// + /// ```ignore + /// let clause = Clause::<4>::from_mask(0b0101); + /// ``` + #[inline] + pub const fn from_mask(mask: u64) -> Self { + assert!(N > 0 && N <= 64); + assert!(mask >> N == 0); + + Self { members: mask } + } + + /// Creates a clause containing exactly one zipper. + /// + /// # Examples + /// + /// ```ignore + /// let clause = Clause::<8>::singleton(3); + /// ``` + /// + /// corresponds to: + /// + /// ```text + /// x₃ + /// ``` + #[inline] + pub const fn singleton(i: usize) -> Self { + assert!(i < N); + Self::from_mask(1 << i) + } + + /// Creates a clause from a collection of zipper indices. + /// + /// # Examples + /// + /// ```ignore + /// let clause = Clause::<5>::new(&[0, 2, 4]); + /// ``` + /// + /// corresponds to: + /// + /// ```text + /// x₀ ∧ x₂ ∧ x₄ + /// ``` + /// + /// # Panics + /// + /// Panics if any index is greater than or equal to `N`. + #[inline] + pub fn new(indices: &[usize]) -> Self { + let mut mask = 0; + + for &i in indices { + assert!(i < N); + mask |= 1u64 << i; + } + + Self::from_mask(mask) + } + + pub const fn pair(i: usize, j: usize) -> Self { + assert!(i < N); + assert!(j < N); + + Self::from_mask((1u64 << i) | (1u64 << j)) + } + + /// Creates a clause from a collection of zipper indices. + /// + /// # Examples + /// + /// ```ignore + /// let clause = Clause::<5>::from_indices([0, 2, 4]); + /// ``` + /// + /// corresponds to: + /// + /// ```text + /// x₀ ∧ x₂ ∧ x₄ + /// ``` + /// + /// # Panics + /// + /// Panics if any index is greater than or equal to `N`. + pub const fn from_indices(indices: [usize; K]) -> Self { + let mut mask = 0u64; + let mut i = 0; + + while i < K { + let idx = indices[i]; + + assert!(idx < N); + + mask |= 1u64 << idx; + i += 1; + } + + Self::from_mask(mask) + } + + /// Returns the internal membership bitmask. + /// + /// Bit `i` is set iff zipper `i` participates in this clause. + /// + /// This operation is constant-time. + #[inline(always)] + pub const fn members(self) -> u64 { + debug_assert!(self.members >> N == 0); + self.members + } + + pub const fn len(self) -> usize { + self.members.count_ones() as usize + } + + pub const fn is_empty(self) -> bool { + self.members == 0 + } +} + +/// Constructs a [`Clause`] using zipper indices. +/// +/// # Examples +/// +/// ```ignore +/// let clause = clause![0, 2, 4]; +/// ``` +/// +/// which is equivalent to: +/// +/// ```ignore +/// Clause::::from_indices([0, 2, 4]) +/// ``` +/// +/// This macro is primarily intended for constructing DNF expressions in a +/// concise and readable form. +/// +/// ```ignore +/// let dnf = [ +/// clause![0, 1], +/// clause![0, 2], +/// clause![1, 2], +/// ]; +/// ``` +#[macro_export] +macro_rules! clause { + ($($i:expr),+ $(,)?) => { + $crate::Clause::from_indices([$($i),+]) + }; +} + +/// Constructs a DNF expression as an array of [`Clause`] values. +/// +/// Each inner bracket denotes a conjunction clause, specified by the +/// indices of the participating zippers. +/// +/// # Examples +/// +/// Majority-of-three: +/// +/// ```ignore +/// let dnf = dnf![ +/// [0, 1], +/// [0, 2], +/// [1, 2], +/// ]; +/// ``` +/// +/// corresponds to: +/// +/// ```text +/// (x₀ ∧ x₁) +/// ∨ (x₀ ∧ x₂) +/// ∨ (x₁ ∧ x₂) +/// ``` +/// +/// A larger example: +/// +/// ```ignore +/// let dnf = dnf![ +/// [0], +/// [1, 2], +/// [0, 3, 4], +/// ]; +/// ``` +/// +/// corresponds to: +/// +/// ```text +/// x₀ +/// ∨ (x₁ ∧ x₂) +/// ∨ (x₀ ∧ x₃ ∧ x₄) +/// ``` +/// +/// The resulting value can be passed directly to [`zipper_merge_dnf`]: +/// +/// ```ignore +/// let clauses = dnf![ +/// [0, 1], +/// [0, 2], +/// [1, 2], +/// ]; +/// +/// zipper_merge_dnf::<_, _, _, _, 3, 3>( +/// &mut [x, y, z], +/// clauses, +/// out, +/// ); +/// ``` +/// +/// # Expansion +/// +/// ```ignore +/// dnf![ +/// [0, 1], +/// [2, 3], +/// ] +/// ``` +/// +/// expands approximately to: +/// +/// ```ignore +/// [ +/// Clause::from_indices([0, 1]), +/// Clause::from_indices([2, 3]), +/// ] +/// ``` +#[macro_export] +macro_rules! dnf { + ( + $( + [$($idx:expr),*] + ),* $(,)? + ) => { + [ + $( + Clause::from_indices([$($idx),*]) + ),* + ] + }; +} + +/// Evaluates a monotone Boolean expression in Disjunctive Normal Form (DNF) +/// over a collection of input zippers. +/// +/// Each clause represents a conjunction (`Meet`) of selected input zippers, +/// while the final result is the disjunction (`Join`) of all clauses: +/// +/// ```text +/// (Clause₀) ∨ (Clause₁) ∨ ... ∨ (Clauseₘ) +/// ``` +/// +/// where each clause is interpreted as: +/// +/// ```text +/// xᵢ ∧ xⱼ ∧ ... +/// ``` +/// +/// The algorithm traverses the input tries simultaneously using zipper +/// operations and emits the resulting trie into `out`. +/// +/// # Example +/// +/// Majority-of-three can be expressed as: +/// +/// ```text +/// (x ∧ y) ∨ (x ∧ z) ∨ (y ∧ z) +/// ``` +/// +/// ```ignore +/// let clauses = [ +/// clause![0, 1], +/// clause![0, 2], +/// clause![1, 2], +/// ]; +/// +/// zipper_merge_dnf::<_, _, _, _, 3, 3>( +/// &mut [x, y, z], +/// clauses, +/// out, +/// ); +/// ``` +/// +/// # Complexity +/// +/// The traversal is output-sensitive and explores only trie regions that are +/// reachable through at least one active clause. +/// +/// Whenever all currently active clauses descend through the same byte, the +/// algorithm performs an iterative tail descent without recursion. Recursive +/// calls occur only when the active clause set splits. +/// +/// # Optimizations +/// +/// * Single active clauses are dispatched to specialized meet +/// implementations (`zipper_meet`, `zipper_meet3`, `zipper_merge4`, ...). +/// * Zippers shared between multiple clauses are descended only once. +/// * Long common paths are traversed iteratively without recursion. +/// +/// # Panics +/// +/// Panics if `N == 0`, `M == 0`, or either exceeds 64. pub fn zipper_merge_dnf( zs: &mut [Z; N], - clauses: [u64; M], + clauses: [Clause; M], out: &mut Out, ) where V: Lattice + Clone + Send + Sync + Unpin, @@ -1701,14 +2050,14 @@ pub fn zipper_merge_dnf( Out: ZipperWriting, { #[inline(always)] - fn clause_mask(zs: &[Z; N], members: u64) -> ByteMask + fn clause_mask(zs: &[Z; N], clause: &Clause) -> ByteMask where Z: Zipper, { - if members == 0 { + if clause.is_empty() { return ByteMask::EMPTY; }; - only_active(zs, members) + only_active(zs, clause.members()) .try_fold(ByteMask::FULL, |mut mask, (_, z)| { mask &= z.child_mask(); if mask.is_empty_mask() { @@ -1721,17 +2070,17 @@ pub fn zipper_merge_dnf( } #[inline(always)] - fn clause_value(zs: &[Z; N], members: u64) -> Option + fn clause_value(zs: &[Z; N], clause: &Clause) -> Option where V: Lattice + Clone, Z: ZipperValues, { - Meet::combine_n(only_active(zs, members).map(|(_, z)| lift(z.val()))) + Meet::combine_n(only_active(zs, clause.members()).map(|(_, z)| lift(z.val()))) } fn active_clauses_value( zs: &[Z; N], - clauses: &[u64; M], + clauses: &[Clause; M], active: u64, ) -> Option where @@ -1739,14 +2088,15 @@ pub fn zipper_merge_dnf( Z: ZipperValues, { Join::combine_n( - only_active(clauses, active).map(|(_, i)| clause_value(zs, *i).map(Cow::Owned)), + only_active(clauses, active) + .map(|(_, clause)| clause_value(zs, clause).map(Cow::Owned)), ) } #[inline(always)] fn compute_masks( zs: &[Z; N], - clauses: &[u64; M], + clauses: &[Clause; M], active: u64, clause_masks: &mut [ByteMask; M], ) -> ByteMask @@ -1756,7 +2106,7 @@ pub fn zipper_merge_dnf( let mut global = ByteMask::EMPTY; for_each_bit(active, |i| { - let m = clause_mask(zs, clauses[i]); + let m = clause_mask(zs, &clauses[i]); clause_masks[i] = m; global |= m; @@ -1767,7 +2117,7 @@ pub fn zipper_merge_dnf( fn zipper_merge_dnf_branch( zs: &mut [Z; N], - clauses: &[u64; M], + clauses: &[Clause; M], active: u64, out: &mut Out, ) where @@ -1782,10 +2132,10 @@ pub fn zipper_merge_dnf( // Single clause fast path // ------------------------------------------------- if active.count_ones() == 1 { - let members = first_active(clauses, active); - match members.count_ones() { + let single_clause = first_active(clauses, active); + match single_clause.len() { 1 => { - let z0 = first_active_mut(zs, *members); + let z0 = first_active_mut(zs, single_clause.members()); if let Some(v) = z0.val() { out.set_val(v.clone()); } @@ -1793,19 +2143,19 @@ pub fn zipper_merge_dnf( return; } 2 => { - with_k::<2, _, _>(zs, *members, |[z0, z1]| { + with_k::<2, _, _>(zs, single_clause.members(), |[z0, z1]| { zipper_meet(z0, z1, out); }); return; } 3 => { - with_k::<3, _, _>(zs, *members, |[z0, z1, z2]| { + with_k::<3, _, _>(zs, single_clause.members(), |[z0, z1, z2]| { zipper_meet3(z0, z1, z2, out); }); return; } 4 => { - with_k::<4, _, _>(zs, *members, |[z0, z1, z2, z3]| { + with_k::<4, _, _>(zs, single_clause.members(), |[z0, z1, z2, z3]| { zipper_merge4::(z0, z1, z2, z3, out); }); return; @@ -1844,7 +2194,7 @@ pub fn zipper_merge_dnf( for_each_bit(active, |i| { if clause_masks[i].test_bit(byte) { sub_active |= 1 << i; - participating |= clauses[i]; + participating |= clauses[i].members(); } }); for_each_bit(participating, |i| { @@ -1890,14 +2240,14 @@ pub fn zipper_merge_dnf( break; } - let byte_from = *first_active(zs, *first_active(clauses, active)) + let byte_from = *first_active(zs, first_active(clauses, active).members()) .path() .last() .expect("non-empty path at depth > 0"); let mut active_zippers = 0; for_each_bit(active, |i| { - active_zippers |= clauses[i]; + active_zippers |= clauses[i].members(); }); for_each_bit(active_zippers, |i| { @@ -1947,13 +2297,13 @@ where Z: ZipperInfallibleSubtries + ZipperConcrete + ZipperMoving, Out: ZipperWriting, { - let clauses = [ - 0b011, // x ∧ y - 0b101, // x ∧ z - 0b110, // y ∧ z + const MAJORITY: [Clause<3>; 3] = [ + Clause::from_mask(0b011), // x ∧ y + Clause::from_mask(0b101), // x ∧ z + Clause::from_mask(0b110), // y ∧ z ]; - zipper_merge_dnf(&mut [x, y, z], clauses, out); + zipper_merge_dnf(&mut [x, y, z], MAJORITY, out); } // ==================== JOIN ==================== @@ -4733,7 +5083,9 @@ mod tests { } mod dnf { - use crate::experimental::zipper_algebra::{zipper_join3, zipper_meet3, zipper_merge_dnf}; + use crate::experimental::zipper_algebra::{ + Clause, zipper_join3, zipper_meet3, zipper_merge_dnf, + }; use super::*; @@ -4753,7 +5105,7 @@ mod tests { let mut result = PathMap::new(); let mut out = result.write_zipper(); - zipper_merge_dnf(&mut [&mut z1, &mut z2, &mut z3], [0b111], &mut out); + zipper_merge_dnf(&mut [&mut z1, &mut z2, &mut z3], [Clause::FULL], &mut out); let mut expected = PathMap::new(); { @@ -4780,7 +5132,11 @@ mod tests { let mut out = result.write_zipper(); zipper_merge_dnf( &mut [&mut z1, &mut z2, &mut z3], - [0b001, 0b010, 0b100], + [ + Clause::singleton(0), + Clause::singleton(1), + Clause::singleton(2), + ], &mut out, ); @@ -4807,7 +5163,11 @@ mod tests { let mut result = PathMap::new(); let mut out = result.write_zipper(); - zipper_merge_dnf(&mut [z1, z2, z3], [0b011, 0b101], &mut out); + zipper_merge_dnf( + &mut [z1, z2, z3], + [Clause::from_mask(0b011), Clause::from_mask(0b101)], + &mut out, + ); let expected = trie1.meet(&trie2.join(&trie3)); assert_trie(expected, result); } @@ -4824,7 +5184,11 @@ mod tests { let mut result = PathMap::new(); let mut out = result.write_zipper(); - zipper_merge_dnf(&mut [z1, z2, z3], [0b011, 0b110], &mut out); + zipper_merge_dnf( + &mut [z1, z2, z3], + [Clause::from_mask(0b011), Clause::from_mask(0b110)], + &mut out, + ); let expected = trie2.meet(&trie1.join(&trie3)); assert_trie(expected, result); } @@ -4841,7 +5205,11 @@ mod tests { let mut result = PathMap::new(); let mut out = result.write_zipper(); - zipper_merge_dnf(&mut [z1, z2, z3], [0b111, 0b011], &mut out); + zipper_merge_dnf( + &mut [z1, z2, z3], + [Clause::FULL, Clause::from_indices([0, 1])], + &mut out, + ); let expected = trie2.meet(&trie1); assert_trie(expected, result); } @@ -4884,7 +5252,11 @@ mod tests { a_shallow_chain.read_zipper(), a_branching.read_zipper(), ], - [0b001, 0b010, 0b100], + [ + Clause::singleton(0), + Clause::singleton(1), + Clause::singleton(2), + ], &mut out, ); let expected = a_deep_chain.join(&a_shallow_chain.join(&a_branching)); From 0c8ca2507bef407973eafcfe33037d0085fc358e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Rze=C5=BAnicki?= Date: Wed, 24 Jun 2026 16:38:28 +0200 Subject: [PATCH 5/8] zipper_algebra: Switching `active_bits()` to iterate over set bits directly --- src/experimental/zipper_algebra.rs | 56 +++++++++++++++++------------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/src/experimental/zipper_algebra.rs b/src/experimental/zipper_algebra.rs index a0effde..8a294e2 100644 --- a/src/experimental/zipper_algebra.rs +++ b/src/experimental/zipper_algebra.rs @@ -1362,25 +1362,33 @@ where } // small micro-helpers -#[inline(always)] -fn for_each_bit(mut bits: u64, mut f: impl FnMut(usize)) { - while bits != 0 { - let i = bits.trailing_zeros() as usize; - bits &= bits - 1; - f(i); +struct ActiveBits(u64); + +impl Iterator for ActiveBits { + type Item = usize; + + #[inline] + fn next(&mut self) -> Option { + if self.0 == 0 { + return None; + } + + let i = self.0.trailing_zeros() as usize; + self.0 &= self.0 - 1; + Some(i) } } -#[inline] -fn active_bits(active: u64) -> impl Iterator { - (0..N).filter(move |i| (active >> i) & 1 != 0) +#[inline(always)] +fn active_bits(active: u64) -> ActiveBits { + ActiveBits(active) } fn only_active<'a, T, const N: usize>( ts: &'a [T; N], active: u64, ) -> impl Iterator { - active_bits::(active).map(|i| (i, &ts[i])) + active_bits(active).map(|i| (i, &ts[i])) } #[inline(always)] @@ -1518,7 +1526,7 @@ where let mut frontier = 0u64; let mut next = None; - for i in active_bits::(active) { + for i in active_bits(active) { if let Some(b) = bytes[i] { match min { None => { @@ -1558,7 +1566,7 @@ where out.descend_to_byte(a); // descend and refresh masks and indices - for_each_bit(active, |i| { + active_bits(active).for_each(|i| { zs[i].descend_to_byte(a); }); @@ -1570,7 +1578,7 @@ where } P::on_id(z0, cnt, out); - for_each_bit(active, |i| { + active_bits(active).for_each(|i| { zs[i].ascend_byte(); bytes[i] = masks[i].next_bit(a); }); @@ -1582,7 +1590,7 @@ where out.set_val(v); } - for_each_bit(active, |i| { + active_bits(active).for_each(|i| { masks[i] = zs[i].child_mask(); bytes[i] = masks[i].indexed_bit::(0); }); @@ -1645,13 +1653,13 @@ where }), _ => { // descend all active in the frontier - for_each_bit(frontier, |i| zs[i].descend_to_byte(a)); + active_bits(frontier).for_each(|i| zs[i].descend_to_byte(a)); // recursive call with SAME array, smaller mask zipper_merge_n_mono::(zs, frontier, out); //ascend - for_each_bit(frontier, |i| { + active_bits(frontier).for_each(|i| { zs[i].ascend_byte(); }); } @@ -1661,7 +1669,7 @@ where } // advance indices - for_each_bit(frontier, |i| { + active_bits(frontier).for_each(|i| { bytes[i] = masks[i].next_bit(a); }); } @@ -1678,7 +1686,7 @@ where .expect("non-empty path when k > 0"); // ascend - for_each_bit(active, |i| { + active_bits(active).for_each(|i| { let mut z = &mut zs[i]; z.ascend_byte(); masks[i] = z.child_mask(); @@ -2105,7 +2113,7 @@ pub fn zipper_merge_dnf( { let mut global = ByteMask::EMPTY; - for_each_bit(active, |i| { + active_bits(active).for_each(|i| { let m = clause_mask(zs, &clauses[i]); clause_masks[i] = m; @@ -2191,13 +2199,13 @@ pub fn zipper_merge_dnf( let mut participating = 0u64; // descend participating clauses - for_each_bit(active, |i| { + active_bits(active).for_each(|i| { if clause_masks[i].test_bit(byte) { sub_active |= 1 << i; participating |= clauses[i].members(); } }); - for_each_bit(participating, |i| { + active_bits(participating).for_each(|i| { zs[i].descend_to_byte(byte); }); @@ -2224,7 +2232,7 @@ pub fn zipper_merge_dnf( zipper_merge_dnf_branch(zs, clauses, sub_active, out); // ascend - for_each_bit(participating, |i| { + active_bits(participating).for_each(|i| { zs[i].ascend_byte(); }); @@ -2246,11 +2254,11 @@ pub fn zipper_merge_dnf( .expect("non-empty path at depth > 0"); let mut active_zippers = 0; - for_each_bit(active, |i| { + active_bits(active).for_each(|i| { active_zippers |= clauses[i].members(); }); - for_each_bit(active_zippers, |i| { + active_bits(active_zippers).for_each(|i| { zs[i].ascend_byte(); }); out.ascend_byte(); From bb6298df23886b71339680b19a60786e2b3dc3fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Rze=C5=BAnicki?= Date: Wed, 24 Jun 2026 17:49:49 +0200 Subject: [PATCH 6/8] Revert "zipper_algebra: Switching `active_bits()` to iterate over set bits directly" This reverts commit 0c8ca2507bef407973eafcfe33037d0085fc358e. --- src/experimental/zipper_algebra.rs | 56 +++++++++++++----------------- 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/src/experimental/zipper_algebra.rs b/src/experimental/zipper_algebra.rs index 8a294e2..a0effde 100644 --- a/src/experimental/zipper_algebra.rs +++ b/src/experimental/zipper_algebra.rs @@ -1362,33 +1362,25 @@ where } // small micro-helpers -struct ActiveBits(u64); - -impl Iterator for ActiveBits { - type Item = usize; - - #[inline] - fn next(&mut self) -> Option { - if self.0 == 0 { - return None; - } - - let i = self.0.trailing_zeros() as usize; - self.0 &= self.0 - 1; - Some(i) +#[inline(always)] +fn for_each_bit(mut bits: u64, mut f: impl FnMut(usize)) { + while bits != 0 { + let i = bits.trailing_zeros() as usize; + bits &= bits - 1; + f(i); } } -#[inline(always)] -fn active_bits(active: u64) -> ActiveBits { - ActiveBits(active) +#[inline] +fn active_bits(active: u64) -> impl Iterator { + (0..N).filter(move |i| (active >> i) & 1 != 0) } fn only_active<'a, T, const N: usize>( ts: &'a [T; N], active: u64, ) -> impl Iterator { - active_bits(active).map(|i| (i, &ts[i])) + active_bits::(active).map(|i| (i, &ts[i])) } #[inline(always)] @@ -1526,7 +1518,7 @@ where let mut frontier = 0u64; let mut next = None; - for i in active_bits(active) { + for i in active_bits::(active) { if let Some(b) = bytes[i] { match min { None => { @@ -1566,7 +1558,7 @@ where out.descend_to_byte(a); // descend and refresh masks and indices - active_bits(active).for_each(|i| { + for_each_bit(active, |i| { zs[i].descend_to_byte(a); }); @@ -1578,7 +1570,7 @@ where } P::on_id(z0, cnt, out); - active_bits(active).for_each(|i| { + for_each_bit(active, |i| { zs[i].ascend_byte(); bytes[i] = masks[i].next_bit(a); }); @@ -1590,7 +1582,7 @@ where out.set_val(v); } - active_bits(active).for_each(|i| { + for_each_bit(active, |i| { masks[i] = zs[i].child_mask(); bytes[i] = masks[i].indexed_bit::(0); }); @@ -1653,13 +1645,13 @@ where }), _ => { // descend all active in the frontier - active_bits(frontier).for_each(|i| zs[i].descend_to_byte(a)); + for_each_bit(frontier, |i| zs[i].descend_to_byte(a)); // recursive call with SAME array, smaller mask zipper_merge_n_mono::(zs, frontier, out); //ascend - active_bits(frontier).for_each(|i| { + for_each_bit(frontier, |i| { zs[i].ascend_byte(); }); } @@ -1669,7 +1661,7 @@ where } // advance indices - active_bits(frontier).for_each(|i| { + for_each_bit(frontier, |i| { bytes[i] = masks[i].next_bit(a); }); } @@ -1686,7 +1678,7 @@ where .expect("non-empty path when k > 0"); // ascend - active_bits(active).for_each(|i| { + for_each_bit(active, |i| { let mut z = &mut zs[i]; z.ascend_byte(); masks[i] = z.child_mask(); @@ -2113,7 +2105,7 @@ pub fn zipper_merge_dnf( { let mut global = ByteMask::EMPTY; - active_bits(active).for_each(|i| { + for_each_bit(active, |i| { let m = clause_mask(zs, &clauses[i]); clause_masks[i] = m; @@ -2199,13 +2191,13 @@ pub fn zipper_merge_dnf( let mut participating = 0u64; // descend participating clauses - active_bits(active).for_each(|i| { + for_each_bit(active, |i| { if clause_masks[i].test_bit(byte) { sub_active |= 1 << i; participating |= clauses[i].members(); } }); - active_bits(participating).for_each(|i| { + for_each_bit(participating, |i| { zs[i].descend_to_byte(byte); }); @@ -2232,7 +2224,7 @@ pub fn zipper_merge_dnf( zipper_merge_dnf_branch(zs, clauses, sub_active, out); // ascend - active_bits(participating).for_each(|i| { + for_each_bit(participating, |i| { zs[i].ascend_byte(); }); @@ -2254,11 +2246,11 @@ pub fn zipper_merge_dnf( .expect("non-empty path at depth > 0"); let mut active_zippers = 0; - active_bits(active).for_each(|i| { + for_each_bit(active, |i| { active_zippers |= clauses[i].members(); }); - active_bits(active_zippers).for_each(|i| { + for_each_bit(active_zippers, |i| { zs[i].ascend_byte(); }); out.ascend_byte(); From 389adfcd3951accf8b7cfe3d70ea9c1c142e36fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Rze=C5=BAnicki?= Date: Thu, 25 Jun 2026 16:19:55 +0200 Subject: [PATCH 7/8] One more attempt to optimize `active_bits` iterations --- src/experimental/zipper_algebra.rs | 53 +++++++++++++++++++----------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/src/experimental/zipper_algebra.rs b/src/experimental/zipper_algebra.rs index a0effde..0f4fab2 100644 --- a/src/experimental/zipper_algebra.rs +++ b/src/experimental/zipper_algebra.rs @@ -1371,16 +1371,30 @@ fn for_each_bit(mut bits: u64, mut f: impl FnMut(usize)) { } } -#[inline] -fn active_bits(active: u64) -> impl Iterator { - (0..N).filter(move |i| (active >> i) & 1 != 0) +struct ActiveRefs<'a, T, const N: usize> { + bits: u64, + xs: &'a [T; N], } -fn only_active<'a, T, const N: usize>( - ts: &'a [T; N], - active: u64, -) -> impl Iterator { - active_bits::(active).map(|i| (i, &ts[i])) +impl<'a, T, const N: usize> Iterator for ActiveRefs<'a, T, N> { + type Item = &'a T; + + #[inline(always)] + fn next(&mut self) -> Option { + if self.bits == 0 { + return None; + } + + let i = self.bits.trailing_zeros() as usize; + self.bits &= self.bits - 1; + + Some(&self.xs[i]) + } +} + +fn active_refs(xs: &[T; N], bits: u64) -> ActiveRefs { + assert!(bits >> N == 0); + ActiveRefs { bits, xs } } #[inline(always)] @@ -1452,14 +1466,14 @@ where V: Clone + 'a, Z: ZipperValues, { - only_active(zs, active).map(|(_, z)| lift(z.val())) + active_refs(zs, active).map(|z| lift(z.val())) } fn all_active_share(zs: &[Z; N], active: u64) -> bool where Z: ZipperConcrete, { - let mut iter = only_active(zs, active).map(|(_, z)| z.shared_node_id()); + let mut iter = active_refs(zs, active).map(|z| z.shared_node_id()); match iter.next() { Some(Some(first)) => iter.all(|next| next.is_some_and(|snid| snid == first)), _ => false, @@ -1484,10 +1498,10 @@ where let mut bytes = [None; N]; let mut masks = [ByteMask::EMPTY; N]; - for (i, z) in only_active(zs, active) { - masks[i] = z.child_mask(); + for_each_bit(active, |i| { + masks[i] = zs[i].child_mask(); bytes[i] = masks[i].indexed_bit::(0); - } + }); // At each node, the algorithm: // @@ -1518,7 +1532,7 @@ where let mut frontier = 0u64; let mut next = None; - for i in active_bits::(active) { + for_each_bit(active, |i| { if let Some(b) = bytes[i] { match min { None => { @@ -1541,7 +1555,7 @@ where } } } - } + }); debug_assert!(frontier <= active); @@ -2057,8 +2071,8 @@ pub fn zipper_merge_dnf( if clause.is_empty() { return ByteMask::EMPTY; }; - only_active(zs, clause.members()) - .try_fold(ByteMask::FULL, |mut mask, (_, z)| { + active_refs(zs, clause.members()) + .try_fold(ByteMask::FULL, |mut mask, z| { mask &= z.child_mask(); if mask.is_empty_mask() { None @@ -2075,7 +2089,7 @@ pub fn zipper_merge_dnf( V: Lattice + Clone, Z: ZipperValues, { - Meet::combine_n(only_active(zs, clause.members()).map(|(_, z)| lift(z.val()))) + Meet::combine_n(active_refs(zs, clause.members()).map(|z| lift(z.val()))) } fn active_clauses_value( @@ -2088,8 +2102,7 @@ pub fn zipper_merge_dnf( Z: ZipperValues, { Join::combine_n( - only_active(clauses, active) - .map(|(_, clause)| clause_value(zs, clause).map(Cow::Owned)), + active_refs(clauses, active).map(|clause| clause_value(zs, clause).map(Cow::Owned)), ) } From 1e43eee0e94c072ca0011b43952e2e7bbca99456 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Rze=C5=BAnicki?= Date: Thu, 25 Jun 2026 19:46:05 +0200 Subject: [PATCH 8/8] zipper_algebra: micro-optimization: use `assert_unchecked` --- src/experimental/zipper_algebra.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/experimental/zipper_algebra.rs b/src/experimental/zipper_algebra.rs index 0f4fab2..cb10802 100644 --- a/src/experimental/zipper_algebra.rs +++ b/src/experimental/zipper_algebra.rs @@ -1557,7 +1557,10 @@ where } }); - debug_assert!(frontier <= active); + unsafe { + // SAFETY: frontier can be at most (1 | 1 << 1 | 1 << 2 | .. | 1 << (popcount(active))) + std::hint::assert_unchecked(frontier <= active); + } match min { None => { @@ -2210,6 +2213,11 @@ pub fn zipper_merge_dnf( participating |= clauses[i].members(); } }); + unsafe { + // SAFETY: The value of particpating preserves the invariant as c_1 | .. | c_i, + // where c_x >> N == 0 + std::hint::assert_unchecked(participating >> N == 0); + } for_each_bit(participating, |i| { zs[i].descend_to_byte(byte); });