-
Notifications
You must be signed in to change notification settings - Fork 304
Super Serial- automatically save and load TFRecords from Tensorflow datasets #1280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
f4bca6a
super_serial automatically creates TFRecords files from dictionary-st…
markemus 326c667
pep8 fixes
markemus a00f789
more pep8 (undoing tensorflow 2 space tabs)
markemus 7df2b7b
bazel changes
markemus aa5759f
small change so github checks will run again
markemus cbdcbec
moved super_serial test to tests/
markemus 84d3ba0
bazel changes
markemus 3a9b574
moved super_serial to experimental
markemus b47525f
refactored super_serial test to work for serial_ops
markemus bee9511
bazel fixes
markemus 34b191a
refactored test to load from tfio instead of full import path
markemus 1ce75ab
licenses
markemus b88a046
bazel fixes
markemus cf2a2f3
fixed license dates for new files
markemus 7447e94
small change so tests rerun
markemus 2d8c25b
small change so tests rerun
markemus fd075e3
cleanup and bazel fix
markemus d974c83
added test to ensure proper crash occurs when trying to save in graph…
markemus 1879e4f
bazel fixes
markemus 22fa83c
fixed imports for test
markemus a75a3a2
fixed imports for test
markemus 8ada254
fixed yaml imports for serial_ops
markemus 94419bf
fixed error path for new tf version
markemus 4a82fdd
prevented flaky behavior in graph mode for serial_ops.py by preemptiv…
markemus 4b898e7
sanity check for graph execution in graph_save_fail()
markemus d4f1f24
it should be impossible for serial_ops not to raise an exception now …
markemus f9212ea
moved eager execution check in serial_ops
markemus File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Easily save tf.data.Datasets as tfrecord files, and restore tfrecords as Datasets. | ||
|
||
The goal of this module is to create a SIMPLE api to tfrecords that can be used without | ||
learning all of the underlying mechanics. | ||
|
||
Users only need to deal with 2 functions: | ||
save_dataset(dataset) | ||
dataset = load_dataset(tfrecord, header) | ||
|
||
It really is that easy! | ||
|
||
To make this work, we create a .header file for each tfrecord which encodes metadata | ||
needed to reconstruct the original dataset. | ||
|
||
Note that PyYAML (yaml) package must be installed to make use of this module. | ||
|
||
Saving must be done in eager mode, but loading is compatible with both eager and | ||
graph execution modes. | ||
|
||
GOTCHAS: | ||
- This module is only compatible with "dictionary-style" datasets {key: val, key2:val2,..., keyN: valN}. | ||
- The restored dataset will have the TFRecord dtypes {float32, int64, string} instead of the original | ||
tensor dtypes. This is always the case with TFRecord datasets, whether you use this module or not. | ||
The original dtypes are stored in the headers if you want to restore them after loading.""" | ||
import functools | ||
import os | ||
import tempfile | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
|
||
# The three encoding functions. | ||
def _bytes_feature(value): | ||
"""value: list""" | ||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) | ||
|
||
|
||
def _float_feature(value): | ||
"""value: list""" | ||
return tf.train.Feature(float_list=tf.train.FloatList(value=value)) | ||
|
||
|
||
def _int64_feature(value): | ||
"""value: list""" | ||
return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) | ||
|
||
|
||
# TODO use base_type() to ensure consistent conversion. | ||
def np_value_to_feature(value): | ||
"""Maps dataset values to tf Features. | ||
Only numpy types are supported since Datasets only contain tensors. | ||
Each datatype should only have one way of being serialized.""" | ||
if isinstance(value, np.ndarray): | ||
# feature = _bytes_feature(value.tostring()) | ||
if np.issubdtype(value.dtype, np.integer): | ||
feature = _int64_feature(value.flatten()) | ||
elif np.issubdtype(value.dtype, np.float): | ||
feature = _float_feature(value.flatten()) | ||
elif np.issubdtype(value.dtype, np.bool): | ||
feature = _int64_feature(value.flatten()) | ||
else: | ||
raise TypeError(f"value dtype: {value.dtype} is not recognized.") | ||
elif isinstance(value, bytes): | ||
feature = _bytes_feature([value]) | ||
elif np.issubdtype(type(value), np.integer): | ||
feature = _int64_feature([value]) | ||
elif np.issubdtype(type(value), np.float): | ||
feature = _float_feature([value]) | ||
|
||
else: | ||
raise TypeError( | ||
f"value type: {type(value)} is not recognized. value must be a valid Numpy object." | ||
) | ||
|
||
return feature | ||
|
||
|
||
def base_type(dtype): | ||
"""Returns the TFRecords allowed type corresponding to dtype.""" | ||
int_types = [ | ||
tf.int8, | ||
tf.int16, | ||
tf.int32, | ||
tf.int64, | ||
tf.uint8, | ||
tf.uint16, | ||
tf.uint32, | ||
tf.uint64, | ||
tf.qint8, | ||
tf.qint16, | ||
tf.qint32, | ||
tf.bool, | ||
] | ||
float_types = [tf.float16, tf.float32, tf.float64] | ||
byte_types = [tf.string, bytes] | ||
|
||
if dtype in int_types: | ||
new_dtype = tf.int64 | ||
elif dtype in float_types: | ||
new_dtype = tf.float32 | ||
elif dtype in byte_types: | ||
new_dtype = tf.string | ||
else: | ||
raise ValueError(f"dtype {dtype} is not a recognized/supported type!") | ||
|
||
return new_dtype | ||
|
||
|
||
def build_header(dataset): | ||
"""Build header dictionary of metadata for the tensors in the dataset. This will be used when loading | ||
the tfrecords file to reconstruct the original tensors from the raw data. Shape is stored as an array | ||
and dtype is stored as an enumerated value (defined by tensorflow).""" | ||
header = {} | ||
for key in dataset.element_spec.keys(): | ||
header[key] = { | ||
"shape": list(dataset.element_spec[key].shape), | ||
"dtype": dataset.element_spec[key].dtype.as_datatype_enum, | ||
} | ||
|
||
return header | ||
|
||
|
||
def build_feature_desc(header): | ||
"""Build feature_desc dictionary for the tensors in the dataset. This will be used to reconstruct Examples | ||
from the tfrecords file. | ||
|
||
Assumes FixedLenFeatures. | ||
If you got VarLenFeatures I feel bad for you son, | ||
I got 115 problems but a VarLenFeature ain't one.""" | ||
feature_desc = {} | ||
for key, params in header.items(): | ||
feature_desc[key] = tf.io.FixedLenFeature( | ||
shape=params["shape"], dtype=base_type(int(params["dtype"])) | ||
) | ||
|
||
return feature_desc | ||
|
||
|
||
def dataset_to_examples(ds): | ||
"""Converts a dataset to a dataset of tf.train.Example strings. Each Example is a single observation. | ||
WARNING: Only compatible with "dictionary-style" datasets {key: val, key2:val2,..., keyN, valN}. | ||
WARNING: Must run in eager mode!""" | ||
# TODO handle tuples and flat datasets as well. | ||
for x in ds: | ||
# Each individual tensor is converted to a known serializable type. | ||
features = {key: np_value_to_feature(value.numpy()) for key, value in x.items()} | ||
# All features are then packaged into a single Example object. | ||
example = tf.train.Example(features=tf.train.Features(feature=features)) | ||
|
||
yield example.SerializeToString() | ||
|
||
|
||
def save_dataset(dataset, tfrecord_path, header_path): | ||
"""Saves a flat dataset as a tfrecord file, and builds a header file for reloading as dataset. | ||
Must run in eager mode because it depends on dataset iteration and element_spec.""" | ||
import yaml | ||
|
||
if not tf.executing_eagerly(): | ||
raise ValueError("save_dataset() must run in eager mode!") | ||
|
||
# Header | ||
header = build_header(dataset) | ||
header_file = open(header_path, "w") | ||
yaml.dump(header, stream=header_file) | ||
|
||
# Dataset | ||
ds_examples = tf.data.Dataset.from_generator( | ||
lambda: dataset_to_examples(dataset), output_types=tf.string | ||
) | ||
writer = tf.data.experimental.TFRecordWriter(tfrecord_path) | ||
writer.write(ds_examples) | ||
|
||
|
||
# TODO-DECIDE is this yaml loader safe? | ||
def load_dataset(tfrecord_path, header_path): | ||
"""Uses header file to predict the shape and dtypes of tensors for tf.data.""" | ||
import yaml | ||
|
||
header_file = open(header_path) | ||
header = yaml.load(header_file, Loader=yaml.FullLoader) | ||
|
||
feature_desc = build_feature_desc(header) | ||
parse_func = functools.partial(tf.io.parse_single_example, features=feature_desc) | ||
dataset = tf.data.TFRecordDataset(tfrecord_path).map(parse_func) | ||
|
||
return dataset |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for the super_serial.py serialization module.""" | ||
import os | ||
import tempfile | ||
|
||
import numpy as np | ||
import pytest | ||
import tensorflow as tf | ||
|
||
import tensorflow_io as tfio | ||
|
||
|
||
def test_serialization(): | ||
"""Test super serial saving and loading. | ||
NOTE- test will only work in eager mode due to list() dataset cast.""" | ||
savefolder = tempfile.TemporaryDirectory() | ||
savepath = os.path.join(savefolder.name, "temp_dataset") | ||
tfrecord_path = savepath + ".tfrecord" | ||
header_path = savepath + ".header" | ||
|
||
# Data | ||
x = np.linspace(1, 3000, num=3000).reshape(10, 10, 10, 3) | ||
y = np.linspace(1, 10, num=10).astype(int) | ||
ds = tf.data.Dataset.from_tensor_slices({"image": x, "label": y}) | ||
|
||
# Run | ||
tfio.experimental.serialization.save_dataset( | ||
ds, tfrecord_path=tfrecord_path, header_path=header_path | ||
) | ||
new_ds = tfio.experimental.serialization.load_dataset( | ||
tfrecord_path=tfrecord_path, header_path=header_path | ||
) | ||
|
||
# Test that values were saved and restored | ||
assert ( | ||
list(ds)[0]["image"].numpy()[0, 0, 0] | ||
== list(new_ds)[0]["image"].numpy()[0, 0, 0] | ||
) | ||
assert list(ds)[0]["label"] == list(new_ds)[0]["label"] | ||
|
||
assert ( | ||
list(ds)[-1]["image"].numpy()[0, 0, 0] | ||
== list(new_ds)[-1]["image"].numpy()[0, 0, 0] | ||
) | ||
assert list(ds)[-1]["label"] == list(new_ds)[-1]["label"] | ||
|
||
# Clean up- folder will disappear on crash as well. | ||
savefolder.cleanup() | ||
|
||
|
||
@tf.function | ||
def graph_save_fail(): | ||
"""Serial ops is expected to raise an exception when | ||
trying to save in graph mode.""" | ||
savefolder = tempfile.TemporaryDirectory() | ||
savepath = os.path.join(savefolder.name, "temp_dataset") | ||
tfrecord_path = savepath + ".tfrecord" | ||
header_path = savepath + ".header" | ||
|
||
# Data | ||
x = np.linspace(1, 3000, num=3000).reshape(10, 10, 10, 3) | ||
y = np.linspace(1, 10, num=10).astype(int) | ||
ds = tf.data.Dataset.from_tensor_slices({"image": x, "label": y}) | ||
|
||
# Run | ||
assert os.path.isdir(savefolder.name) | ||
assert not tf.executing_eagerly() | ||
tfio.experimental.serialization.save_dataset( | ||
ds, tfrecord_path=tfrecord_path, header_path=header_path | ||
) | ||
|
||
|
||
def test_ensure_graph_fail(): | ||
"""Test that super_serial fails in graph mode.""" | ||
with pytest.raises(ValueError): | ||
graph_save_fail() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If only eager mode is supported, then maybe you can check whether tf is executing in eager mode and raise an exception if it isn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. It will raise an exception already if tf is not in eager mode, and if no exceptions are raised then it will run fine. In general I'm not in favor of explicitly adding exceptions like that because I've seen it cause unnecessary exceptions in other packages, but if that's the practice here I'm happy to add it.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we ensure this in a different test case? The test would switch to graph mode and save the dataset and assert that the respective exception is raised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll add that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kvignesh1420 I added the test you requested as
test_ensure_graph_fail()
intests/test_serial_ops.py
.