-
Notifications
You must be signed in to change notification settings - Fork 608
Tensor2Tensor Example and transform_tensorflow feature #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
eb77e7b
mnist basic t2t model
1vn b2448b8
merge master
1vn 9917cd7
add newline
1vn 8edcb58
fix prediction time shaping
1vn a6e0f47
clean reviews example
1vn df24557
if undefined shape, take the length
1vn 6202511
Merge branch 'master' into t2t-example
1vn b5c60a3
add numpy to api image
1vn 7145df8
remove numpy dep, dont restrict unspecified python pkgs
1vn e4a01c9
add TODO comment to address later
1vn f8c128d
clean up
1vn 17144a2
clean up example and transform tensor api
1vn e73d95c
transform_tensors -> transform_tensorflow
1vn 99e0b2a
add back dnn
1vn ad0be81
add back dnn
1vn e9e7c92
fix example
1vn 9be820a
remove TODO
1vn 5b5263b
add docs
1vn 71ea67f
address comments
1vn e7b62f1
merge master
1vn 9a9b6b9
remove commented code
1vn b8903ba
clean up extra line
1vn b035040
add transform_tensorflow to model_impl check
1vn 08c097b
format
1vn a1a8a4d
remove extra new line
1vn f44d9f2
update docs
1vn 0710af4
update mnist conv model
1vn 3c1d6a1
address doc comments
1vn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import tensorflow as tf | ||
from tensor2tensor.utils import trainer_lib | ||
from tensor2tensor import models # pylint: disable=unused-import | ||
from tensor2tensor import problems # pylint: disable=unused-import | ||
from tensor2tensor.data_generators import problem_hparams | ||
from tensor2tensor.utils import registry | ||
|
||
|
||
def create_estimator(run_config, model_config): | ||
# t2t expects these keys in run_config | ||
run_config.data_parallelism = None | ||
run_config.t2t_device_info = {"num_async_replicas": 1} | ||
|
||
# t2t has its own set of hyperparameters we can use | ||
hparams = trainer_lib.create_hparams("basic_fc_small") | ||
1vn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
problem = registry.problem("image_mnist") | ||
p_hparams = problem.get_hparams(hparams) | ||
hparams.problem = problem | ||
hparams.problem_hparams = p_hparams | ||
|
||
# don't need eval_metrics | ||
problem.eval_metrics = lambda: [] | ||
|
||
# t2t expects this key | ||
hparams.warm_start_from = None | ||
|
||
estimator = trainer_lib.create_estimator("basic_fc_relu", hparams, run_config) | ||
return estimator | ||
|
||
|
||
def transform_tensorflow(features, labels, model_config): | ||
hparams = model_config["hparams"] | ||
|
||
# t2t model performs flattening and expects this input key | ||
features["inputs"] = tf.reshape(features["image_pixels"], hparams["input_shape"]) | ||
|
||
# t2t expects this key and dimensionality | ||
features["targets"] = tf.expand_dims(labels, 0) | ||
|
||
return features, labels | ||
deliahu marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
pillow==5.4.1 | ||
tensor2tensor==1.10.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.