Skip to content

Commit e815c23

Browse files
committed
Fix RandGridDistortiond crash when transform is skipped
When _do_transform is False, convert_to_tensor was called on the entire data dict, which fails when non-tensor values (e.g. ints, strings) are present — causing "AttributeError: 'int' object has no attribute 'numel'" in the DataLoader collate function. Convert only the keyed tensor items instead, consistent with how other dict transforms handle the no-transform case. Fixes #8604
1 parent 5b71547 commit e815c23

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

monai/transforms/spatial/dictionary.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2305,8 +2305,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
23052305
d = dict(data)
23062306
self.randomize(None)
23072307
if not self._do_transform:
2308-
out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())
2309-
return out
2308+
for key in self.key_iterator(d):
2309+
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
2310+
return d
23102311

23112312
first_key: Hashable = self.first_key(d)
23122313
if first_key == ():

tests/transforms/test_rand_grid_distortiond.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,17 @@ def test_rand_grid_distortiond(self, input_param, seed, input_data, expected_val
8686
assert_allclose(result["mask"], expected_val_mask, type_test=False, rtol=1e-4, atol=1e-4)
8787

8888

89+
def test_no_transform_with_non_tensor_metadata(self):
90+
"""When _do_transform is False, non-tensor values in the dict should not cause an error."""
91+
img = np.indices([6, 6]).astype(np.float32)
92+
data = {"img": img, "extra_info": 42, "label_name": "tumor"}
93+
g = RandGridDistortiond(keys=["img"], prob=0.0) # prob=0 ensures _do_transform is False
94+
result = g(data)
95+
# non-tensor metadata should pass through unchanged
96+
self.assertEqual(result["extra_info"], 42)
97+
self.assertEqual(result["label_name"], "tumor")
98+
assert_allclose(result["img"], img, type_test=False)
99+
100+
89101
if __name__ == "__main__":
90102
unittest.main()

0 commit comments

Comments
 (0)