21
21
_DEFAULT_READER_SCHEMA = ""
22
22
# From https://github.com/tensorflow/tensorflow/blob/v2.0.0/tensorflow/python/data/ops/readers.py
23
23
24
+
25
+ def _require (condition : bool , err_msg : str = None ) -> None :
26
+ """Checks if the specified condition is true else raises exception
27
+
28
+ Args:
29
+ condition: The condition to test
30
+ err_msg: If specified, it's the error message to use if condition is not true.
31
+
32
+ Raises:
33
+ ValueError: Raised when the condition is false
34
+
35
+ Returns:
36
+ None
37
+ """
38
+ if not condition :
39
+ raise ValueError (err_msg )
40
+
41
+
24
42
# copied from https://github.com/tensorflow/tensorflow/blob/
25
43
# 3095681b8649d9a828afb0a14538ace7a998504d/tensorflow/python/data/ops/readers.py#L36
26
44
def _create_or_validate_filenames_dataset (filenames ):
@@ -52,21 +70,62 @@ def _create_or_validate_filenames_dataset(filenames):
52
70
53
71
# copied from https://github.com/tensorflow/tensorflow/blob/
54
72
# 3095681b8649d9a828afb0a14538ace7a998504d/tensorflow/python/data/ops/readers.py#L67
55
- def _create_dataset_reader (dataset_creator , filenames , num_parallel_reads = None ):
56
- """create_dataset_reader"""
57
-
58
- def read_one_file (filename ):
59
- filename = tf .convert_to_tensor (filename , tf .string , name = "filename" )
60
- return dataset_creator (filename )
61
-
62
- if num_parallel_reads is None :
63
- return filenames .flat_map (read_one_file )
64
- if num_parallel_reads == tf .data .experimental .AUTOTUNE :
65
- return filenames .interleave (
66
- read_one_file , num_parallel_calls = num_parallel_reads
67
- )
73
+ def _create_dataset_reader (
74
+ dataset_creator ,
75
+ filenames ,
76
+ cycle_length = None ,
77
+ num_parallel_calls = None ,
78
+ deterministic = None ,
79
+ block_length = 1 ,
80
+ ):
81
+ """
82
+ This creates a dataset reader which reads records from multiple files and interleaves them together
83
+ ```
84
+ dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
85
+ # NOTE: New lines indicate "block" boundaries.
86
+ dataset = dataset.interleave(
87
+ lambda x: Dataset.from_tensors(x).repeat(6),
88
+ cycle_length=2, block_length=4)
89
+ list(dataset.as_numpy_iterator())
90
+ ```
91
+ Results in the following output:
92
+ [1,1,1,1,
93
+ 2,2,2,2,
94
+ 1,1,
95
+ 2,2,
96
+ 3,3,3,3,
97
+ 4,4,4,4,
98
+ 3,4,
99
+ 5,5,5,5,
100
+ 5,5,
101
+ ]
102
+ Args:
103
+ dataset_creator: Initializer for AvroDatasetRecord
104
+ filenames: A `tf.data.Dataset` iterator of filenames to read
105
+ cycle_length: The number of files to be processed in parallel. This is used by `Dataset.Interleave`.
106
+ We set this equal to `block_length`, so that each time n number of records are returned for each of the n
107
+ files.
108
+ num_parallel_calls: Number of threads spawned by the interleave call.
109
+ deterministic: Sets whether the interleaved records are written in deterministic order. in tf.interleave this is default true
110
+ block_length: Sets the number of output on the output tensor. Defaults to 1
111
+ Returns:
112
+ A dataset iterator with an interleaved list of parsed avro records.
113
+
114
+ """
115
+
116
+ def read_many_files (filenames ):
117
+ filenames = tf .convert_to_tensor (filenames , tf .string , name = "filename" )
118
+ return dataset_creator (filenames )
119
+
120
+ if cycle_length is None :
121
+ return filenames .flat_map (read_many_files )
122
+
68
123
return filenames .interleave (
69
- read_one_file , cycle_length = num_parallel_reads , block_length = 1
124
+ read_many_files ,
125
+ cycle_length = cycle_length ,
126
+ num_parallel_calls = num_parallel_calls ,
127
+ block_length = block_length ,
128
+ deterministic = deterministic ,
70
129
)
71
130
72
131
@@ -128,10 +187,16 @@ class AvroRecordDataset(tf.data.Dataset):
128
187
"""A `Dataset` comprising records from one or more AvroRecord files."""
129
188
130
189
def __init__ (
131
- self , filenames , buffer_size = None , num_parallel_reads = None , reader_schema = None
190
+ self ,
191
+ filenames ,
192
+ buffer_size = None ,
193
+ num_parallel_reads = None ,
194
+ num_parallel_calls = None ,
195
+ reader_schema = None ,
196
+ deterministic = True ,
197
+ block_length = 1 ,
132
198
):
133
199
"""Creates a `AvroRecordDataset` to read one or more AvroRecord files.
134
-
135
200
Args:
136
201
filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
137
202
more filenames.
@@ -144,25 +209,61 @@ def __init__(
144
209
files read in parallel are outputted in an interleaved order. If your
145
210
input pipeline is I/O bottlenecked, consider setting this parameter to a
146
211
value greater than one to parallelize the I/O. If `None`, files will be
147
- read sequentially.
212
+ read sequentially. This must be set to equal or greater than `num_parallel_calls`.
213
+ This constraint exists because `num_parallel_reads` becomes `cycle_length` in the
214
+ underlying call to `tf.Dataset.Interleave`, and the `cycle_length` is required to be
215
+ equal or higher than the number of threads(`num_parallel_calls`).
216
+ `cycle_length` in tf.Dataset.Interleave will dictate how many items it will pick up to process
217
+ num_parallel_calls: (Optional.) number of thread to spawn. This must be set to `None`
218
+ or greater than 0. Also this must be less than or equal to `num_parallel_reads`. This defines
219
+ the degree of parallelism in the underlying Dataset.interleave call.
148
220
reader_schema: (Optional.) A `tf.string` scalar representing the reader
149
221
schema or None.
150
-
222
+ deterministic: (Optional.) A boolean controlling whether determinism should be traded for performance by
223
+ allowing elements to be produced out of order. Defaults to `True`
224
+ block_length: Sets the number of output on the output tensor. Defaults to 1
151
225
Raises:
152
226
TypeError: If any argument does not have the expected type.
153
227
ValueError: If any argument does not have the expected shape.
154
228
"""
229
+ _require (
230
+ num_parallel_calls is None
231
+ or num_parallel_calls == tf .data .experimental .AUTOTUNE
232
+ or num_parallel_calls > 0 ,
233
+ f"num_parallel_calls: { num_parallel_calls } must be set to None, "
234
+ f"tf.data.experimental.AUTOTUNE, or greater than 0" ,
235
+ )
236
+ if num_parallel_calls is not None :
237
+ _require (
238
+ num_parallel_reads is not None
239
+ and (
240
+ num_parallel_reads >= num_parallel_calls
241
+ or num_parallel_reads == tf .data .experimental .AUTOTUNE
242
+ ),
243
+ f"num_parallel_reads: { num_parallel_reads } must be greater than or equal to "
244
+ f"num_parallel_calls: { num_parallel_calls } or set to tf.data.experimental.AUTOTUNE" ,
245
+ )
246
+
155
247
filenames = _create_or_validate_filenames_dataset (filenames )
156
248
157
249
self ._filenames = filenames
158
250
self ._buffer_size = buffer_size
159
251
self ._num_parallel_reads = num_parallel_reads
252
+ self ._num_parallel_calls = num_parallel_calls
160
253
self ._reader_schema = reader_schema
254
+ self ._block_length = block_length
161
255
162
- def creator_fn ( filename ):
163
- return _AvroRecordDataset (filename , buffer_size , reader_schema )
256
+ def read_multiple_files ( filenames ):
257
+ return _AvroRecordDataset (filenames , buffer_size , reader_schema )
164
258
165
- self ._impl = _create_dataset_reader (creator_fn , filenames , num_parallel_reads )
259
+ self ._impl = _create_dataset_reader (
260
+ read_multiple_files ,
261
+ filenames ,
262
+ cycle_length = num_parallel_reads ,
263
+ num_parallel_calls = num_parallel_calls ,
264
+ deterministic = deterministic ,
265
+ block_length = block_length ,
266
+ )
166
267
variant_tensor = self ._impl ._variant_tensor # pylint: disable=protected-access
167
268
super ().__init__ (variant_tensor )
168
269
@@ -171,13 +272,17 @@ def _clone(
171
272
filenames = None ,
172
273
buffer_size = None ,
173
274
num_parallel_reads = None ,
275
+ num_parallel_calls = None ,
174
276
reader_schema = None ,
277
+ block_length = None ,
175
278
):
176
279
return AvroRecordDataset (
177
280
filenames or self ._filenames ,
178
281
buffer_size or self ._buffer_size ,
179
282
num_parallel_reads or self ._num_parallel_reads ,
283
+ num_parallel_calls or self ._num_parallel_calls ,
180
284
reader_schema or self ._reader_schema ,
285
+ block_length or self ._block_length ,
181
286
)
182
287
183
288
def _inputs (self ):
0 commit comments