Skip to content

Commit 29aa38a

Browse files
authored
replace np.frombuffer with torch.frombuffer in MNIST prototype (#4651)
* replace np.frombuffer with torch.frombuffer in MNIST prototype * cleanup * appease mypy * more cleanup * clarify inplace offset * fix num bytes for floating point data
1 parent 979ecac commit 29aa38a

File tree

1 file changed

+38
-34
lines changed
  • torchvision/prototype/datasets/_builtin

1 file changed

+38
-34
lines changed

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import operator
66
import pathlib
77
import string
8-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, Union
8+
import sys
9+
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast
910

10-
import numpy as np
1111
import torch
1212
from torchdata.datapipes.iter import (
1313
IterDataPipe,
@@ -38,14 +38,14 @@
3838
prod = functools.partial(functools.reduce, operator.mul)
3939

4040

41-
class MNISTFileReader(IterDataPipe[np.ndarray]):
41+
class MNISTFileReader(IterDataPipe[torch.Tensor]):
4242
_DTYPE_MAP = {
43-
8: "u1", # uint8
44-
9: "i1", # int8
45-
11: "i2", # int16
46-
12: "i4", # int32
47-
13: "f4", # float32
48-
14: "f8", # float64
43+
8: torch.uint8,
44+
9: torch.int8,
45+
11: torch.int16,
46+
12: torch.int32,
47+
13: torch.float32,
48+
14: torch.float64,
4949
}
5050

5151
def __init__(
@@ -59,30 +59,36 @@ def __init__(
5959
def _decode(bytes: bytes) -> int:
6060
return int(codecs.encode(bytes, "hex"), 16)
6161

62-
def __iter__(self) -> Iterator[np.ndarray]:
62+
def __iter__(self) -> Iterator[torch.Tensor]:
6363
for _, file in self.datapipe:
6464
magic = self._decode(file.read(4))
65-
dtype_type = self._DTYPE_MAP[magic // 256]
65+
dtype = self._DTYPE_MAP[magic // 256]
6666
ndim = magic % 256 - 1
6767

6868
num_samples = self._decode(file.read(4))
6969
shape = [self._decode(file.read(4)) for _ in range(ndim)]
7070

71-
in_dtype = np.dtype(f">{dtype_type}")
72-
out_dtype = np.dtype(dtype_type)
73-
chunk_size = (cast(int, prod(shape)) if shape else 1) * in_dtype.itemsize
71+
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
72+
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
73+
# we need to reverse the bytes before we can read them with torch.frombuffer().
74+
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
75+
chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value
7476

7577
start = self.start or 0
7678
stop = self.stop or num_samples
7779

7880
file.seek(start * chunk_size, 1)
7981
for _ in range(stop - start):
8082
chunk = file.read(chunk_size)
81-
yield np.frombuffer(chunk, dtype=in_dtype).astype(out_dtype).reshape(shape)
83+
if not needs_byte_reversal:
84+
yield torch.frombuffer(chunk, dtype=dtype).reshape(shape)
85+
86+
chunk = bytearray(chunk)
87+
chunk.reverse()
88+
yield torch.frombuffer(chunk, dtype=dtype).flip(0).reshape(shape)
8289

8390

8491
class _MNISTBase(Dataset):
85-
_FORMAT = "png"
8692
_URL_BASE: str
8793

8894
@abc.abstractmethod
@@ -105,24 +111,23 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional
105111

106112
def _collate_and_decode(
107113
self,
108-
data: Tuple[np.ndarray, np.ndarray],
114+
data: Tuple[torch.Tensor, torch.Tensor],
109115
*,
110116
config: DatasetConfig,
111117
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
112118
) -> Dict[str, Any]:
113-
image_array, label_array = data
119+
image, label = data
114120

115-
image: Union[torch.Tensor, io.BytesIO]
116121
if decoder is raw:
117-
image = torch.from_numpy(image_array)
122+
image = image.unsqueeze(0)
118123
else:
119-
image_buffer = image_buffer_from_array(image_array)
120-
image = decoder(image_buffer) if decoder else image_buffer
124+
image_buffer = image_buffer_from_array(image.numpy())
125+
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
121126

122-
label = torch.tensor(label_array, dtype=torch.int64)
123127
category = self.info.categories[int(label)]
128+
label = label.to(torch.int64)
124129

125-
return dict(image=image, label=label, category=category)
130+
return dict(image=image, category=category, label=label)
126131

127132
def _make_datapipe(
128133
self,
@@ -293,12 +298,11 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) ->
293298

294299
def _collate_and_decode(
295300
self,
296-
data: Tuple[np.ndarray, np.ndarray],
301+
data: Tuple[torch.Tensor, torch.Tensor],
297302
*,
298303
config: DatasetConfig,
299304
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
300305
) -> Dict[str, Any]:
301-
image_array, label_array = data
302306
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
303307
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
304308
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example,
@@ -308,8 +312,8 @@ def _collate_and_decode(
308312
# index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing)
309313
# in self.categories. Thus, we need to add 1 to the label to correct this.
310314
if config.image_set in ("Balanced", "By_Merge"):
311-
label_array += np.array(self._LABEL_OFFSETS.get(int(label_array), 0), dtype=label_array.dtype)
312-
return super()._collate_and_decode((image_array, label_array), config=config, decoder=decoder)
315+
data[1] += self._LABEL_OFFSETS.get(int(data[1]), 0)
316+
return super()._collate_and_decode(data, config=config, decoder=decoder)
313317

314318
def _make_datapipe(
315319
self,
@@ -379,22 +383,22 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional
379383

380384
def _collate_and_decode(
381385
self,
382-
data: Tuple[np.ndarray, np.ndarray],
386+
data: Tuple[torch.Tensor, torch.Tensor],
383387
*,
384388
config: DatasetConfig,
385389
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
386390
) -> Dict[str, Any]:
387-
image_array, label_array = data
388-
label_parts = label_array.tolist()
389-
sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder)
391+
image, ann = data
392+
label, *extra_anns = ann
393+
sample = super()._collate_and_decode((image, label), config=config, decoder=decoder)
390394

391395
sample.update(
392396
dict(
393397
zip(
394398
("nist_hsf_series", "nist_writer_id", "digit_index", "nist_label", "global_digit_index"),
395-
label_parts[1:6],
399+
[int(value) for value in extra_anns[:5]],
396400
)
397401
)
398402
)
399-
sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in label_parts[-2:]])))
403+
sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]])))
400404
return sample

0 commit comments

Comments
 (0)