From 2b1d73596b6e364f5e7f0cac57a5379c1115a79e Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Fri, 26 Jun 2026 10:15:04 +0100 Subject: [PATCH 1/2] fix(warp): apply reference-grid jitter on a float grid and isolate RNG Warp.get_reference_grid built the grid from torch.arange (integer dtype), assigned self.ref_grid = grid.to(ddf) before the jitter block, then jittered the original integer grid in place. Three defects resulted: the jitter was applied to a dead local and never returned; torch.rand_like on the Long grid raised NotImplementedError, so jitter=True crashed outright; and fork_rng(enabled=seed) disabled RNG forking whenever seed defaulted to 0, leaking the seeded state into the global RNG. Cast the grid to ddf first, jitter that float tensor, assign it to self.ref_grid after jittering, and use fork_rng() so the seeded draw is isolated from the global RNG. Signed-off-by: Soumya Snigdha Kundu --- monai/networks/blocks/warp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index ddd3a350d5..0721db1502 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -121,12 +121,13 @@ def get_reference_grid(self, ddf: torch.Tensor, jitter: bool = False, seed: int mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]] grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...) grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...) - self.ref_grid = grid.to(ddf) + grid = grid.to(ddf) if jitter: # Define reference grid on non-integer values - with torch.random.fork_rng(enabled=seed): + with torch.random.fork_rng(): torch.random.manual_seed(seed) grid += torch.rand_like(grid) + self.ref_grid = grid self.ref_grid.requires_grad = False return self.ref_grid From e3d9aebf96d9cb25a3d82371bd61c3d881984768 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Fri, 26 Jun 2026 10:15:04 +0100 Subject: [PATCH 2/2] tests: add Warp reference-grid jitter regression test Asserts the jittered grid is floating point with non-integer values, the unjittered grid stays integer valued, and jitter is reproducible per seed. Fails before the fix with NotImplementedError on torch.rand_like. Signed-off-by: Soumya Snigdha Kundu --- tests/networks/blocks/warp/test_warp.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/networks/blocks/warp/test_warp.py b/tests/networks/blocks/warp/test_warp.py index 93af559790..09c6b2be4b 100644 --- a/tests/networks/blocks/warp/test_warp.py +++ b/tests/networks/blocks/warp/test_warp.py @@ -138,6 +138,21 @@ def test_ill_shape(self): with self.assertRaisesRegex(ValueError, ""): warp_layer(image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 3, 3)) + def test_jitter(self): + ddf = torch.zeros(1, 2, 4, 5) + grid = Warp(jitter=True).get_reference_grid(ddf, jitter=True, seed=0) + self.assertTrue(grid.is_floating_point()) + self.assertTrue(bool((grid != grid.round()).any())) + + ref = Warp().get_reference_grid(ddf, jitter=False) + self.assertTrue(bool((ref == ref.round()).all())) + + same = Warp().get_reference_grid(ddf, jitter=True, seed=7) + repeat = Warp().get_reference_grid(ddf, jitter=True, seed=7) + other = Warp().get_reference_grid(ddf, jitter=True, seed=8) + self.assertTrue(torch.equal(same, repeat)) + self.assertFalse(torch.equal(same, other)) + def test_grad(self): for b in GridSampleMode: for p in GridSamplePadMode: