diff --git a/docs/applications/implementations/models.md b/docs/applications/implementations/models.md index 21bf0ad21c..4262c0edb4 100644 --- a/docs/applications/implementations/models.md +++ b/docs/applications/implementations/models.md @@ -48,6 +48,8 @@ def create_estimator(run_config, model_config): ## Pre-installed Packages +You can import PyPI packages or your own Python packages to help create more complex models. See [Python Packages](../advanced/python-packages.md) for more details. + The following packages have been pre-installed and can be used in your implementations: ```text @@ -60,3 +62,41 @@ packaging==19.0.0 ``` You can install additional PyPI packages and import your own Python packages. See [Python Packages](../advanced/python-packages.md) for more details. + + +# Tensorflow Transformations +You can preprocess input features and labels to your model by defining a `transform_tensorflow` function. You can define tensor transformations you want to apply to the features and labels tensors before they are passed to the model. + +## Implementation + +```python +def transform_tensorflow(features, labels, model_config): + """Define tensor transformations for the feature and label tensors. + + Args: + features: A feature dictionary of column names to feature tensors. + + labels: The label tensor. + + model_config: The Cortex configuration for the model. + Note: nested resources are expanded (e.g. model_config["target_column"]) + will be the configuration for the target column, rather than the + name of the target column). + + + Returns: + features and labels tensors. + """ + return features, labels +``` + +## Example + +```python +import tensorflow as tf + +def transform_tensorflow(features, labels, model_config): + hparams = model_config["hparams"] + features["image_pixels"] = tf.reshape(features["image_pixels"], hparams["input_shape"]) + return features, labels +``` diff --git a/examples/mnist/implementations/models/basic.py b/examples/mnist/implementations/models/custom.py similarity index 88% rename from examples/mnist/implementations/models/basic.py rename to examples/mnist/implementations/models/custom.py index 0fe17eaed6..3583113af5 100644 --- a/examples/mnist/implementations/models/basic.py +++ b/examples/mnist/implementations/models/custom.py @@ -5,10 +5,7 @@ def create_estimator(run_config, model_config): hparams = model_config["hparams"] def model_fn(features, labels, mode, params): - images = features["image_pixels"] - images = tf.reshape(images, [-1] + hparams["input_shape"]) - x = images - + x = features["image_pixels"] for i, feature_count in enumerate(hparams["hidden_units"]): with tf.variable_scope("layer_%d" % i): if hparams["layer_type"] == "conv": @@ -55,3 +52,10 @@ def model_fn(features, labels, mode, params): estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) return estimator + + +def transform_tensorflow(features, labels, model_config): + hparams = model_config["hparams"] + + features["image_pixels"] = tf.reshape(features["image_pixels"], hparams["input_shape"]) + return features, labels diff --git a/examples/mnist/implementations/models/t2t.py b/examples/mnist/implementations/models/t2t.py new file mode 100644 index 0000000000..53ce2dfa57 --- /dev/null +++ b/examples/mnist/implementations/models/t2t.py @@ -0,0 +1,40 @@ +import tensorflow as tf +from tensor2tensor.utils import trainer_lib +from tensor2tensor import models # pylint: disable=unused-import +from tensor2tensor import problems # pylint: disable=unused-import +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.utils import registry + + +def create_estimator(run_config, model_config): + # t2t expects these keys in run_config + run_config.data_parallelism = None + run_config.t2t_device_info = {"num_async_replicas": 1} + + # t2t has its own set of hyperparameters we can use + hparams = trainer_lib.create_hparams("basic_fc_small") + problem = registry.problem("image_mnist") + p_hparams = problem.get_hparams(hparams) + hparams.problem = problem + hparams.problem_hparams = p_hparams + + # don't need eval_metrics + problem.eval_metrics = lambda: [] + + # t2t expects this key + hparams.warm_start_from = None + + estimator = trainer_lib.create_estimator("basic_fc_relu", hparams, run_config) + return estimator + + +def transform_tensorflow(features, labels, model_config): + hparams = model_config["hparams"] + + # t2t model performs flattening and expects this input key + features["inputs"] = tf.reshape(features["image_pixels"], hparams["input_shape"]) + + # t2t expects this key and dimensionality + features["targets"] = tf.expand_dims(labels, 0) + + return features, labels diff --git a/examples/mnist/requirements.txt b/examples/mnist/requirements.txt index 19dd3f52d4..cfbf63288b 100644 --- a/examples/mnist/requirements.txt +++ b/examples/mnist/requirements.txt @@ -1 +1,2 @@ pillow==5.4.1 +tensor2tensor==1.10.0 diff --git a/examples/mnist/resources/apis.yaml b/examples/mnist/resources/apis.yaml index b66fb6b8fa..3bab92bd2a 100644 --- a/examples/mnist/resources/apis.yaml +++ b/examples/mnist/resources/apis.yaml @@ -1,6 +1,6 @@ - kind: api - name: dense-classifier - model_name: dense + name: dnn-classifier + model_name: dnn compute: replicas: 1 @@ -9,3 +9,9 @@ model_name: conv compute: replicas: 1 + +- kind: api + name: t2t-classifier + model_name: t2t + compute: + replicas: 1 diff --git a/examples/mnist/resources/models.yaml b/examples/mnist/resources/models.yaml index ab8a95ddc5..a789ab9d0c 100644 --- a/examples/mnist/resources/models.yaml +++ b/examples/mnist/resources/models.yaml @@ -1,26 +1,22 @@ - kind: model - name: dense - path: implementations/models/basic.py + name: dnn + path: implementations/models/dnn.py type: classification target_column: label feature_columns: - image_pixels hparams: - layer_type: basic learning_rate: 0.01 input_shape: [784] output_shape: [10] - hidden_units: [100, 200, 10] + hidden_units: [100, 200] data_partition_ratio: training: 0.7 evaluation: 0.3 - training: - batch_size: 64 - num_epochs: 5 - kind: model name: conv - path: implementations/models/basic.py + path: implementations/models/custom.py type: classification target_column: label feature_columns: @@ -30,7 +26,7 @@ learning_rate: 0.01 input_shape: [28, 28, 1] output_shape: [10] - kernel_size: 2 + kernel_size: 4 hidden_units: [10, 10, 10] data_partition_ratio: training: 0.7 @@ -39,18 +35,17 @@ batch_size: 64 num_epochs: 5 + - kind: model - name: dnn - path: implementations/models/dnn.py + name: t2t + path: implementations/models/t2t.py type: classification target_column: label feature_columns: - image_pixels + prediction_key: outputs hparams: - learning_rate: 0.01 input_shape: [28, 28, 1] - output_shape: [10] - hidden_units: [100, 200] data_partition_ratio: training: 0.7 evaluation: 0.3 diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index a164dc94bb..3387390912 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -460,7 +460,8 @@ def resource_status_key(self, resource): MODEL_IMPL_VALIDATION = { - "required": [{"name": "create_estimator", "args": ["run_config", "model_config"]}] + "required": [{"name": "create_estimator", "args": ["run_config", "model_config"]}], + "optional": [{"name": "transform_tensorflow", "args": ["features", "labels", "model_config"]}], } AGGREGATOR_IMPL_VALIDATION = { diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 910f660dd5..e52ed51e66 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -89,12 +89,11 @@ def transform_sample(sample): def create_prediction_request(transformed_sample): ctx = local_cache["ctx"] - + signatureDef = local_cache["metadata"]["signatureDef"] + signature_key = list(signatureDef.keys())[0] prediction_request = predict_pb2.PredictRequest() prediction_request.model_spec.name = "default" - prediction_request.model_spec.signature_name = list( - local_cache["metadata"]["signatureDef"].keys() - )[0] + prediction_request.model_spec.signature_name = signature_key for column_name, value in transformed_sample.items(): data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[ctx.columns[column_name]["type"]] diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index a45031afb0..db141494e0 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -33,6 +33,24 @@ def get_input_placeholder(model_name, ctx, training=True): return input_placeholder +def get_label_placeholder(model_name, ctx): + model = ctx.models[model_name] + + target_column_name = model["target_column"] + column_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[ctx.columns[target_column_name]["type"]] + return tf.placeholder(shape=[None], dtype=column_type) + + +def get_transform_tensor_fn(ctx, model_impl, model_name): + model = ctx.models[model_name] + model_config = ctx.model_config(model["name"]) + + def transform_tensor_fn_wrapper(inputs, labels): + return model_impl.transform_tensorflow(inputs, labels, model_config) + + return transform_tensor_fn_wrapper + + def generate_example_parsing_fn(model_name, ctx, training=True): model = ctx.models[model_name] @@ -47,7 +65,7 @@ def _parse_example(example_proto): # Mode must be "training" or "evaluation" -def generate_input_fn(model_name, ctx, mode): +def generate_input_fn(model_name, ctx, mode, model_impl): model = ctx.models[model_name] filenames = ctx.get_training_data_parts(model_name, mode) @@ -66,6 +84,9 @@ def _input_fn(): if model[mode]["shuffle"]: dataset = dataset.shuffle(buffer_size) + if hasattr(model_impl, "transform_tensorflow"): + dataset = dataset.map(get_transform_tensor_fn(ctx, model_impl, model_name)) + dataset = dataset.batch(model[mode]["batch_size"]) dataset = dataset.prefetch(buffer_size) dataset = dataset.repeat() @@ -77,27 +98,19 @@ def _input_fn(): return _input_fn -def generate_json_serving_input_fn(model_name, ctx): +def generate_json_serving_input_fn(model_name, ctx, model_impl): def _json_serving_input_fn(): inputs = get_input_placeholder(model_name, ctx, training=False) - features = {key: tf.expand_dims(tensor, -1) for key, tensor in inputs.items()} - return tf.estimator.export.ServingInputReceiver(features=features, receiver_tensors=inputs) - - return _json_serving_input_fn + labels = get_label_placeholder(model_name, ctx) + features = {key: tensor for key, tensor in inputs.items()} + if hasattr(model_impl, "transform_tensorflow"): + features, _ = get_transform_tensor_fn(ctx, model_impl, model_name)(features, labels) -def generate_example_serving_input_fn(model_name, ctx): - def _example_serving_input_fn(): - feature_spec = tf_lib.get_feature_spec(model_name, ctx, training=False) - example_bytestring = tf.placeholder(shape=[None], dtype=tf.string) - feature_scalars = tf.parse_single_example(example_bytestring, feature_spec) - features = {key: tf.expand_dims(tensor, -1) for key, tensor in feature_scalars.items()} - - return tf.estimator.export.ServingInputReceiver( - features=features, receiver_tensors={"example_proto": example_bytestring} - ) + features = {key: tf.expand_dims(tensor, 0) for key, tensor in features.items()} + return tf.estimator.export.ServingInputReceiver(features=features, receiver_tensors=inputs) - return _example_serving_input_fn + return _json_serving_input_fn def get_regression_eval_metrics(labels, predictions): @@ -130,9 +143,9 @@ def train(model_name, model_impl, ctx, model_dir): model_dir=model_dir, ) - train_input_fn = generate_input_fn(model_name, ctx, "training") - eval_input_fn = generate_input_fn(model_name, ctx, "evaluation") - serving_input_fn = generate_json_serving_input_fn(model_name, ctx) + train_input_fn = generate_input_fn(model_name, ctx, "training", model_impl) + eval_input_fn = generate_input_fn(model_name, ctx, "evaluation", model_impl) + serving_input_fn = generate_json_serving_input_fn(model_name, ctx, model_impl) exporter = tf.estimator.FinalExporter("estimator", serving_input_fn, as_text=False) dataset_metadata = aws.read_json_from_s3(model["dataset"]["metadata_key"], ctx.bucket)