Skip to content

Transforms are lost when using a dataloader with spawn #8066

Closed
@RickLuiken

Description

@RickLuiken

🐛 Describe the bug

When using a DataLoader with multiprocessing_context="spawn", transforms are lost. I believe this is an issue with the pickle change in #7860, the transforms are not pickled.

Minimal example:

from torch.utils.data import DataLoader
from torchvision.transforms import v2 as T
from torchvision.datasets import VOCDetection, wrap_dataset_for_transforms_v2

def _no_collate(batch):
    return batch

transform = T.ToImage()
dataset = VOCDetection(..., transforms=transform)
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)

dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)

# Sample from the dataset is not a tv_tensor image!
next(iter(dataloader))

Versions

Collecting environment information...
PyTorch version: 2.1.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Enterprise
GCC version: (Rev10, Built by MSYS2 project) 12.2.0
Clang version: Could not collect
CMake version: version 3.26.0
Libc version: N/A

Python version: 3.11.3 (tags/v3.11.3:f3909b8, Apr 4 2023, 23:49:59) [MSC v.1934 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22621-SP0
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A2000 Laptop GPU
Nvidia driver version: 536.25
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=2304
DeviceID=CPU0
Family=198
L2CacheSize=10240
L2CacheSpeed=
Manufacturer=GenuineIntel
MaxClockSpeed=2304
Name=11th Gen Intel(R) Core(TM) i7-11800H @ 2.30GHz
ProcessorType=3
Revision=

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] mypy==1.6.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] torch==2.1.0+cu121
[pip3] torchaudio==2.1.0+cu121
[pip3] torchvision==0.16.0+cu121
[conda] Could not collect

cc @vfdev-5

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions