5
5
import operator
6
6
import pathlib
7
7
import string
8
- from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple , cast , Union
8
+ import sys
9
+ from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple , cast
9
10
10
- import numpy as np
11
11
import torch
12
12
from torchdata .datapipes .iter import (
13
13
IterDataPipe ,
38
38
prod = functools .partial (functools .reduce , operator .mul )
39
39
40
40
41
- class MNISTFileReader (IterDataPipe [np . ndarray ]):
41
+ class MNISTFileReader (IterDataPipe [torch . Tensor ]):
42
42
_DTYPE_MAP = {
43
- 8 : "u1" , # uint8
44
- 9 : "i1" , # int8
45
- 11 : "i2" , # int16
46
- 12 : "i4" , # int32
47
- 13 : "f4" , # float32
48
- 14 : "f8" , # float64
43
+ 8 : torch . uint8 ,
44
+ 9 : torch . int8 ,
45
+ 11 : torch . int16 ,
46
+ 12 : torch . int32 ,
47
+ 13 : torch . float32 ,
48
+ 14 : torch . float64 ,
49
49
}
50
50
51
51
def __init__ (
@@ -59,30 +59,36 @@ def __init__(
59
59
def _decode (bytes : bytes ) -> int :
60
60
return int (codecs .encode (bytes , "hex" ), 16 )
61
61
62
- def __iter__ (self ) -> Iterator [np . ndarray ]:
62
+ def __iter__ (self ) -> Iterator [torch . Tensor ]:
63
63
for _ , file in self .datapipe :
64
64
magic = self ._decode (file .read (4 ))
65
- dtype_type = self ._DTYPE_MAP [magic // 256 ]
65
+ dtype = self ._DTYPE_MAP [magic // 256 ]
66
66
ndim = magic % 256 - 1
67
67
68
68
num_samples = self ._decode (file .read (4 ))
69
69
shape = [self ._decode (file .read (4 )) for _ in range (ndim )]
70
70
71
- in_dtype = np .dtype (f">{ dtype_type } " )
72
- out_dtype = np .dtype (dtype_type )
73
- chunk_size = (cast (int , prod (shape )) if shape else 1 ) * in_dtype .itemsize
71
+ num_bytes_per_value = (torch .finfo if dtype .is_floating_point else torch .iinfo )(dtype ).bits // 8
72
+ # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
73
+ # we need to reverse the bytes before we can read them with torch.frombuffer().
74
+ needs_byte_reversal = sys .byteorder == "little" and num_bytes_per_value > 1
75
+ chunk_size = (cast (int , prod (shape )) if shape else 1 ) * num_bytes_per_value
74
76
75
77
start = self .start or 0
76
78
stop = self .stop or num_samples
77
79
78
80
file .seek (start * chunk_size , 1 )
79
81
for _ in range (stop - start ):
80
82
chunk = file .read (chunk_size )
81
- yield np .frombuffer (chunk , dtype = in_dtype ).astype (out_dtype ).reshape (shape )
83
+ if not needs_byte_reversal :
84
+ yield torch .frombuffer (chunk , dtype = dtype ).reshape (shape )
85
+
86
+ chunk = bytearray (chunk )
87
+ chunk .reverse ()
88
+ yield torch .frombuffer (chunk , dtype = dtype ).flip (0 ).reshape (shape )
82
89
83
90
84
91
class _MNISTBase (Dataset ):
85
- _FORMAT = "png"
86
92
_URL_BASE : str
87
93
88
94
@abc .abstractmethod
@@ -105,24 +111,23 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional
105
111
106
112
def _collate_and_decode (
107
113
self ,
108
- data : Tuple [np . ndarray , np . ndarray ],
114
+ data : Tuple [torch . Tensor , torch . Tensor ],
109
115
* ,
110
116
config : DatasetConfig ,
111
117
decoder : Optional [Callable [[io .IOBase ], torch .Tensor ]],
112
118
) -> Dict [str , Any ]:
113
- image_array , label_array = data
119
+ image , label = data
114
120
115
- image : Union [torch .Tensor , io .BytesIO ]
116
121
if decoder is raw :
117
- image = torch . from_numpy ( image_array )
122
+ image = image . unsqueeze ( 0 )
118
123
else :
119
- image_buffer = image_buffer_from_array (image_array )
120
- image = decoder (image_buffer ) if decoder else image_buffer
124
+ image_buffer = image_buffer_from_array (image . numpy () )
125
+ image = decoder (image_buffer ) if decoder else image_buffer # type: ignore[assignment]
121
126
122
- label = torch .tensor (label_array , dtype = torch .int64 )
123
127
category = self .info .categories [int (label )]
128
+ label = label .to (torch .int64 )
124
129
125
- return dict (image = image , label = label , category = category )
130
+ return dict (image = image , category = category , label = label )
126
131
127
132
def _make_datapipe (
128
133
self ,
@@ -293,12 +298,11 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) ->
293
298
294
299
def _collate_and_decode (
295
300
self ,
296
- data : Tuple [np . ndarray , np . ndarray ],
301
+ data : Tuple [torch . Tensor , torch . Tensor ],
297
302
* ,
298
303
config : DatasetConfig ,
299
304
decoder : Optional [Callable [[io .IOBase ], torch .Tensor ]],
300
305
) -> Dict [str , Any ]:
301
- image_array , label_array = data
302
306
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
303
307
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
304
308
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example,
@@ -308,8 +312,8 @@ def _collate_and_decode(
308
312
# index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing)
309
313
# in self.categories. Thus, we need to add 1 to the label to correct this.
310
314
if config .image_set in ("Balanced" , "By_Merge" ):
311
- label_array += np . array ( self ._LABEL_OFFSETS .get (int (label_array ), 0 ), dtype = label_array . dtype )
312
- return super ()._collate_and_decode (( image_array , label_array ) , config = config , decoder = decoder )
315
+ data [ 1 ] += self ._LABEL_OFFSETS .get (int (data [ 1 ] ), 0 )
316
+ return super ()._collate_and_decode (data , config = config , decoder = decoder )
313
317
314
318
def _make_datapipe (
315
319
self ,
@@ -379,22 +383,22 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional
379
383
380
384
def _collate_and_decode (
381
385
self ,
382
- data : Tuple [np . ndarray , np . ndarray ],
386
+ data : Tuple [torch . Tensor , torch . Tensor ],
383
387
* ,
384
388
config : DatasetConfig ,
385
389
decoder : Optional [Callable [[io .IOBase ], torch .Tensor ]],
386
390
) -> Dict [str , Any ]:
387
- image_array , label_array = data
388
- label_parts = label_array . tolist ()
389
- sample = super ()._collate_and_decode ((image_array , label_parts [ 0 ] ), config = config , decoder = decoder )
391
+ image , ann = data
392
+ label , * extra_anns = ann
393
+ sample = super ()._collate_and_decode ((image , label ), config = config , decoder = decoder )
390
394
391
395
sample .update (
392
396
dict (
393
397
zip (
394
398
("nist_hsf_series" , "nist_writer_id" , "digit_index" , "nist_label" , "global_digit_index" ),
395
- label_parts [ 1 : 6 ],
399
+ [ int ( value ) for value in extra_anns [: 5 ] ],
396
400
)
397
401
)
398
402
)
399
- sample .update (dict (zip (("duplicate" , "unused" ), [bool (value ) for value in label_parts [- 2 :]])))
403
+ sample .update (dict (zip (("duplicate" , "unused" ), [bool (value ) for value in extra_anns [- 2 :]])))
400
404
return sample
0 commit comments