Skip to content

Commit 18aea22

Browse files
i-onyashahab
andauthored
Exposes num_parallel_reads and num_parallel_calls (tensorflow#1232)
* Exposes num_parallel_reads and num_parallel_calls -Exposes `num_parallel_reads` and `num_parallel_calls` in AvroRecordDataset and `make_avro_record_dataset` -Adds parameter constraints -Fixes lint issues * Exposes num_parallel_reads and num_parallel_calls -Exposes `num_parallel_reads` and `num_parallel_calls` in AvroRecordDataset and `make_avro_record_dataset` -Adds parameter constraints -Fixes lint issues * Exposes num_parallel_reads and num_parallel_calls -Exposes `num_parallel_reads` and `num_parallel_calls` in AvroRecordDataset and `make_avro_record_dataset` -Adds parameter constraints -Fixes lint issues * Fixes Lint Issues * Removes Optional typing for method parameter - * Adds test method for _require() function -This update adds a test to check if ValueErrors are raised when given an invalid input for num_parallel_calls * Uncomments skip for macOS pytests * Fixes Lint issues Co-authored-by: Abin Shahab <[email protected]>
1 parent 99cb778 commit 18aea22

File tree

3 files changed

+189
-54
lines changed

3 files changed

+189
-54
lines changed

tensorflow_io/core/python/experimental/avro_record_dataset_ops.py

Lines changed: 126 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,24 @@
2121
_DEFAULT_READER_SCHEMA = ""
2222
# From https://github.com/tensorflow/tensorflow/blob/v2.0.0/tensorflow/python/data/ops/readers.py
2323

24+
25+
def _require(condition: bool, err_msg: str = None) -> None:
26+
"""Checks if the specified condition is true else raises exception
27+
28+
Args:
29+
condition: The condition to test
30+
err_msg: If specified, it's the error message to use if condition is not true.
31+
32+
Raises:
33+
ValueError: Raised when the condition is false
34+
35+
Returns:
36+
None
37+
"""
38+
if not condition:
39+
raise ValueError(err_msg)
40+
41+
2442
# copied from https://github.com/tensorflow/tensorflow/blob/
2543
# 3095681b8649d9a828afb0a14538ace7a998504d/tensorflow/python/data/ops/readers.py#L36
2644
def _create_or_validate_filenames_dataset(filenames):
@@ -52,21 +70,62 @@ def _create_or_validate_filenames_dataset(filenames):
5270

5371
# copied from https://github.com/tensorflow/tensorflow/blob/
5472
# 3095681b8649d9a828afb0a14538ace7a998504d/tensorflow/python/data/ops/readers.py#L67
55-
def _create_dataset_reader(dataset_creator, filenames, num_parallel_reads=None):
56-
"""create_dataset_reader"""
57-
58-
def read_one_file(filename):
59-
filename = tf.convert_to_tensor(filename, tf.string, name="filename")
60-
return dataset_creator(filename)
61-
62-
if num_parallel_reads is None:
63-
return filenames.flat_map(read_one_file)
64-
if num_parallel_reads == tf.data.experimental.AUTOTUNE:
65-
return filenames.interleave(
66-
read_one_file, num_parallel_calls=num_parallel_reads
67-
)
73+
def _create_dataset_reader(
74+
dataset_creator,
75+
filenames,
76+
cycle_length=None,
77+
num_parallel_calls=None,
78+
deterministic=None,
79+
block_length=1,
80+
):
81+
"""
82+
This creates a dataset reader which reads records from multiple files and interleaves them together
83+
```
84+
dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
85+
# NOTE: New lines indicate "block" boundaries.
86+
dataset = dataset.interleave(
87+
lambda x: Dataset.from_tensors(x).repeat(6),
88+
cycle_length=2, block_length=4)
89+
list(dataset.as_numpy_iterator())
90+
```
91+
Results in the following output:
92+
[1,1,1,1,
93+
2,2,2,2,
94+
1,1,
95+
2,2,
96+
3,3,3,3,
97+
4,4,4,4,
98+
3,4,
99+
5,5,5,5,
100+
5,5,
101+
]
102+
Args:
103+
dataset_creator: Initializer for AvroDatasetRecord
104+
filenames: A `tf.data.Dataset` iterator of filenames to read
105+
cycle_length: The number of files to be processed in parallel. This is used by `Dataset.Interleave`.
106+
We set this equal to `block_length`, so that each time n number of records are returned for each of the n
107+
files.
108+
num_parallel_calls: Number of threads spawned by the interleave call.
109+
deterministic: Sets whether the interleaved records are written in deterministic order. in tf.interleave this is default true
110+
block_length: Sets the number of output on the output tensor. Defaults to 1
111+
Returns:
112+
A dataset iterator with an interleaved list of parsed avro records.
113+
114+
"""
115+
116+
def read_many_files(filenames):
117+
filenames = tf.convert_to_tensor(filenames, tf.string, name="filename")
118+
return dataset_creator(filenames)
119+
120+
if cycle_length is None:
121+
return filenames.flat_map(read_many_files)
122+
68123
return filenames.interleave(
69-
read_one_file, cycle_length=num_parallel_reads, block_length=1
124+
read_many_files,
125+
cycle_length=cycle_length,
126+
num_parallel_calls=num_parallel_calls,
127+
block_length=block_length,
128+
deterministic=deterministic,
70129
)
71130

72131

@@ -128,10 +187,16 @@ class AvroRecordDataset(tf.data.Dataset):
128187
"""A `Dataset` comprising records from one or more AvroRecord files."""
129188

130189
def __init__(
131-
self, filenames, buffer_size=None, num_parallel_reads=None, reader_schema=None
190+
self,
191+
filenames,
192+
buffer_size=None,
193+
num_parallel_reads=None,
194+
num_parallel_calls=None,
195+
reader_schema=None,
196+
deterministic=True,
197+
block_length=1,
132198
):
133199
"""Creates a `AvroRecordDataset` to read one or more AvroRecord files.
134-
135200
Args:
136201
filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
137202
more filenames.
@@ -144,25 +209,61 @@ def __init__(
144209
files read in parallel are outputted in an interleaved order. If your
145210
input pipeline is I/O bottlenecked, consider setting this parameter to a
146211
value greater than one to parallelize the I/O. If `None`, files will be
147-
read sequentially.
212+
read sequentially. This must be set to equal or greater than `num_parallel_calls`.
213+
This constraint exists because `num_parallel_reads` becomes `cycle_length` in the
214+
underlying call to `tf.Dataset.Interleave`, and the `cycle_length` is required to be
215+
equal or higher than the number of threads(`num_parallel_calls`).
216+
`cycle_length` in tf.Dataset.Interleave will dictate how many items it will pick up to process
217+
num_parallel_calls: (Optional.) number of thread to spawn. This must be set to `None`
218+
or greater than 0. Also this must be less than or equal to `num_parallel_reads`. This defines
219+
the degree of parallelism in the underlying Dataset.interleave call.
148220
reader_schema: (Optional.) A `tf.string` scalar representing the reader
149221
schema or None.
150-
222+
deterministic: (Optional.) A boolean controlling whether determinism should be traded for performance by
223+
allowing elements to be produced out of order. Defaults to `True`
224+
block_length: Sets the number of output on the output tensor. Defaults to 1
151225
Raises:
152226
TypeError: If any argument does not have the expected type.
153227
ValueError: If any argument does not have the expected shape.
154228
"""
229+
_require(
230+
num_parallel_calls is None
231+
or num_parallel_calls == tf.data.experimental.AUTOTUNE
232+
or num_parallel_calls > 0,
233+
f"num_parallel_calls: {num_parallel_calls} must be set to None, "
234+
f"tf.data.experimental.AUTOTUNE, or greater than 0",
235+
)
236+
if num_parallel_calls is not None:
237+
_require(
238+
num_parallel_reads is not None
239+
and (
240+
num_parallel_reads >= num_parallel_calls
241+
or num_parallel_reads == tf.data.experimental.AUTOTUNE
242+
),
243+
f"num_parallel_reads: {num_parallel_reads} must be greater than or equal to "
244+
f"num_parallel_calls: {num_parallel_calls} or set to tf.data.experimental.AUTOTUNE",
245+
)
246+
155247
filenames = _create_or_validate_filenames_dataset(filenames)
156248

157249
self._filenames = filenames
158250
self._buffer_size = buffer_size
159251
self._num_parallel_reads = num_parallel_reads
252+
self._num_parallel_calls = num_parallel_calls
160253
self._reader_schema = reader_schema
254+
self._block_length = block_length
161255

162-
def creator_fn(filename):
163-
return _AvroRecordDataset(filename, buffer_size, reader_schema)
256+
def read_multiple_files(filenames):
257+
return _AvroRecordDataset(filenames, buffer_size, reader_schema)
164258

165-
self._impl = _create_dataset_reader(creator_fn, filenames, num_parallel_reads)
259+
self._impl = _create_dataset_reader(
260+
read_multiple_files,
261+
filenames,
262+
cycle_length=num_parallel_reads,
263+
num_parallel_calls=num_parallel_calls,
264+
deterministic=deterministic,
265+
block_length=block_length,
266+
)
166267
variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access
167268
super().__init__(variant_tensor)
168269

@@ -171,13 +272,17 @@ def _clone(
171272
filenames=None,
172273
buffer_size=None,
173274
num_parallel_reads=None,
275+
num_parallel_calls=None,
174276
reader_schema=None,
277+
block_length=None,
175278
):
176279
return AvroRecordDataset(
177280
filenames or self._filenames,
178281
buffer_size or self._buffer_size,
179282
num_parallel_reads or self._num_parallel_reads,
283+
num_parallel_calls or self._num_parallel_calls,
180284
reader_schema or self._reader_schema,
285+
block_length or self._block_length,
181286
)
182287

183288
def _inputs(self):

tensorflow_io/core/python/experimental/make_avro_record_dataset.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -37,60 +37,41 @@ def make_avro_record_dataset(
3737
shuffle_seed=None,
3838
prefetch_buffer_size=tf.data.experimental.AUTOTUNE,
3939
num_parallel_reads=None,
40-
num_parallel_parser_calls=None,
4140
drop_final_batch=False,
4241
):
4342
"""Reads and (optionally) parses avro files into a dataset.
44-
4543
Provides common functionality such as batching, optional parsing, shuffling,
4644
and performing defaults.
47-
4845
Args:
4946
file_pattern: List of files or patterns of avro file paths.
5047
See `tf.io.gfile.glob` for pattern rules.
51-
5248
features: A map of feature names mapped to feature information.
53-
5449
batch_size: An int representing the number of records to combine
5550
in a single batch.
56-
5751
reader_schema: The reader schema.
58-
5952
reader_buffer_size: (Optional.) An int specifying the readers buffer
6053
size in By. If None (the default) will use the default value from
6154
AvroRecordDataset.
62-
6355
num_epochs: (Optional.) An int specifying the number of times this
6456
dataset is repeated. If None (the default), cycles through the
6557
dataset forever. If set to None drops final batch.
66-
6758
shuffle: (Optional.) A bool that indicates whether the input
6859
should be shuffled. Defaults to `True`.
69-
7060
shuffle_buffer_size: (Optional.) Buffer size to use for
7161
shuffling. A large buffer size ensures better shuffling, but
7262
increases memory usage and startup time. If not provided
7363
assumes default value of 10,000 records. Note that the shuffle
7464
size is measured in records.
75-
7665
shuffle_seed: (Optional.) Randomization seed to use for shuffling.
7766
By default uses a pseudo-random seed.
78-
7967
prefetch_buffer_size: (Optional.) An int specifying the number of
8068
feature batches to prefetch for performance improvement.
8169
Defaults to auto-tune. Set to 0 to disable prefetching.
82-
83-
num_parallel_reads: (Optional.) Number of threads used to read
84-
records from files. By default or if set to a value >1, the
85-
results will be interleaved.
86-
87-
num_parallel_parser_calls: (Optional.) Number of parallel
88-
records to parse in parallel. Defaults to an automatic selection.
89-
70+
num_parallel_reads: (Optional.) Number of parallel
71+
records to parse in parallel. Defaults to None(no parallelization).
9072
drop_final_batch: (Optional.) Whether the last batch should be
9173
dropped in case its size is smaller than `batch_size`; the
9274
default behavior is not to drop the smaller batch.
93-
9475
Returns:
9576
A dataset, where each element matches the output of `parser_fn`
9677
except it will have an additional leading `batch-size` dimension,
@@ -99,20 +80,15 @@ def make_avro_record_dataset(
9980
"""
10081
files = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle, seed=shuffle_seed)
10182

102-
if num_parallel_reads is None:
103-
# Note: We considered auto-tuning this value, but there is a concern
104-
# that this affects the mixing of records from different files, which
105-
# could affect training convergence/accuracy, so we are defaulting to
106-
# a constant for now.
107-
num_parallel_reads = 24
108-
10983
if reader_buffer_size is None:
11084
reader_buffer_size = 1024 * 1024
111-
85+
num_parallel_calls = num_parallel_reads
11286
dataset = AvroRecordDataset(
11387
files,
11488
buffer_size=reader_buffer_size,
11589
num_parallel_reads=num_parallel_reads,
90+
num_parallel_calls=num_parallel_calls,
91+
block_length=num_parallel_calls,
11692
reader_schema=reader_schema,
11793
)
11894

@@ -131,14 +107,11 @@ def make_avro_record_dataset(
131107

132108
dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
133109

134-
if num_parallel_parser_calls is None:
135-
num_parallel_parser_calls = tf.data.experimental.AUTOTUNE
136-
137110
dataset = dataset.map(
138111
lambda data: parse_avro(
139112
serialized=data, reader_schema=reader_schema, features=features
140113
),
141-
num_parallel_calls=num_parallel_parser_calls,
114+
num_parallel_calls=num_parallel_calls,
142115
)
143116

144117
if prefetch_buffer_size == 0:

tests/test_parse_avro_eager.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,63 @@ def _load_records_as_tensors(filenames, schema):
246246
),
247247
)
248248

249+
def test_inval_num_parallel_calls(self):
250+
"""test_inval_num_parallel_calls
251+
This function tests that value errors are raised upon
252+
the passing of invalid values for num_parallel_calls which
253+
includes zero values and values greater than num_parallel_reads
254+
"""
255+
256+
NUM_PARALLEL_READS = 1
257+
NUM_PARALLEL_CALLS_ZERO = 0
258+
NUM_PARALLEL_CALLS_GREATER = 2
259+
260+
writer_schema = """{
261+
"type": "record",
262+
"name": "dataTypes",
263+
"fields": [
264+
{
265+
"name":"index",
266+
"type":"int"
267+
},
268+
{
269+
"name":"string_value",
270+
"type":"string"
271+
}
272+
]}"""
273+
274+
record_data = [
275+
{"index": 0, "string_value": ""},
276+
{"index": 1, "string_value": "SpecialChars@!#$%^&*()-_=+{}[]|/`~\\'?"},
277+
{
278+
"index": 2,
279+
"string_value": "ABCDEFGHIJKLMNOPQRSTUVW"
280+
+ "Zabcdefghijklmnopqrstuvwz0123456789",
281+
},
282+
]
283+
284+
filenames = AvroRecordDatasetTest._setup_files(
285+
writer_schema=writer_schema, records=record_data
286+
)
287+
288+
with pytest.raises(ValueError):
289+
290+
dataset_a = tfio.experimental.columnar.AvroRecordDataset(
291+
filenames=filenames,
292+
num_parallel_reads=NUM_PARALLEL_READS,
293+
num_parallel_calls=NUM_PARALLEL_CALLS_ZERO,
294+
reader_schema="reader_schema",
295+
)
296+
297+
with pytest.raises(ValueError):
298+
299+
dataset_b = tfio.experimental.columnar.AvroRecordDataset(
300+
filenames=filenames,
301+
num_parallel_reads=NUM_PARALLEL_READS,
302+
num_parallel_calls=NUM_PARALLEL_CALLS_GREATER,
303+
reader_schema="reader_schema",
304+
)
305+
249306
def _test_pass_dataset(self, writer_schema, record_data, **kwargs):
250307
"""test_pass_dataset"""
251308
filenames = AvroRecordDatasetTest._setup_files(

0 commit comments

Comments
 (0)