Skip to content

Commit 5494f09

Browse files
Gong-airWrench-Git
authored andcommitted
feat(dipu): set pin_memory to false for unsupported vendor (DeepLink-org#898)
* pin_memory set false for unsupported vendor * python black
1 parent e94dffb commit 5494f09

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

dipu/torch_dipu/dipu/dataloader.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from torch.utils.data import DataLoader, Sampler, Dataset
3+
from torch_dipu import dipu
34

45
from typing import Any, Callable, Iterable, TypeVar, Sequence, List, Optional, Union
56

@@ -26,6 +27,8 @@
2627
# the CUDAHostAllocator can also obtain the device of the main thread and set as allocator's own device.
2728
# DIPU currently does not have such logic and supports this feature requires API similar to cuda's cuDevicePrimaryCtxGetState.
2829
class DIPUDataLoader(DataLoader):
30+
UNSUPPORTED_PINMEMORY_VENDORS = ["DROPLET"]
31+
2932
def __init__(
3033
self,
3134
dataset: Dataset[T_co],
@@ -44,9 +47,16 @@ def __init__(
4447
*,
4548
prefetch_factor: Optional[int] = None,
4649
persistent_workers: bool = False,
47-
pin_memory_device: str = ""
50+
pin_memory_device: str = "",
4851
):
49-
if pin_memory:
52+
if dipu.vendor_type in self.UNSUPPORTED_PINMEMORY_VENDORS:
53+
print(
54+
f"[DIPU] warning: {dipu.vendor_type} does not support pin_memory, continuing with pin_memory=False\n",
55+
flush=True,
56+
end=None,
57+
)
58+
pin_memory = False
59+
elif pin_memory:
5060
pin_memory_device = "cuda"
5161

5262
super().__init__(

0 commit comments

Comments
 (0)