From 4ceee6425298a563dd750f91de8b47c6fded0497 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Mon, 17 Apr 2017 22:59:02 -0700 Subject: [PATCH 1/7] support for strings --- src/main/python/tensorframes/core.py | 5 +- src/main/python/tensorframes/core_test.py | 15 +- .../preparation_inceptionv3.py | 189 ++++++++++++++++++ .../org/tensorframes/ColumnInformation.scala | 35 ++-- .../tensorframes/ExperimentalOperations.scala | 6 +- .../org/tensorframes/MetadataConstants.scala | 6 +- src/main/scala/org/tensorframes/Shape.scala | 21 +- .../scala/org/tensorframes/dsl/DslImpl.scala | 13 +- .../scala/org/tensorframes/dsl/package.scala | 6 +- .../scala/org/tensorframes/impl/DataOps.scala | 5 +- .../org/tensorframes/impl/DebugRowOps.scala | 41 +++- .../org/tensorframes/impl/DenseTensor.scala | 20 +- .../org/tensorframes/impl/TFDataOps.scala | 2 +- .../org/tensorframes/impl/TensorFlowOps.scala | 6 +- .../org/tensorframes/impl/datatypes.scala | 135 ++++++++++++- .../scala/org/tensorframes/test/dsl.scala | 26 ++- .../org/tensorframes/DebugRowOpsSuite.scala | 10 +- .../tensorframes/ExtraOperationsSuite.scala | 19 +- .../perf/ConvertBackPerformanceSuite.scala | 6 +- .../perf/ConvertPerformanceSuite.scala | 4 +- 20 files changed, 467 insertions(+), 103 deletions(-) create mode 100644 src/main/python/tensorframes_snippets/preparation_inceptionv3.py diff --git a/src/main/python/tensorframes/core.py b/src/main/python/tensorframes/core.py index c148551..17d70c4 100644 --- a/src/main/python/tensorframes/core.py +++ b/src/main/python/tensorframes/core.py @@ -43,7 +43,8 @@ def _add_graph(graph, builder, use_file=True): fname = d + "/proto.pb" builder.graphFromFile(fname) else: - gser = graph.as_graph_def().SerializeToString() + # Make sure that TF adds the shapes. + gser = graph.as_graph_def(add_shapes=True).SerializeToString() gbytes = bytearray(gser) builder.graph(gbytes) @@ -55,7 +56,7 @@ def _add_shapes(graph, builder, fetches): # dimensions are unknown ph_names = [] ph_shapes = [] - for n in graph.as_graph_def().node: + for n in graph.as_graph_def(add_shapes=True).node: # Just the input nodes: if not n.input: op_name = n.name diff --git a/src/main/python/tensorframes/core_test.py b/src/main/python/tensorframes/core_test.py index 419ebf2..cb6e8a3 100644 --- a/src/main/python/tensorframes/core_test.py +++ b/src/main/python/tensorframes/core_test.py @@ -14,7 +14,9 @@ class TestCore(object): @classmethod def setup_class(cls): print("setup ", cls) - cls.sc = SparkContext('local[1]', cls.__name__) + sc = SparkContext('local[1]', cls.__name__) + sc.setLogLevel('DEBUG') + cls.sc = sc @classmethod def teardown_class(cls): @@ -25,6 +27,7 @@ def setUp(self): self.sql = SQLContext(TestCore.sc) self.api = _java_api() self.api.initialize_logging() + TestCore.sc.setLogLevel('INFO') print("setup") @@ -126,6 +129,16 @@ def test_groupby_1(self): data2 = df2.collect() assert data2 == [Row(key='0', x=2.0), Row(key='1', x=4.0)], data2 + def test_byte_array(self): + data = [Row(x=bytearray('123', 'utf-8'))] + df = self.sql.createDataFrame(data) + with tf.Graph().as_default(): + x = tf.placeholder(tf.string, shape=[], name="x") + z = tf.string_to_number(x, tf.int32, name='z') + df2 = tfs.map_rows(z, df) + data2 = df2.collect() + assert data2[0].z == 123, data2 + if __name__ == "__main__": # Some testing stuff that should not be executed diff --git a/src/main/python/tensorframes_snippets/preparation_inceptionv3.py b/src/main/python/tensorframes_snippets/preparation_inceptionv3.py new file mode 100644 index 0000000..0276373 --- /dev/null +++ b/src/main/python/tensorframes_snippets/preparation_inceptionv3.py @@ -0,0 +1,189 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from datetime import datetime +import math +import os.path +import time + + +import numpy as np +import tensorflow as tf + +from preprocessing import inception_preprocessing +import datasets.imagenet as imagenet +from nets import inception +import datasets.dataset_utils as dataset_utils + +import tensorflow as tf +from tensorflow.python.training import saver as tf_saver +from tensorflow.python.framework import graph_util + +slim = tf.contrib.slim + +default_image_size = 299 + + +####### Download the network data +# The URL of the checkpointed data. +url = "http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz" +# The name of the checkpoint file: +checkpoint_file = 'inception_v3.ckpt' +# Specify where you want to download the model to +checkpoints_dir = '/tmp/checkpoints' + +checkpoint_path = os.path.join(checkpoints_dir, checkpoint_file) +if not tf.gfile.Exists(checkpoints_dir): + tf.gfile.MakeDirs(checkpoints_dir) + +if not tf.gfile.Exists(checkpoint_path): + print('Downloading the model...') + dataset_utils.download_and_uncompress_tarball(url, checkpoints_dir) + +#### TEST TO REMOVE + +s = tf.constant("This is string") +r = tf.decode_raw(s, tf.int8) +s2 = tf.as_string(r) +sess = tf.InteractiveSession() +print(s2.eval()) + +###### Building the computation graph + +# All this code can be run once. It assembles the computation graph, fills it with the checkpointed +# coefficients, and then saves it as a protocol buffer description. + +# Build the graph +g = tf.Graph() +with g.as_default(): + # Keep for now a placeholder that will eventually be filled with the content of the image. + # This code only accepts JPEG images, which is the most common image format. + image_string = tf.placeholder(tf.string, [], name="image_input") + + # Decode string into matrix with intensity values + image = tf.image.decode_jpeg(image_string, channels=3) + + # Resize the input image, preserving the aspect ratio + # and make a central crop of the resulted image. + # The crop will be of the size of the default image size of + # the network. + processed_image = inception_preprocessing.preprocess_image(image, + default_image_size, + default_image_size, + is_training=False) + + # Networks accept images in batches. + # The first dimension usually represents the batch size. + # In our case the batch size is one. + processed_images = tf.expand_dims(processed_image, 0) + + # Create the model, use the default arg scope to configure + # the batch norm parameters. arg_scope is a very conveniet + # feature of slim library -- you can define default + # parameters for layers -- like stride, padding etc. + # Note: like the Arabian nights, inception defines 1001 classes + # to include a background class (the first). + with slim.arg_scope(inception.inception_v3_arg_scope()): + logits, _ = inception.inception_v3(processed_images, + num_classes=1001, + is_training=False) + + # In order to get probabilities we apply softmax on the output. + probabilities = tf.nn.softmax(logits) + + # Just focus on the top predictions + top_pred = tf.nn.top_k(tf.squeeze(probabilities), k=5, name="top_predictions") + + # These are the outputs we will be requesting from the network. + output_nodes = [probabilities, top_pred.indices, top_pred.values] + +# Create the saver +with g.as_default(): + model_variables = slim.get_model_variables('InceptionV3') + saver = tf_saver.Saver(model_variables, reshape=False) + +def get_op_name(tensor): + return tensor.name.split(":")[0] + +# Export the network +with g.as_default(): + with tf.Session() as sess: + saver.restore(sess, checkpoint_path) + # The add_shapes option is important: Spark requires this extra shape information to infor the + # correct types. + input_graph_def = g.as_graph_def(add_shapes=True) + output_tensor_names = [node.name for node in output_nodes] + output_node_names = [n.split(":")[0] for n in output_tensor_names] + output_graph_def = graph_util.convert_variables_to_constants( + sess, + input_graph_def, + output_node_names, + variable_names_blacklist=[]) + +# The variable 'output_graph_def' now contains all the description of the computation. +# The variables in the 'output_nodes' list will be used to know what to output. + +####### Testing the computation graph + +# This code performs a sanity check, by running the network against some image content downloaded from the internet. + +g2 = tf.Graph() +with g2.as_default(): + tf.import_graph_def(output_graph_def, name='') + +#### Download an image +import requests + +# Example picture: +# Specify where you want to download the model to +images_dir = '/tmp/image_data' + +if not tf.gfile.Exists(images_dir): + tf.gfile.MakeDirs(images_dir) + +image_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/8/85/WeaverAntDefense.JPG/640px-WeaverAntDefense.JPG' +image_url = 'https://www.tensorflow.org/images/cropped_panda.jpg' +image_path = os.path.join(images_dir, image_url.split('/')[-1]) + +if not tf.gfile.Exists(image_path): + response = requests.get(image_url) + if response.status_code == 200: + with open(image_path, 'wb') as f: + f.write(response.content) + +image_data = tf.gfile.FastGFile(image_path, 'rb').read() + +with g2.as_default(): + input_node2 = g2.get_operation_by_name(get_op_name(image)) + output_nodes2 = [g2.get_tensor_by_name(n) for n in output_tensor_names] + with tf.Session() as sess: + (probabilities_, indices_, values_) = sess.run(output_nodes2, {'image_input:0':image_data}) + +names = imagenet.create_readable_names_for_imagenet_labels() +for i in range(5): + index = indices_[i] + print('Probability %d %0.2f => [%s]' % (index, values_[i], names[index])) + + +###### Perform some evaluation with TensorFrames + +# This code takes the network and a directory that contains some image content. It shows how to process the content +# using Spark dataframes and Tensorframes. + +import tensorframes as tfs +sc.setLogLevel('INFO') + +raw_images_miscast = sc.binaryFiles("file:"+images_dir) +raw_images = raw_images_miscast.map(lambda x: (x[0], bytearray(x[1]))) + +df = spark.createDataFrame(raw_images).toDF('image_uri', 'image_data') +df + +with g2.as_default(): + index_output = tf.identity(g2.get_tensor_by_name('top_predictions:1'), name="index") + value_output = tf.identity(g2.get_tensor_by_name('top_predictions:0'), name="value") + pred_df = tfs.map_rows([index_output, value_output], df, feed_dict={'image_input':'image_data'}) + +pred_df.select('index', 'value').head() + diff --git a/src/main/scala/org/tensorframes/ColumnInformation.scala b/src/main/scala/org/tensorframes/ColumnInformation.scala index d06a76c..9e3402b 100644 --- a/src/main/scala/org/tensorframes/ColumnInformation.scala +++ b/src/main/scala/org/tensorframes/ColumnInformation.scala @@ -1,6 +1,7 @@ package org.tensorframes import org.apache.spark.sql.types._ +import org.tensorframes.impl.{ScalarType, SupportedOperations} class ColumnInformation private ( @@ -15,7 +16,9 @@ class ColumnInformation private ( val b = new MetadataBuilder().withMetadata(field.metadata) for (info <- stf) { b.putLongArray(shapeKey, info.shape.dims.toArray) - b.putString(tensorStructType, info.dataType.toString) + // Keep the SQL name, so that we do not leak internal details. + val dt = SupportedOperations.opsFor(info.dataType).sqlType + b.putString(tensorStructType, dt.toString) } val meta = b.build() field.copy(metadata = meta) @@ -73,15 +76,15 @@ object ColumnInformation extends Logging { * @param scalarType the data type * @param blockShape the shape of the block */ - def structField(name: String, scalarType: NumericType, blockShape: Shape): StructField = { + def structField(name: String, scalarType: ScalarType, blockShape: Shape): StructField = { val i = SparkTFColInfo(blockShape, scalarType) val f = StructField(name, sqlType(scalarType, blockShape.tail), nullable = false) ColumnInformation(f, i).merged } - private def sqlType(scalarType: NumericType, shape: Shape): DataType = { + private def sqlType(scalarType: ScalarType, shape: Shape): DataType = { if (shape.dims.isEmpty) { - scalarType + SupportedOperations.opsFor(scalarType).sqlType } else { ArrayType(sqlType(scalarType, shape.tail), containsNull = false) } @@ -102,11 +105,14 @@ object ColumnInformation extends Logging { for { s <- shape t <- tpe - } yield SparkTFColInfo(s, t) + ops <- SupportedOperations.getOps(t) + } yield SparkTFColInfo(s, ops.scalarType) } - private def getType(s: String): Option[NumericType] = { - supportedTypes.find(_.toString == s) + private def getType(s: String): Option[DataType] = { + val res = supportedTypes.find(_.toString == s) + logInfo(s"getType: $s -> $res") + res } /** @@ -115,19 +121,18 @@ object ColumnInformation extends Logging { * @return */ private def extractFromRow(dt: DataType): Option[SparkTFColInfo] = dt match { - case x: NumericType if MetadataConstants.supportedTypes.contains(dt) => - logTrace("numerictype: " + x) - // It is a basic type that we understand - Some(SparkTFColInfo(Shape(Unknown), x)) case x: ArrayType => logTrace("arraytype: " + x) // Look into the array to figure out the type. extractFromRow(x.elementType).map { info => SparkTFColInfo(info.shape.prepend(Unknown), info.dataType) } - case _ => - logTrace("not understood: " + dt) - // Not understood. - None + case _ => SupportedOperations.getOps(dt) match { + case Some(ops) => + logTrace("numerictype: " + ops.scalarType) + // It is a basic type that we understand + Some(SparkTFColInfo(Shape(Unknown), ops.scalarType)) + case None => None + } } } diff --git a/src/main/scala/org/tensorframes/ExperimentalOperations.scala b/src/main/scala/org/tensorframes/ExperimentalOperations.scala index a622104..87aed4e 100644 --- a/src/main/scala/org/tensorframes/ExperimentalOperations.scala +++ b/src/main/scala/org/tensorframes/ExperimentalOperations.scala @@ -3,7 +3,7 @@ package org.tensorframes import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{ArrayType, DataType, NumericType} -import org.tensorframes.impl.SupportedOperations +import org.tensorframes.impl.{ScalarType, SupportedOperations} /** * Some useful methods for operating on dataframes that are not part of the official API (and thus may change anytime). @@ -109,8 +109,8 @@ private[tensorframes] object ExtraOperations extends ExperimentalOperations with DataFrameInfo(allInfo) } - private def extractBasicType(dt: DataType): Option[NumericType] = dt match { - case x: NumericType => Some(x) + private def extractBasicType(dt: DataType): Option[ScalarType] = dt match { + case x: NumericType => Some(SupportedOperations.opsFor(x).scalarType) case x: ArrayType => extractBasicType(x.elementType) case _ => None } diff --git a/src/main/scala/org/tensorframes/MetadataConstants.scala b/src/main/scala/org/tensorframes/MetadataConstants.scala index affec0d..d0aea61 100644 --- a/src/main/scala/org/tensorframes/MetadataConstants.scala +++ b/src/main/scala/org/tensorframes/MetadataConstants.scala @@ -1,7 +1,7 @@ package org.tensorframes -import org.apache.spark.sql.types.NumericType -import org.tensorframes.impl.SupportedOperations +import org.apache.spark.sql.types.{DataType, NumericType} +import org.tensorframes.impl.{ScalarType, SupportedOperations} /** * Metadata annotations that get embedded in dataframes to express tensor information. @@ -29,5 +29,5 @@ object MetadataConstants { /** * All the SQL types supported by SparkTF. */ - val supportedTypes: Seq[NumericType] = SupportedOperations.sqlTypes + val supportedTypes: Seq[DataType] = SupportedOperations.sqlTypes } \ No newline at end of file diff --git a/src/main/scala/org/tensorframes/Shape.scala b/src/main/scala/org/tensorframes/Shape.scala index b7d9859..6eced36 100644 --- a/src/main/scala/org/tensorframes/Shape.scala +++ b/src/main/scala/org/tensorframes/Shape.scala @@ -1,9 +1,11 @@ package org.tensorframes -import org.apache.spark.sql.types.NumericType +import org.apache.spark.sql.types.{BinaryType, DataType, NumericType} import org.tensorflow.framework.TensorShapeProto + import scala.collection.JavaConverters._ import org.tensorframes.Shape.DimType +import org.tensorframes.impl.ScalarType import org.{tensorflow => tf} @@ -36,6 +38,11 @@ class Shape private (private val ds: Array[DimType]) extends Serializable { def prepend(x: Int): Shape = Shape(x.toLong +: ds) + /** + * Drops the most inner dimension of the shape. + */ + def dropInner: Shape = Shape(ds.dropRight(1)) + /** * A shape with the first dimension dropped. */ @@ -105,14 +112,22 @@ object Shape { /** * SparkTF information. This is the information generally required to work on a tensor. * @param shape - * @param dataType + * @param dataType the datatype of the scalar. Note that it is either NumericType or BinaryType. */ // TODO(tjh) the types supported by TF are much richer (uint8, etc.) but it is not clear // if they all map to a Catalyst memory representation // TODO(tjh) support later basic structures for sparse types? case class SparkTFColInfo( shape: Shape, - dataType: NumericType) extends Serializable + dataType: ScalarType) extends Serializable { + +// // Forces a cast to a numeric type, which may fail. +// // TODO: try to use an atomic type instead? +// def numericType: NumericType = dataType match { +// case x: NumericType => x +// case _ => throw new Exception(s"$dataType cannot be cast to a numeric type") +// } +} /** * Exception thrown when the user requests tensors of high order. diff --git a/src/main/scala/org/tensorframes/dsl/DslImpl.scala b/src/main/scala/org/tensorframes/dsl/DslImpl.scala index 3795c9f..6e82c0e 100644 --- a/src/main/scala/org/tensorframes/dsl/DslImpl.scala +++ b/src/main/scala/org/tensorframes/dsl/DslImpl.scala @@ -1,13 +1,12 @@ package org.tensorframes.dsl import javax.annotation.Nullable -import org.tensorflow.framework.{AttrValue, DataType, GraphDef, TensorShapeProto} +import org.tensorflow.framework.{AttrValue, DataType, GraphDef, TensorShapeProto} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.NumericType - -import org.tensorframes.{Logging, ColumnInformation, Shape} -import org.tensorframes.impl.DenseTensor +import org.tensorframes.{ColumnInformation, Logging, Shape} +import org.tensorframes.impl.{DenseTensor, SupportedOperations} /** @@ -75,8 +74,9 @@ private[dsl] object DslImpl extends Logging with DefaultConversions { def build_constant(dt: DenseTensor): Node = { val a = AttrValue.newBuilder().setTensor(DenseTensor.toTensorProto(dt)) + val dt2 = SupportedOperations.opsFor(dt.dtype).sqlType.asInstanceOf[NumericType] build("Const", isOp = false, - shape = dt.shape, dtype = dt.dtype, + shape = dt.shape, dtype = dt2, extraAttrs = Map("value" -> a.build())) } @@ -100,7 +100,8 @@ private[dsl] object DslImpl extends Logging with DefaultConversions { s"tensorframes: $schema") } val shape = if (block) { stf.shape } else { stf.shape.tail } - DslImpl.placeholder(stf.dataType, shape).named(tfName) + val dt = SupportedOperations.opsFor(stf.dataType).sqlType.asInstanceOf[NumericType] + DslImpl.placeholder(dt, shape).named(tfName) } private def commonShape(shapes: Seq[Shape]): Shape = { diff --git a/src/main/scala/org/tensorframes/dsl/package.scala b/src/main/scala/org/tensorframes/dsl/package.scala index adfa39a..2a787d6 100644 --- a/src/main/scala/org/tensorframes/dsl/package.scala +++ b/src/main/scala/org/tensorframes/dsl/package.scala @@ -1,10 +1,8 @@ package org.tensorframes import scala.reflect.runtime.universe.TypeTag - import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.IntegerType - +import org.apache.spark.sql.types.{IntegerType, NumericType} import org.tensorframes.impl.SupportedOperations /** @@ -45,7 +43,7 @@ package object dsl { def placeholder[T : Numeric : TypeTag](shape: Int*): Operation = { val ops = SupportedOperations.getOps[T]() - DslImpl.placeholder(ops.sqlType, Shape(shape: _*)) + DslImpl.placeholder(ops.sqlType.asInstanceOf[NumericType], Shape(shape: _*)) } def constant[T : ConvertibleToDenseTensor](x: T): Operation = { diff --git a/src/main/scala/org/tensorframes/impl/DataOps.scala b/src/main/scala/org/tensorframes/impl/DataOps.scala index 6e60f1a..a1e7d8d 100644 --- a/src/main/scala/org/tensorframes/impl/DataOps.scala +++ b/src/main/scala/org/tensorframes/impl/DataOps.scala @@ -2,10 +2,11 @@ package org.tensorframes.impl import scala.collection.mutable import scala.reflect.ClassTag +import org.{tensorflow => tf} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types.{StructType} import org.tensorframes.{Logging, Shape} import org.tensorframes.Shape.DimType @@ -145,7 +146,7 @@ object DataOps extends Logging { def getColumnFast0( reshapeShape: Shape, - scalaType: NumericType, + scalaType: ScalarType, allDataBuffer: mutable.WrappedArray[_]): Iterable[Any] = { reshapeShape.dims match { case Seq() => diff --git a/src/main/scala/org/tensorframes/impl/DebugRowOps.scala b/src/main/scala/org/tensorframes/impl/DebugRowOps.scala index e9dbd32..db413fc 100644 --- a/src/main/scala/org/tensorframes/impl/DebugRowOps.scala +++ b/src/main/scala/org/tensorframes/impl/DebugRowOps.scala @@ -267,6 +267,24 @@ private[impl] trait SchemaTransforms extends Logging { case _ => f // Nothing to do } } + + /** + * Checks that the data, coming with a certain shape from Spark, can be shaped into + * the given shape (taking unknown values and binary data into account) + */ + def canBeReshapedTo(dt: ScalarType, from: Shape, to: Shape): Boolean = { +// if (dt == ScalarBinaryType) { +// // Binary has to be at least an array. +// if (from.numDims == 0) { +// return false +// } +// // In that case, the spark shape should be one larger, and we should drop the bottom one. +// from.dropInner.checkMorePreciseThan(to) +// } else { +// from.checkMorePreciseThan(to) +// } + from.checkMorePreciseThan(to) + } } object SchemaTransforms extends SchemaTransforms @@ -322,17 +340,17 @@ class DebugRowOps throw new Exception( s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe") } - if (! stf.shape.checkMorePreciseThan(in.shape)) { - throw new Exception( - s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + - s" ${in.shape} requested by the TF graph") - } // We do not support autocasting for now. if (stf.dataType != in.scalarType) { throw new Exception( s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " + s"of the column (${in.scalarType})") } + if (! canBeReshapedTo(stf.dataType, stf.shape, in.shape)) { + throw new Exception( + s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + + s" ${in.shape} requested by the TF graph") + } // The input has to be either a constant or a placeholder if (! in.isPlaceholder) { throw new Exception( @@ -414,16 +432,16 @@ class DebugRowOps val stf = get(ColumnInformation(f).stf, s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe") + check(stf.dataType == in.scalarType, + s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " + + s"of the column (${in.scalarType})") + val cellShape = stf.shape.tail // No check for unknowns: we allow unknowns in the first dimension of the cell shape. - check(cellShape.checkMorePreciseThan(in.shape), + check(canBeReshapedTo(stf.dataType, cellShape, in.shape), s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + s" ${in.shape} requested by the TF graph") - check(stf.dataType == in.scalarType, - s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " + - s"of the column (${in.scalarType})") - check(in.isPlaceholder, s"Invalid type for input node ${in.name}. It has to be a placeholder") } @@ -532,7 +550,8 @@ class DebugRowOps val f = col.field builder.append(s"$prefix-- ${f.name}: ${f.dataType.typeName} (nullable = ${f.nullable})") val stf = col.stf.map { s => - s" ${s.dataType.typeName}${s.shape}" + val dt = SupportedOperations.opsFor(s.dataType).sqlType + s" ${dt.typeName}${s.shape}" } .getOrElse(" ") builder.append(stf) builder.append("\n") diff --git a/src/main/scala/org/tensorframes/impl/DenseTensor.scala b/src/main/scala/org/tensorframes/impl/DenseTensor.scala index 7414f73..d9e30e2 100644 --- a/src/main/scala/org/tensorframes/impl/DenseTensor.scala +++ b/src/main/scala/org/tensorframes/impl/DenseTensor.scala @@ -17,27 +17,31 @@ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, NumericTy */ private[tensorframes] class DenseTensor private( val shape: Shape, - val dtype: NumericType, + val dtype: ScalarType, private val data: Array[Byte]) { override def toString(): String = s"DenseTensor($shape, $dtype, " + - s"${data.length / dtype.defaultSize} elements)" + s"${data.length} bytes)" } private[tensorframes] object DenseTensor { def apply[T](x: T)(implicit ev2: TypeTag[T]): DenseTensor = { val ops = SupportedOperations.getOps[T]() - new DenseTensor(Shape.empty, ops.sqlType, convert(x)) + apply(Shape.empty, ops.sqlType.asInstanceOf[NumericType], convert(x)) } def apply[T](xs: Seq[T])(implicit ev1: Numeric[T], ev2: TypeTag[T]): DenseTensor = { val ops = SupportedOperations.getOps[T]() - new DenseTensor(Shape(xs.size), ops.sqlType, convert1(xs)) + apply(Shape(xs.size), ops.sqlType.asInstanceOf[NumericType], convert1(xs)) + } + + def apply(shape: Shape, dtype: NumericType, data: Array[Byte]): DenseTensor = { + new DenseTensor(shape, SupportedOperations.opsFor(dtype).scalarType, data) } def matrix[T](xs: Seq[Seq[T]])(implicit ev1: Numeric[T], ev2: TypeTag[T]): DenseTensor = { val ops = SupportedOperations.getOps[T]() - new DenseTensor(Shape(xs.size, xs.head.size), ops.sqlType, convert2(xs)) + apply(Shape(xs.size, xs.head.size), ops.sqlType.asInstanceOf[NumericType], convert2(xs)) } private def convert[T](x: T)(implicit ev2: TypeTag[T]): Array[Byte] = { @@ -98,15 +102,15 @@ private[tensorframes] object DenseTensor { val shape = Shape.from(proto.getTensorShape) val data = ops.sqlType match { case DoubleType => - val coll = proto.getDoubleValList.asScala.toSeq.map(_.doubleValue()) + val coll = proto.getDoubleValList.asScala.map(_.doubleValue()) convert(coll) case IntegerType => - val coll = proto.getIntValList.asScala.toSeq.map(_.intValue()) + val coll = proto.getIntValList.asScala.map(_.intValue()) convert(coll) case _ => throw new IllegalArgumentException( s"Cannot convert type ${ops.sqlType}") } - new DenseTensor(shape, ops.sqlType, data) + new DenseTensor(shape, ops.scalarType, data) } } diff --git a/src/main/scala/org/tensorframes/impl/TFDataOps.scala b/src/main/scala/org/tensorframes/impl/TFDataOps.scala index 5ed609e..f3c4c6b 100644 --- a/src/main/scala/org/tensorframes/impl/TFDataOps.scala +++ b/src/main/scala/org/tensorframes/impl/TFDataOps.scala @@ -184,7 +184,7 @@ object TFDataOps extends Logging { */ private def getColumn( t: tf.Tensor, - scalaType: NumericType, + scalaType: ScalarType, cellShape: Shape, expectedNumRows: Option[Int], fastPath: Boolean = true): (Int, Iterable[Any]) = { diff --git a/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala b/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala index 449db0e..f0aca5c 100644 --- a/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala +++ b/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala @@ -140,10 +140,10 @@ object TensorFlowOps extends Logging { } } - private def getSummaryDefault(op: tf.Operation): Seq[(NumericType, Shape)] = { + private def getSummaryDefault(op: tf.Operation): Seq[(ScalarType, Shape)] = { (0 until op.numOutputs()).map { idx => val n = op.output(idx) - val dt = SupportedOperations.opsFor(n.dataType()).sqlType + val dt = SupportedOperations.opsFor(n.dataType()).scalarType val shape = Shape.from(n.shape()) dt -> shape } @@ -164,6 +164,6 @@ case class GraphNodeSummary( isPlaceholder: Boolean, isInput: Boolean, isOutput: Boolean, - scalarType: NumericType, + scalarType: ScalarType, shape: Shape, name: String) extends Serializable diff --git a/src/main/scala/org/tensorframes/impl/datatypes.scala b/src/main/scala/org/tensorframes/impl/datatypes.scala index 9be2bfa..f12f8ea 100644 --- a/src/main/scala/org/tensorframes/impl/datatypes.scala +++ b/src/main/scala/org/tensorframes/impl/datatypes.scala @@ -5,7 +5,7 @@ import java.nio._ import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import org.{tensorflow => tf} -import org.tensorflow.framework.DataType +import org.tensorflow.framework.{DataType => ProtoDataType} import org.tensorframes.{Logging, Shape} import scala.collection.mutable.{WrappedArray => MWrappedArray} @@ -18,6 +18,39 @@ import scala.reflect.runtime.universe.TypeTag // - jvm: ??? // - protobuf: ??? +/** + * All the types of scalars supported by TensorFrames. + * + * It can be argued that the Binary type is not really a scalar, + * but it is considered as such by both Spark and TensorFlow. + */ +trait ScalarType + +/** + * Int32 + */ +case object ScalarIntType extends ScalarType + +/** + * INT64 + */ +case object ScalarLongType extends ScalarType + +/** + * FLOAT64 + */ +case object ScalarDoubleType extends ScalarType + +/** + * FLOAT32 + */ +case object ScalarFloatType extends ScalarType + +/** + * STRING / BINARY + */ +case object ScalarBinaryType extends ScalarType + /** * @param shape the shape of the element in the row (not the overall shape of the block) * @param numCells the number of cells that are going to be allocated with the given shape. @@ -79,6 +112,7 @@ private[tensorframes] sealed abstract class TensorConverter[@specialized(Double, // The return element is just here so that the method gets specialized (otherwise it would not). final def append(row: Row, position: Int): Array[T] = { + logger.debug(s"append: position=$position row=$row") val d = shape.numDims if (d == 0) { appendRaw(row.getAs[T](position)) @@ -130,12 +164,12 @@ private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int /** * The SQL type associated with the given type. */ - val sqlType: NumericType + val sqlType: DataType /** * The TF type */ - val tfType: DataType + val tfType: ProtoDataType /** * The TF type (new style). @@ -143,6 +177,11 @@ private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int */ val tfType2: tf.DataType + /** + * The type of the scalar value. + */ + val scalarType: ScalarType + /** * A zero element for this type */ @@ -222,20 +261,33 @@ private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int private[tensorframes] object SupportedOperations { private val ops: Seq[ScalarTypeOperation[_]] = - Seq(DoubleOperations, FloatOperations, IntOperations, LongOperations) + Seq(DoubleOperations, FloatOperations, IntOperations, LongOperations, StringOperations) val sqlTypes = ops.map(_.sqlType) + val scalarTypes = ops.map(_.scalarType) + private val tfTypes = ops.map(_.tfType) - def opsFor(t: NumericType): ScalarTypeOperation[_] = { + def getOps(t: DataType): Option[ScalarTypeOperation[_]] = { + ops.find(_.sqlType == t) + } + + def opsFor(t: DataType): ScalarTypeOperation[_] = { ops.find(_.sqlType == t).getOrElse { throw new IllegalArgumentException(s"Type $t is not supported. Only the following types are" + s"supported: ${sqlTypes.mkString(", ")}") } } - def opsFor(t: DataType): ScalarTypeOperation[_] = { + def opsFor(t: ScalarType): ScalarTypeOperation[_] = { + ops.find(_.scalarType == t).getOrElse { + throw new IllegalArgumentException(s"Type $t is not supported. Only the following types are" + + s"supported: ${sqlTypes.mkString(", ")}") + } + } + + def opsFor(t: ProtoDataType): ScalarTypeOperation[_] = { ops.find(_.tfType == t).getOrElse { throw new IllegalArgumentException(s"Type $t is not supported. Only the following types are" + s"supported: ${tfTypes.mkString(", ")}") @@ -299,8 +351,9 @@ private[impl] class DoubleTensorConverter(s: Shape, numCells: Int) private[impl] object DoubleOperations extends ScalarTypeOperation[Double] with Logging { override val sqlType = DoubleType - override val tfType = DataType.DT_DOUBLE + override val tfType = ProtoDataType.DT_DOUBLE override val tfType2 = tf.DataType.DOUBLE + override val scalarType = ScalarDoubleType final override val zero = 0.0 override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Double] = new DoubleTensorConverter(cellShape, numCells) @@ -358,8 +411,9 @@ private[impl] class FloatTensorConverter(s: Shape, numCells: Int) private[impl] object FloatOperations extends ScalarTypeOperation[Float] with Logging { override val sqlType = FloatType - override val tfType = DataType.DT_FLOAT + override val tfType = ProtoDataType.DT_FLOAT override val tfType2 = tf.DataType.FLOAT + override val scalarType = ScalarFloatType final override val zero = 0.0f override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Float] = new FloatTensorConverter(cellShape, numCells) @@ -414,8 +468,9 @@ private[impl] class IntTensorConverter(s: Shape, numCells: Int) private[impl] object IntOperations extends ScalarTypeOperation[Int] with Logging { override val sqlType = IntegerType - override val tfType = DataType.DT_INT32 + override val tfType = ProtoDataType.DT_INT32 override val tfType2 = tf.DataType.INT32 + override val scalarType = ScalarIntType final override val zero = 0 override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Int] = new IntTensorConverter(cellShape, numCells) @@ -467,8 +522,9 @@ private[impl] class LongTensorConverter(s: Shape, numCells: Int) private[impl] object LongOperations extends ScalarTypeOperation[Long] with Logging { override val sqlType = LongType - override val tfType = DataType.DT_INT64 + override val tfType = ProtoDataType.DT_INT64 override val tfType2 = tf.DataType.INT64 + override val scalarType = ScalarLongType final override val zero = 0L override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Long] = new LongTensorConverter(cellShape, numCells) @@ -488,4 +544,61 @@ private[impl] object LongOperations extends ScalarTypeOperation[Long] with Loggi logTrace(s"Extracted from buffer: ${res.toSeq}") res } -} \ No newline at end of file +} + +// ********** STRING ********* +// This is actually byte arrays, which corresponds to the 'binary' type in Spark. + +// The string converter can only deal with one row at a time (the most common case). +private[impl] class StringTensorConverter(s: Shape, numCells: Int) + extends TensorConverter[Array[Byte]](s, numCells) with Logging { + private var buffer: Array[Byte] = null + + override val elementSize: Int = 1 + + { + logger.debug(s"Creating string buffer for shape $s and $numCells cells") + assert(s == Shape() && numCells == 1, s"The string buffer does not accept more than one" + + s" scalar of type binary. shape=$s numCells=$numCells") + } + + + override def reserve(): Unit = {} + + override def appendRaw(d: Array[Byte]): Unit = { + assert(buffer == null, s"The buffer has only been set with ${buffer.length} values," + + s" but ${d.length} are trying to get inserted") + buffer = d.clone() + } + + override def tensor2(): tf.Tensor = { + tf.Tensor.create(buffer) + } + + override def fillBuffer(buff: ByteBuffer): Unit = { + buff.put(buffer) + } +} + +private[impl] object StringOperations extends ScalarTypeOperation[Array[Byte]] with Logging { + override val sqlType = BinaryType + override val tfType = ProtoDataType.DT_STRING + override val tfType2 = tf.DataType.STRING + override val scalarType = ScalarBinaryType + final override val zero = Array.empty[Byte] + + override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Array[Byte]] = + new StringTensorConverter(cellShape, numCells) + + override def convertTensor(t: tf.Tensor): MWrappedArray[Array[Byte]] = { + // TODO(tjh) implement later + ??? + } + + override def convertBuffer(buff: ByteBuffer, numElements: Int): Iterable[Any] = { + // TODO(tjh) implement later + ??? + } +} + + diff --git a/src/main/scala/org/tensorframes/test/dsl.scala b/src/main/scala/org/tensorframes/test/dsl.scala index 1d14503..fbe4f7a 100644 --- a/src/main/scala/org/tensorframes/test/dsl.scala +++ b/src/main/scala/org/tensorframes/test/dsl.scala @@ -2,10 +2,10 @@ package org.tensorframes.test import java.nio.file.{Files, Paths} -import org.apache.spark.sql.types.{DoubleType, NumericType} -import org.tensorflow.framework._ +import org.apache.spark.sql.types.{DataType, NumericType} +import org.tensorflow.framework.{AttrValue, GraphDef, NodeDef, TensorShapeProto, DataType => ProtoDataType} import org.tensorframes.{Logging, Shape} -import org.tensorframes.impl.{DenseTensor, SupportedOperations} +import org.tensorframes.impl.{DenseTensor, ScalarType, SupportedOperations} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe._ @@ -25,7 +25,7 @@ object dsl extends Logging { def toAttr: AttrValue = buildType(s) } - private implicit class DataTypeToAttr(dt: DataType) { + private implicit class DataTypeToAttr(dt: ProtoDataType) { def toAttr: AttrValue = dataTypeToAttrValue(dt) } @@ -66,8 +66,8 @@ object dsl extends Logging { def +(other: Node): Node = op_add(this, other) } - private[tensorframes] def placeholder(dtype: NumericType, shape: Shape): Node = { - build("Placeholder", shape=shape, dtype=dtype, isOp = false, + private[tensorframes] def placeholder(dtype: DataType, shape: Shape): Node = { + build("Placeholder", shape=shape, dtype=dtype.asInstanceOf[NumericType], isOp = false, extraAttrs = Map("shape" -> shape.toAttr)) } @@ -165,8 +165,9 @@ object dsl extends Logging { private def build_constant(dt: DenseTensor): Node = { val a = AttrValue.newBuilder().setTensor(DenseTensor.toTensorProto(dt)) + val dt2 = SupportedOperations.opsFor(dt.dtype).sqlType.asInstanceOf[NumericType] build("Const", isOp = false, - shape = dt.shape, dtype = dt.dtype, + shape = dt.shape, dtype = dt2, extraAttrs = Map("value" -> a.build())) } @@ -196,7 +197,7 @@ object dsl extends Logging { dtype = parent.scalarType, shape = reduce_shape(parent.shape, Option(reduction_indices).getOrElse(Nil)), extraAttrs = Map( - "Tidx" -> AttrValue.newBuilder().setType(DataType.DT_INT32).build(), + "Tidx" -> AttrValue.newBuilder().setType(ProtoDataType.DT_INT32).build(), "keep_dims" -> AttrValue.newBuilder().setB(false).build())) } @@ -218,13 +219,16 @@ object dsl extends Logging { * Utilities to convert data back and forth between the proto descriptions and the dataframe descriptions. */ object ProtoConversions { - def getDType(nodeDef: NodeDef): DataType = { + def getDType(nodeDef: NodeDef): ProtoDataType = { val opt = Option(nodeDef.getAttr.get("T")).orElse(Option(nodeDef.getAttr.get("dtype"))) val v = opt.getOrElse(throw new Exception(s"Neither 'T' no 'dtype' was found in $nodeDef")) v.getType } - def getDType(sqlType: NumericType): DataType = { + def getDType(sqlType: NumericType): ProtoDataType = { + SupportedOperations.opsFor(sqlType).tfType + } + def getDType(sqlType: ScalarType): ProtoDataType = { SupportedOperations.opsFor(sqlType).tfType } @@ -232,7 +236,7 @@ object ProtoConversions { AttrValue.newBuilder().setType(getDType(sqlType)).build() } - def dataTypeToAttrValue(dataType: DataType): AttrValue = { + def dataTypeToAttrValue(dataType: ProtoDataType): AttrValue = { AttrValue.newBuilder().setType(dataType).build() } diff --git a/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala b/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala index d92e58e..cd7d50c 100644 --- a/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala +++ b/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala @@ -3,7 +3,7 @@ package org.tensorframes import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DoubleType, StructType} import org.scalatest.FunSuite -import org.tensorframes.impl.DebugRowOpsImpl +import org.tensorframes.impl.{DebugRowOpsImpl, ScalarDoubleType} import org.tensorframes.dsl._ class DebugRowOpsSuite @@ -14,10 +14,10 @@ class DebugRowOpsSuite testGraph("Simple identity") { val rows = Array(Row(1.0)) - val input = StructType(Array(structField("x", DoubleType, Shape(Unknown)))) + val input = StructType(Array(structField("x", ScalarDoubleType, Shape(Unknown)))) val p2 = placeholder[Double](1) named "x" val out = identity(p2) named "y" - val outputSchema = StructType(Array(structField("y", DoubleType, Shape(Unknown)))) + val outputSchema = StructType(Array(structField("y", ScalarDoubleType, Shape(Unknown)))) val (g, _) = TestUtilities.analyzeGraph(out) logDebug(g.toString) val res = DebugRowOpsImpl.performMap(rows, input, Array(0), g, outputSchema) @@ -26,10 +26,10 @@ class DebugRowOpsSuite testGraph("Simple add") { val rows = Array(Row(1.0)) - val input = StructType(Array(structField("x", DoubleType, Shape(Unknown)))) + val input = StructType(Array(structField("x", ScalarDoubleType, Shape(Unknown)))) val p2 = placeholder[Double](1) named "x" val out = p2 + p2 named "y" - val outputSchema = StructType(Array(structField("y", DoubleType, Shape(Unknown)))) + val outputSchema = StructType(Array(structField("y", ScalarDoubleType, Shape(Unknown)))) val (g, _) = TestUtilities.analyzeGraph(out) logDebug(g.toString) val res = DebugRowOpsImpl.performMap(rows, input, Array(0), g, outputSchema) diff --git a/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala b/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala index 2a0a1a0..b197df7 100644 --- a/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala +++ b/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala @@ -2,6 +2,7 @@ package org.tensorframes import org.apache.spark.sql.types.{DoubleType, IntegerType} import org.scalatest.FunSuite +import org.tensorframes.impl.{ScalarDoubleType, ScalarIntType} class ExtraOperationsSuite @@ -16,7 +17,7 @@ class ExtraOperationsSuite val di = ExtraOperations.explainDetailed(df) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === DoubleType) + assert(s.dataType === ScalarDoubleType) assert(s.shape === Shape(Unknown)) logDebug(df.toString() + "->" + di.toString) } @@ -26,7 +27,7 @@ class ExtraOperationsSuite val di = explainDetailed(df) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === IntegerType) + assert(s.dataType === ScalarIntType) assert(s.shape === Shape(Unknown)) logDebug(df.toString() + "->" + di.toString) } @@ -37,13 +38,13 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1, c2, c3) = di.cols val Some(s1) = c1.stf - assert(s1.dataType === DoubleType) + assert(s1.dataType === ScalarDoubleType) assert(s1.shape === Shape(Unknown)) val Some(s2) = c2.stf - assert(s2.dataType === DoubleType) + assert(s2.dataType === ScalarDoubleType) assert(s2.shape === Shape(Unknown, Unknown)) val Some(s3) = c3.stf - assert(s3.dataType === DoubleType) + assert(s3.dataType === ScalarDoubleType) assert(s3.shape === Shape(Unknown, Unknown, Unknown)) } @@ -54,7 +55,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === DoubleType) + assert(s.dataType === ScalarDoubleType) assert(s.shape === Shape(1)) // There is only one partition } @@ -65,7 +66,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === DoubleType) + assert(s.dataType === ScalarDoubleType) assert(s.shape === Shape(Unknown)) // There is only one partition } @@ -78,7 +79,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1, c2) = di.cols val Some(s2) = c2.stf - assert(s2.dataType === DoubleType) + assert(s2.dataType === ScalarDoubleType) assert(s2.shape === Shape(2, Unknown)) // There is only one partition } @@ -92,7 +93,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1, c2) = di.cols val Some(s2) = c2.stf - assert(s2.dataType === DoubleType) + assert(s2.dataType === ScalarDoubleType) assert(s2.shape === Shape(3, 2)) // There is only one partition } } diff --git a/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala b/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala index a1680f9..3624e72 100644 --- a/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala +++ b/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala @@ -2,7 +2,7 @@ package org.tensorframes.perf import org.scalatest.FunSuite import org.tensorframes.{ColumnInformation, Shape, TensorFramesTestSparkContext} -import org.tensorframes.impl.{SupportedOperations, TFDataOps} +import org.tensorframes.impl.{ScalarIntType, SupportedOperations, TFDataOps} import org.tensorframes.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.types._ @@ -48,9 +48,9 @@ class ConvertBackPerformanceSuite // Creating the rows this way, because we need to respect the collection used by Spark when // unpacking the rows. val rows = sqlContext.createDataFrame(Seq.fill(numCells)(Tuple1(Seq.fill(numVals)(1)))).collect() - val schema = StructType(Seq(ColumnInformation.structField("f1", IntegerType, + val schema = StructType(Seq(ColumnInformation.structField("f1", ScalarIntType, Shape(numCells, numVals)))) - val tfSchema = StructType(Seq(ColumnInformation.structField("f2", IntegerType, + val tfSchema = StructType(Seq(ColumnInformation.structField("f2", ScalarIntType, Shape(numCells, numVals)))) val tensor = getTFTensor(IntegerType, Row(Seq.fill(numVals)(1)), Shape(numVals), numCells) println("generated data") diff --git a/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala b/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala index 556c8f1..b3112e8 100644 --- a/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala +++ b/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala @@ -5,7 +5,7 @@ import org.tensorframes.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import org.tensorframes.{ColumnInformation, Shape, TensorFramesTestSparkContext} -import org.tensorframes.impl.{DataOps, SupportedOperations, TFDataOps} +import org.tensorframes.impl.{DataOps, ScalarIntType, SupportedOperations, TFDataOps} class ConvertPerformanceSuite extends FunSuite with TensorFramesTestSparkContext with Logging { @@ -44,7 +44,7 @@ class ConvertPerformanceSuite // Creating the rows this way, because we need to respect the collection used by Spark when // unpacking the rows. val rows = sqlContext.createDataFrame(Seq.fill(numCells)(Tuple1(Seq.fill(numVals)(1)))).collect() - val schema = StructType(Seq(ColumnInformation.structField("f1", IntegerType, + val schema = StructType(Seq(ColumnInformation.structField("f1", ScalarIntType, Shape(numCells, numVals)))) println("generated data") logInfo("generated data") From baf65425dbc27f9b1449de9466ece443d2850f5c Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 18 Apr 2017 08:23:10 -0700 Subject: [PATCH 2/7] removing type tags, they do not work with scala 2.10 --- .../scala/org/tensorframes/impl/datatypes.scala | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/main/scala/org/tensorframes/impl/datatypes.scala b/src/main/scala/org/tensorframes/impl/datatypes.scala index f12f8ea..50a0853 100644 --- a/src/main/scala/org/tensorframes/impl/datatypes.scala +++ b/src/main/scala/org/tensorframes/impl/datatypes.scala @@ -160,7 +160,7 @@ private[tensorframes] sealed abstract class TensorConverter[@specialized(Double, * internally through casting. */ private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int, Long, Double, Float) T] - (implicit ev1: TypeTag[T], ev2: ClassTag[T]) { + (implicit ev: ClassTag[T]) { /** * The SQL type associated with the given type. */ @@ -256,7 +256,7 @@ private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int res.map { arr => conv(arr.map(conv)) } } - def tag: TypeTag[_] = implicitly[TypeTag[T]] + def tag: Option[TypeTag[_]] } private[tensorframes] object SupportedOperations { @@ -304,7 +304,7 @@ private[tensorframes] object SupportedOperations { def getOps[T : TypeTag](): ScalarTypeOperation[T] = { val ev: TypeTag[_] = implicitly[TypeTag[T]] - ops.find(_.tag.tpe =:= ev.tpe).getOrElse { + ops.find(_.tag.map(_.tpe =:= ev.tpe) == Some(true)).getOrElse { val tags = ops.map(_.tag.toString()).mkString(", ") throw new IllegalArgumentException(s"Type ${ev} is not supported. Only the following types " + s"are supported: ${tags}") @@ -378,6 +378,8 @@ private[impl] object DoubleOperations extends ScalarTypeOperation[Double] with L res } + override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Double]]) + } // ********** FLOAT ************ @@ -435,6 +437,8 @@ private[impl] object FloatOperations extends ScalarTypeOperation[Float] with Log t.writeTo(b) res } + + override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Float]]) } // ********** INT32 ************ @@ -489,6 +493,8 @@ private[impl] object IntOperations extends ScalarTypeOperation[Int] with Logging dbuff.get(res) res } + + override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Int]]) } // ****** INT64 (LONG) ****** @@ -544,6 +550,8 @@ private[impl] object LongOperations extends ScalarTypeOperation[Long] with Loggi logTrace(s"Extracted from buffer: ${res.toSeq}") res } + + override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Long]]) } // ********** STRING ********* @@ -599,6 +607,8 @@ private[impl] object StringOperations extends ScalarTypeOperation[Array[Byte]] w // TODO(tjh) implement later ??? } + + override def tag: Option[TypeTag[_]] = None } From ed4d39c5813ef14bb0142b066f0fdb0bd46fb3ec Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 18 Apr 2017 09:35:39 -0700 Subject: [PATCH 3/7] removing type tags, they do not work with scala 2.10 --- .../scala/org/tensorframes/impl/datatypes.scala | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/main/scala/org/tensorframes/impl/datatypes.scala b/src/main/scala/org/tensorframes/impl/datatypes.scala index 50a0853..72088cc 100644 --- a/src/main/scala/org/tensorframes/impl/datatypes.scala +++ b/src/main/scala/org/tensorframes/impl/datatypes.scala @@ -159,8 +159,7 @@ private[tensorframes] sealed abstract class TensorConverter[@specialized(Double, * It does not support TF's rich type collection (uint16, float128, etc.). These have to be handled * internally through casting. */ -private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int, Long, Double, Float) T] - (implicit ev: ClassTag[T]) { +private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int, Long, Double, Float) T] { /** * The SQL type associated with the given type. */ @@ -256,7 +255,11 @@ private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int res.map { arr => conv(arr.map(conv)) } } + implicit def classTag: ClassTag[T] = ev + def tag: Option[TypeTag[_]] + + def ev: ClassTag[T] = null } private[tensorframes] object SupportedOperations { @@ -380,6 +383,7 @@ private[impl] object DoubleOperations extends ScalarTypeOperation[Double] with L override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Double]]) + override def ev = ClassTag.Double } // ********** FLOAT ************ @@ -439,6 +443,8 @@ private[impl] object FloatOperations extends ScalarTypeOperation[Float] with Log } override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Float]]) + + override def ev = ClassTag.Float } // ********** INT32 ************ @@ -495,6 +501,8 @@ private[impl] object IntOperations extends ScalarTypeOperation[Int] with Logging } override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Int]]) + + override def ev = ClassTag.Int } // ****** INT64 (LONG) ****** @@ -552,6 +560,8 @@ private[impl] object LongOperations extends ScalarTypeOperation[Long] with Loggi } override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Long]]) + + override def ev = ClassTag.Long } // ********** STRING ********* @@ -609,6 +619,8 @@ private[impl] object StringOperations extends ScalarTypeOperation[Array[Byte]] w } override def tag: Option[TypeTag[_]] = None + + override def ev = ??? } From 46ce8d0c8497f583bbd7707d31de932d3ac3bd9f Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 18 Apr 2017 09:47:08 -0700 Subject: [PATCH 4/7] removing type tags, they do not work with scala 2.10 --- src/main/scala/org/tensorframes/impl/datatypes.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/org/tensorframes/impl/datatypes.scala b/src/main/scala/org/tensorframes/impl/datatypes.scala index 72088cc..9b976c0 100644 --- a/src/main/scala/org/tensorframes/impl/datatypes.scala +++ b/src/main/scala/org/tensorframes/impl/datatypes.scala @@ -60,7 +60,7 @@ case object ScalarBinaryType extends ScalarType private[tensorframes] sealed abstract class TensorConverter[@specialized(Double, Float, Int, Long) T] ( val shape: Shape, val numCells: Int) - (implicit ev1: TypeTag[T], ev2: ClassTag[T]) extends Logging { + (implicit ev2: ClassTag[T]) extends Logging { final val empty = Array.empty[T] /** * Creates memory space for a given number of units of the given shape. From 839c9456fccd544c2e81e85ba95b48ab50d704e3 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Thu, 20 Apr 2017 14:03:24 -0700 Subject: [PATCH 5/7] changes --- .../org/tensorframes/ColumnInformation.scala | 35 ++-- .../tensorframes/ExperimentalOperations.scala | 6 +- .../org/tensorframes/MetadataConstants.scala | 6 +- src/main/scala/org/tensorframes/Shape.scala | 21 +-- .../scala/org/tensorframes/dsl/DslImpl.scala | 13 +- .../scala/org/tensorframes/dsl/package.scala | 6 +- .../scala/org/tensorframes/impl/DataOps.scala | 5 +- .../org/tensorframes/impl/DebugRowOps.scala | 41 ++--- .../org/tensorframes/impl/DenseTensor.scala | 20 +-- .../org/tensorframes/impl/TFDataOps.scala | 2 +- .../org/tensorframes/impl/TensorFlowOps.scala | 6 +- .../org/tensorframes/impl/datatypes.scala | 167 ++---------------- .../scala/org/tensorframes/test/dsl.scala | 26 ++- 13 files changed, 86 insertions(+), 268 deletions(-) diff --git a/src/main/scala/org/tensorframes/ColumnInformation.scala b/src/main/scala/org/tensorframes/ColumnInformation.scala index 9e3402b..d06a76c 100644 --- a/src/main/scala/org/tensorframes/ColumnInformation.scala +++ b/src/main/scala/org/tensorframes/ColumnInformation.scala @@ -1,7 +1,6 @@ package org.tensorframes import org.apache.spark.sql.types._ -import org.tensorframes.impl.{ScalarType, SupportedOperations} class ColumnInformation private ( @@ -16,9 +15,7 @@ class ColumnInformation private ( val b = new MetadataBuilder().withMetadata(field.metadata) for (info <- stf) { b.putLongArray(shapeKey, info.shape.dims.toArray) - // Keep the SQL name, so that we do not leak internal details. - val dt = SupportedOperations.opsFor(info.dataType).sqlType - b.putString(tensorStructType, dt.toString) + b.putString(tensorStructType, info.dataType.toString) } val meta = b.build() field.copy(metadata = meta) @@ -76,15 +73,15 @@ object ColumnInformation extends Logging { * @param scalarType the data type * @param blockShape the shape of the block */ - def structField(name: String, scalarType: ScalarType, blockShape: Shape): StructField = { + def structField(name: String, scalarType: NumericType, blockShape: Shape): StructField = { val i = SparkTFColInfo(blockShape, scalarType) val f = StructField(name, sqlType(scalarType, blockShape.tail), nullable = false) ColumnInformation(f, i).merged } - private def sqlType(scalarType: ScalarType, shape: Shape): DataType = { + private def sqlType(scalarType: NumericType, shape: Shape): DataType = { if (shape.dims.isEmpty) { - SupportedOperations.opsFor(scalarType).sqlType + scalarType } else { ArrayType(sqlType(scalarType, shape.tail), containsNull = false) } @@ -105,14 +102,11 @@ object ColumnInformation extends Logging { for { s <- shape t <- tpe - ops <- SupportedOperations.getOps(t) - } yield SparkTFColInfo(s, ops.scalarType) + } yield SparkTFColInfo(s, t) } - private def getType(s: String): Option[DataType] = { - val res = supportedTypes.find(_.toString == s) - logInfo(s"getType: $s -> $res") - res + private def getType(s: String): Option[NumericType] = { + supportedTypes.find(_.toString == s) } /** @@ -121,18 +115,19 @@ object ColumnInformation extends Logging { * @return */ private def extractFromRow(dt: DataType): Option[SparkTFColInfo] = dt match { + case x: NumericType if MetadataConstants.supportedTypes.contains(dt) => + logTrace("numerictype: " + x) + // It is a basic type that we understand + Some(SparkTFColInfo(Shape(Unknown), x)) case x: ArrayType => logTrace("arraytype: " + x) // Look into the array to figure out the type. extractFromRow(x.elementType).map { info => SparkTFColInfo(info.shape.prepend(Unknown), info.dataType) } - case _ => SupportedOperations.getOps(dt) match { - case Some(ops) => - logTrace("numerictype: " + ops.scalarType) - // It is a basic type that we understand - Some(SparkTFColInfo(Shape(Unknown), ops.scalarType)) - case None => None - } + case _ => + logTrace("not understood: " + dt) + // Not understood. + None } } diff --git a/src/main/scala/org/tensorframes/ExperimentalOperations.scala b/src/main/scala/org/tensorframes/ExperimentalOperations.scala index 87aed4e..a622104 100644 --- a/src/main/scala/org/tensorframes/ExperimentalOperations.scala +++ b/src/main/scala/org/tensorframes/ExperimentalOperations.scala @@ -3,7 +3,7 @@ package org.tensorframes import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{ArrayType, DataType, NumericType} -import org.tensorframes.impl.{ScalarType, SupportedOperations} +import org.tensorframes.impl.SupportedOperations /** * Some useful methods for operating on dataframes that are not part of the official API (and thus may change anytime). @@ -109,8 +109,8 @@ private[tensorframes] object ExtraOperations extends ExperimentalOperations with DataFrameInfo(allInfo) } - private def extractBasicType(dt: DataType): Option[ScalarType] = dt match { - case x: NumericType => Some(SupportedOperations.opsFor(x).scalarType) + private def extractBasicType(dt: DataType): Option[NumericType] = dt match { + case x: NumericType => Some(x) case x: ArrayType => extractBasicType(x.elementType) case _ => None } diff --git a/src/main/scala/org/tensorframes/MetadataConstants.scala b/src/main/scala/org/tensorframes/MetadataConstants.scala index d0aea61..affec0d 100644 --- a/src/main/scala/org/tensorframes/MetadataConstants.scala +++ b/src/main/scala/org/tensorframes/MetadataConstants.scala @@ -1,7 +1,7 @@ package org.tensorframes -import org.apache.spark.sql.types.{DataType, NumericType} -import org.tensorframes.impl.{ScalarType, SupportedOperations} +import org.apache.spark.sql.types.NumericType +import org.tensorframes.impl.SupportedOperations /** * Metadata annotations that get embedded in dataframes to express tensor information. @@ -29,5 +29,5 @@ object MetadataConstants { /** * All the SQL types supported by SparkTF. */ - val supportedTypes: Seq[DataType] = SupportedOperations.sqlTypes + val supportedTypes: Seq[NumericType] = SupportedOperations.sqlTypes } \ No newline at end of file diff --git a/src/main/scala/org/tensorframes/Shape.scala b/src/main/scala/org/tensorframes/Shape.scala index 6eced36..b7d9859 100644 --- a/src/main/scala/org/tensorframes/Shape.scala +++ b/src/main/scala/org/tensorframes/Shape.scala @@ -1,11 +1,9 @@ package org.tensorframes -import org.apache.spark.sql.types.{BinaryType, DataType, NumericType} +import org.apache.spark.sql.types.NumericType import org.tensorflow.framework.TensorShapeProto - import scala.collection.JavaConverters._ import org.tensorframes.Shape.DimType -import org.tensorframes.impl.ScalarType import org.{tensorflow => tf} @@ -38,11 +36,6 @@ class Shape private (private val ds: Array[DimType]) extends Serializable { def prepend(x: Int): Shape = Shape(x.toLong +: ds) - /** - * Drops the most inner dimension of the shape. - */ - def dropInner: Shape = Shape(ds.dropRight(1)) - /** * A shape with the first dimension dropped. */ @@ -112,22 +105,14 @@ object Shape { /** * SparkTF information. This is the information generally required to work on a tensor. * @param shape - * @param dataType the datatype of the scalar. Note that it is either NumericType or BinaryType. + * @param dataType */ // TODO(tjh) the types supported by TF are much richer (uint8, etc.) but it is not clear // if they all map to a Catalyst memory representation // TODO(tjh) support later basic structures for sparse types? case class SparkTFColInfo( shape: Shape, - dataType: ScalarType) extends Serializable { - -// // Forces a cast to a numeric type, which may fail. -// // TODO: try to use an atomic type instead? -// def numericType: NumericType = dataType match { -// case x: NumericType => x -// case _ => throw new Exception(s"$dataType cannot be cast to a numeric type") -// } -} + dataType: NumericType) extends Serializable /** * Exception thrown when the user requests tensors of high order. diff --git a/src/main/scala/org/tensorframes/dsl/DslImpl.scala b/src/main/scala/org/tensorframes/dsl/DslImpl.scala index 6e82c0e..3795c9f 100644 --- a/src/main/scala/org/tensorframes/dsl/DslImpl.scala +++ b/src/main/scala/org/tensorframes/dsl/DslImpl.scala @@ -1,12 +1,13 @@ package org.tensorframes.dsl import javax.annotation.Nullable - import org.tensorflow.framework.{AttrValue, DataType, GraphDef, TensorShapeProto} + import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.NumericType -import org.tensorframes.{ColumnInformation, Logging, Shape} -import org.tensorframes.impl.{DenseTensor, SupportedOperations} + +import org.tensorframes.{Logging, ColumnInformation, Shape} +import org.tensorframes.impl.DenseTensor /** @@ -74,9 +75,8 @@ private[dsl] object DslImpl extends Logging with DefaultConversions { def build_constant(dt: DenseTensor): Node = { val a = AttrValue.newBuilder().setTensor(DenseTensor.toTensorProto(dt)) - val dt2 = SupportedOperations.opsFor(dt.dtype).sqlType.asInstanceOf[NumericType] build("Const", isOp = false, - shape = dt.shape, dtype = dt2, + shape = dt.shape, dtype = dt.dtype, extraAttrs = Map("value" -> a.build())) } @@ -100,8 +100,7 @@ private[dsl] object DslImpl extends Logging with DefaultConversions { s"tensorframes: $schema") } val shape = if (block) { stf.shape } else { stf.shape.tail } - val dt = SupportedOperations.opsFor(stf.dataType).sqlType.asInstanceOf[NumericType] - DslImpl.placeholder(dt, shape).named(tfName) + DslImpl.placeholder(stf.dataType, shape).named(tfName) } private def commonShape(shapes: Seq[Shape]): Shape = { diff --git a/src/main/scala/org/tensorframes/dsl/package.scala b/src/main/scala/org/tensorframes/dsl/package.scala index 2a787d6..adfa39a 100644 --- a/src/main/scala/org/tensorframes/dsl/package.scala +++ b/src/main/scala/org/tensorframes/dsl/package.scala @@ -1,8 +1,10 @@ package org.tensorframes import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.{IntegerType, NumericType} +import org.apache.spark.sql.types.IntegerType + import org.tensorframes.impl.SupportedOperations /** @@ -43,7 +45,7 @@ package object dsl { def placeholder[T : Numeric : TypeTag](shape: Int*): Operation = { val ops = SupportedOperations.getOps[T]() - DslImpl.placeholder(ops.sqlType.asInstanceOf[NumericType], Shape(shape: _*)) + DslImpl.placeholder(ops.sqlType, Shape(shape: _*)) } def constant[T : ConvertibleToDenseTensor](x: T): Operation = { diff --git a/src/main/scala/org/tensorframes/impl/DataOps.scala b/src/main/scala/org/tensorframes/impl/DataOps.scala index a1e7d8d..6e60f1a 100644 --- a/src/main/scala/org/tensorframes/impl/DataOps.scala +++ b/src/main/scala/org/tensorframes/impl/DataOps.scala @@ -2,11 +2,10 @@ package org.tensorframes.impl import scala.collection.mutable import scala.reflect.ClassTag -import org.{tensorflow => tf} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.{StructType} +import org.apache.spark.sql.types.{NumericType, StructType} import org.tensorframes.{Logging, Shape} import org.tensorframes.Shape.DimType @@ -146,7 +145,7 @@ object DataOps extends Logging { def getColumnFast0( reshapeShape: Shape, - scalaType: ScalarType, + scalaType: NumericType, allDataBuffer: mutable.WrappedArray[_]): Iterable[Any] = { reshapeShape.dims match { case Seq() => diff --git a/src/main/scala/org/tensorframes/impl/DebugRowOps.scala b/src/main/scala/org/tensorframes/impl/DebugRowOps.scala index db413fc..e9dbd32 100644 --- a/src/main/scala/org/tensorframes/impl/DebugRowOps.scala +++ b/src/main/scala/org/tensorframes/impl/DebugRowOps.scala @@ -267,24 +267,6 @@ private[impl] trait SchemaTransforms extends Logging { case _ => f // Nothing to do } } - - /** - * Checks that the data, coming with a certain shape from Spark, can be shaped into - * the given shape (taking unknown values and binary data into account) - */ - def canBeReshapedTo(dt: ScalarType, from: Shape, to: Shape): Boolean = { -// if (dt == ScalarBinaryType) { -// // Binary has to be at least an array. -// if (from.numDims == 0) { -// return false -// } -// // In that case, the spark shape should be one larger, and we should drop the bottom one. -// from.dropInner.checkMorePreciseThan(to) -// } else { -// from.checkMorePreciseThan(to) -// } - from.checkMorePreciseThan(to) - } } object SchemaTransforms extends SchemaTransforms @@ -340,17 +322,17 @@ class DebugRowOps throw new Exception( s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe") } + if (! stf.shape.checkMorePreciseThan(in.shape)) { + throw new Exception( + s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + + s" ${in.shape} requested by the TF graph") + } // We do not support autocasting for now. if (stf.dataType != in.scalarType) { throw new Exception( s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " + s"of the column (${in.scalarType})") } - if (! canBeReshapedTo(stf.dataType, stf.shape, in.shape)) { - throw new Exception( - s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + - s" ${in.shape} requested by the TF graph") - } // The input has to be either a constant or a placeholder if (! in.isPlaceholder) { throw new Exception( @@ -432,16 +414,16 @@ class DebugRowOps val stf = get(ColumnInformation(f).stf, s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe") - check(stf.dataType == in.scalarType, - s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " + - s"of the column (${in.scalarType})") - val cellShape = stf.shape.tail // No check for unknowns: we allow unknowns in the first dimension of the cell shape. - check(canBeReshapedTo(stf.dataType, cellShape, in.shape), + check(cellShape.checkMorePreciseThan(in.shape), s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + s" ${in.shape} requested by the TF graph") + check(stf.dataType == in.scalarType, + s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " + + s"of the column (${in.scalarType})") + check(in.isPlaceholder, s"Invalid type for input node ${in.name}. It has to be a placeholder") } @@ -550,8 +532,7 @@ class DebugRowOps val f = col.field builder.append(s"$prefix-- ${f.name}: ${f.dataType.typeName} (nullable = ${f.nullable})") val stf = col.stf.map { s => - val dt = SupportedOperations.opsFor(s.dataType).sqlType - s" ${dt.typeName}${s.shape}" + s" ${s.dataType.typeName}${s.shape}" } .getOrElse(" ") builder.append(stf) builder.append("\n") diff --git a/src/main/scala/org/tensorframes/impl/DenseTensor.scala b/src/main/scala/org/tensorframes/impl/DenseTensor.scala index d9e30e2..7414f73 100644 --- a/src/main/scala/org/tensorframes/impl/DenseTensor.scala +++ b/src/main/scala/org/tensorframes/impl/DenseTensor.scala @@ -17,31 +17,27 @@ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, NumericTy */ private[tensorframes] class DenseTensor private( val shape: Shape, - val dtype: ScalarType, + val dtype: NumericType, private val data: Array[Byte]) { override def toString(): String = s"DenseTensor($shape, $dtype, " + - s"${data.length} bytes)" + s"${data.length / dtype.defaultSize} elements)" } private[tensorframes] object DenseTensor { def apply[T](x: T)(implicit ev2: TypeTag[T]): DenseTensor = { val ops = SupportedOperations.getOps[T]() - apply(Shape.empty, ops.sqlType.asInstanceOf[NumericType], convert(x)) + new DenseTensor(Shape.empty, ops.sqlType, convert(x)) } def apply[T](xs: Seq[T])(implicit ev1: Numeric[T], ev2: TypeTag[T]): DenseTensor = { val ops = SupportedOperations.getOps[T]() - apply(Shape(xs.size), ops.sqlType.asInstanceOf[NumericType], convert1(xs)) - } - - def apply(shape: Shape, dtype: NumericType, data: Array[Byte]): DenseTensor = { - new DenseTensor(shape, SupportedOperations.opsFor(dtype).scalarType, data) + new DenseTensor(Shape(xs.size), ops.sqlType, convert1(xs)) } def matrix[T](xs: Seq[Seq[T]])(implicit ev1: Numeric[T], ev2: TypeTag[T]): DenseTensor = { val ops = SupportedOperations.getOps[T]() - apply(Shape(xs.size, xs.head.size), ops.sqlType.asInstanceOf[NumericType], convert2(xs)) + new DenseTensor(Shape(xs.size, xs.head.size), ops.sqlType, convert2(xs)) } private def convert[T](x: T)(implicit ev2: TypeTag[T]): Array[Byte] = { @@ -102,15 +98,15 @@ private[tensorframes] object DenseTensor { val shape = Shape.from(proto.getTensorShape) val data = ops.sqlType match { case DoubleType => - val coll = proto.getDoubleValList.asScala.map(_.doubleValue()) + val coll = proto.getDoubleValList.asScala.toSeq.map(_.doubleValue()) convert(coll) case IntegerType => - val coll = proto.getIntValList.asScala.map(_.intValue()) + val coll = proto.getIntValList.asScala.toSeq.map(_.intValue()) convert(coll) case _ => throw new IllegalArgumentException( s"Cannot convert type ${ops.sqlType}") } - new DenseTensor(shape, ops.scalarType, data) + new DenseTensor(shape, ops.sqlType, data) } } diff --git a/src/main/scala/org/tensorframes/impl/TFDataOps.scala b/src/main/scala/org/tensorframes/impl/TFDataOps.scala index f3c4c6b..5ed609e 100644 --- a/src/main/scala/org/tensorframes/impl/TFDataOps.scala +++ b/src/main/scala/org/tensorframes/impl/TFDataOps.scala @@ -184,7 +184,7 @@ object TFDataOps extends Logging { */ private def getColumn( t: tf.Tensor, - scalaType: ScalarType, + scalaType: NumericType, cellShape: Shape, expectedNumRows: Option[Int], fastPath: Boolean = true): (Int, Iterable[Any]) = { diff --git a/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala b/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala index f0aca5c..449db0e 100644 --- a/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala +++ b/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala @@ -140,10 +140,10 @@ object TensorFlowOps extends Logging { } } - private def getSummaryDefault(op: tf.Operation): Seq[(ScalarType, Shape)] = { + private def getSummaryDefault(op: tf.Operation): Seq[(NumericType, Shape)] = { (0 until op.numOutputs()).map { idx => val n = op.output(idx) - val dt = SupportedOperations.opsFor(n.dataType()).scalarType + val dt = SupportedOperations.opsFor(n.dataType()).sqlType val shape = Shape.from(n.shape()) dt -> shape } @@ -164,6 +164,6 @@ case class GraphNodeSummary( isPlaceholder: Boolean, isInput: Boolean, isOutput: Boolean, - scalarType: ScalarType, + scalarType: NumericType, shape: Shape, name: String) extends Serializable diff --git a/src/main/scala/org/tensorframes/impl/datatypes.scala b/src/main/scala/org/tensorframes/impl/datatypes.scala index 9b976c0..9be2bfa 100644 --- a/src/main/scala/org/tensorframes/impl/datatypes.scala +++ b/src/main/scala/org/tensorframes/impl/datatypes.scala @@ -5,7 +5,7 @@ import java.nio._ import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import org.{tensorflow => tf} -import org.tensorflow.framework.{DataType => ProtoDataType} +import org.tensorflow.framework.DataType import org.tensorframes.{Logging, Shape} import scala.collection.mutable.{WrappedArray => MWrappedArray} @@ -18,39 +18,6 @@ import scala.reflect.runtime.universe.TypeTag // - jvm: ??? // - protobuf: ??? -/** - * All the types of scalars supported by TensorFrames. - * - * It can be argued that the Binary type is not really a scalar, - * but it is considered as such by both Spark and TensorFlow. - */ -trait ScalarType - -/** - * Int32 - */ -case object ScalarIntType extends ScalarType - -/** - * INT64 - */ -case object ScalarLongType extends ScalarType - -/** - * FLOAT64 - */ -case object ScalarDoubleType extends ScalarType - -/** - * FLOAT32 - */ -case object ScalarFloatType extends ScalarType - -/** - * STRING / BINARY - */ -case object ScalarBinaryType extends ScalarType - /** * @param shape the shape of the element in the row (not the overall shape of the block) * @param numCells the number of cells that are going to be allocated with the given shape. @@ -60,7 +27,7 @@ case object ScalarBinaryType extends ScalarType private[tensorframes] sealed abstract class TensorConverter[@specialized(Double, Float, Int, Long) T] ( val shape: Shape, val numCells: Int) - (implicit ev2: ClassTag[T]) extends Logging { + (implicit ev1: TypeTag[T], ev2: ClassTag[T]) extends Logging { final val empty = Array.empty[T] /** * Creates memory space for a given number of units of the given shape. @@ -112,7 +79,6 @@ private[tensorframes] sealed abstract class TensorConverter[@specialized(Double, // The return element is just here so that the method gets specialized (otherwise it would not). final def append(row: Row, position: Int): Array[T] = { - logger.debug(s"append: position=$position row=$row") val d = shape.numDims if (d == 0) { appendRaw(row.getAs[T](position)) @@ -159,16 +125,17 @@ private[tensorframes] sealed abstract class TensorConverter[@specialized(Double, * It does not support TF's rich type collection (uint16, float128, etc.). These have to be handled * internally through casting. */ -private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int, Long, Double, Float) T] { +private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int, Long, Double, Float) T] + (implicit ev1: TypeTag[T], ev2: ClassTag[T]) { /** * The SQL type associated with the given type. */ - val sqlType: DataType + val sqlType: NumericType /** * The TF type */ - val tfType: ProtoDataType + val tfType: DataType /** * The TF type (new style). @@ -176,11 +143,6 @@ private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int */ val tfType2: tf.DataType - /** - * The type of the scalar value. - */ - val scalarType: ScalarType - /** * A zero element for this type */ @@ -255,42 +217,25 @@ private[tensorframes] sealed abstract class ScalarTypeOperation[@specialized(Int res.map { arr => conv(arr.map(conv)) } } - implicit def classTag: ClassTag[T] = ev - - def tag: Option[TypeTag[_]] - - def ev: ClassTag[T] = null + def tag: TypeTag[_] = implicitly[TypeTag[T]] } private[tensorframes] object SupportedOperations { private val ops: Seq[ScalarTypeOperation[_]] = - Seq(DoubleOperations, FloatOperations, IntOperations, LongOperations, StringOperations) + Seq(DoubleOperations, FloatOperations, IntOperations, LongOperations) val sqlTypes = ops.map(_.sqlType) - val scalarTypes = ops.map(_.scalarType) - private val tfTypes = ops.map(_.tfType) - def getOps(t: DataType): Option[ScalarTypeOperation[_]] = { - ops.find(_.sqlType == t) - } - - def opsFor(t: DataType): ScalarTypeOperation[_] = { + def opsFor(t: NumericType): ScalarTypeOperation[_] = { ops.find(_.sqlType == t).getOrElse { throw new IllegalArgumentException(s"Type $t is not supported. Only the following types are" + s"supported: ${sqlTypes.mkString(", ")}") } } - def opsFor(t: ScalarType): ScalarTypeOperation[_] = { - ops.find(_.scalarType == t).getOrElse { - throw new IllegalArgumentException(s"Type $t is not supported. Only the following types are" + - s"supported: ${sqlTypes.mkString(", ")}") - } - } - - def opsFor(t: ProtoDataType): ScalarTypeOperation[_] = { + def opsFor(t: DataType): ScalarTypeOperation[_] = { ops.find(_.tfType == t).getOrElse { throw new IllegalArgumentException(s"Type $t is not supported. Only the following types are" + s"supported: ${tfTypes.mkString(", ")}") @@ -307,7 +252,7 @@ private[tensorframes] object SupportedOperations { def getOps[T : TypeTag](): ScalarTypeOperation[T] = { val ev: TypeTag[_] = implicitly[TypeTag[T]] - ops.find(_.tag.map(_.tpe =:= ev.tpe) == Some(true)).getOrElse { + ops.find(_.tag.tpe =:= ev.tpe).getOrElse { val tags = ops.map(_.tag.toString()).mkString(", ") throw new IllegalArgumentException(s"Type ${ev} is not supported. Only the following types " + s"are supported: ${tags}") @@ -354,9 +299,8 @@ private[impl] class DoubleTensorConverter(s: Shape, numCells: Int) private[impl] object DoubleOperations extends ScalarTypeOperation[Double] with Logging { override val sqlType = DoubleType - override val tfType = ProtoDataType.DT_DOUBLE + override val tfType = DataType.DT_DOUBLE override val tfType2 = tf.DataType.DOUBLE - override val scalarType = ScalarDoubleType final override val zero = 0.0 override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Double] = new DoubleTensorConverter(cellShape, numCells) @@ -381,9 +325,6 @@ private[impl] object DoubleOperations extends ScalarTypeOperation[Double] with L res } - override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Double]]) - - override def ev = ClassTag.Double } // ********** FLOAT ************ @@ -417,9 +358,8 @@ private[impl] class FloatTensorConverter(s: Shape, numCells: Int) private[impl] object FloatOperations extends ScalarTypeOperation[Float] with Logging { override val sqlType = FloatType - override val tfType = ProtoDataType.DT_FLOAT + override val tfType = DataType.DT_FLOAT override val tfType2 = tf.DataType.FLOAT - override val scalarType = ScalarFloatType final override val zero = 0.0f override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Float] = new FloatTensorConverter(cellShape, numCells) @@ -441,10 +381,6 @@ private[impl] object FloatOperations extends ScalarTypeOperation[Float] with Log t.writeTo(b) res } - - override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Float]]) - - override def ev = ClassTag.Float } // ********** INT32 ************ @@ -478,9 +414,8 @@ private[impl] class IntTensorConverter(s: Shape, numCells: Int) private[impl] object IntOperations extends ScalarTypeOperation[Int] with Logging { override val sqlType = IntegerType - override val tfType = ProtoDataType.DT_INT32 + override val tfType = DataType.DT_INT32 override val tfType2 = tf.DataType.INT32 - override val scalarType = ScalarIntType final override val zero = 0 override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Int] = new IntTensorConverter(cellShape, numCells) @@ -499,10 +434,6 @@ private[impl] object IntOperations extends ScalarTypeOperation[Int] with Logging dbuff.get(res) res } - - override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Int]]) - - override def ev = ClassTag.Int } // ****** INT64 (LONG) ****** @@ -536,9 +467,8 @@ private[impl] class LongTensorConverter(s: Shape, numCells: Int) private[impl] object LongOperations extends ScalarTypeOperation[Long] with Logging { override val sqlType = LongType - override val tfType = ProtoDataType.DT_INT64 + override val tfType = DataType.DT_INT64 override val tfType2 = tf.DataType.INT64 - override val scalarType = ScalarLongType final override val zero = 0L override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Long] = new LongTensorConverter(cellShape, numCells) @@ -558,69 +488,4 @@ private[impl] object LongOperations extends ScalarTypeOperation[Long] with Loggi logTrace(s"Extracted from buffer: ${res.toSeq}") res } - - override def tag: Option[TypeTag[_]] = Option(implicitly[TypeTag[Long]]) - - override def ev = ClassTag.Long -} - -// ********** STRING ********* -// This is actually byte arrays, which corresponds to the 'binary' type in Spark. - -// The string converter can only deal with one row at a time (the most common case). -private[impl] class StringTensorConverter(s: Shape, numCells: Int) - extends TensorConverter[Array[Byte]](s, numCells) with Logging { - private var buffer: Array[Byte] = null - - override val elementSize: Int = 1 - - { - logger.debug(s"Creating string buffer for shape $s and $numCells cells") - assert(s == Shape() && numCells == 1, s"The string buffer does not accept more than one" + - s" scalar of type binary. shape=$s numCells=$numCells") - } - - - override def reserve(): Unit = {} - - override def appendRaw(d: Array[Byte]): Unit = { - assert(buffer == null, s"The buffer has only been set with ${buffer.length} values," + - s" but ${d.length} are trying to get inserted") - buffer = d.clone() - } - - override def tensor2(): tf.Tensor = { - tf.Tensor.create(buffer) - } - - override def fillBuffer(buff: ByteBuffer): Unit = { - buff.put(buffer) - } -} - -private[impl] object StringOperations extends ScalarTypeOperation[Array[Byte]] with Logging { - override val sqlType = BinaryType - override val tfType = ProtoDataType.DT_STRING - override val tfType2 = tf.DataType.STRING - override val scalarType = ScalarBinaryType - final override val zero = Array.empty[Byte] - - override def tfConverter(cellShape: Shape, numCells: Int): TensorConverter[Array[Byte]] = - new StringTensorConverter(cellShape, numCells) - - override def convertTensor(t: tf.Tensor): MWrappedArray[Array[Byte]] = { - // TODO(tjh) implement later - ??? - } - - override def convertBuffer(buff: ByteBuffer, numElements: Int): Iterable[Any] = { - // TODO(tjh) implement later - ??? - } - - override def tag: Option[TypeTag[_]] = None - - override def ev = ??? -} - - +} \ No newline at end of file diff --git a/src/main/scala/org/tensorframes/test/dsl.scala b/src/main/scala/org/tensorframes/test/dsl.scala index fbe4f7a..1d14503 100644 --- a/src/main/scala/org/tensorframes/test/dsl.scala +++ b/src/main/scala/org/tensorframes/test/dsl.scala @@ -2,10 +2,10 @@ package org.tensorframes.test import java.nio.file.{Files, Paths} -import org.apache.spark.sql.types.{DataType, NumericType} -import org.tensorflow.framework.{AttrValue, GraphDef, NodeDef, TensorShapeProto, DataType => ProtoDataType} +import org.apache.spark.sql.types.{DoubleType, NumericType} +import org.tensorflow.framework._ import org.tensorframes.{Logging, Shape} -import org.tensorframes.impl.{DenseTensor, ScalarType, SupportedOperations} +import org.tensorframes.impl.{DenseTensor, SupportedOperations} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe._ @@ -25,7 +25,7 @@ object dsl extends Logging { def toAttr: AttrValue = buildType(s) } - private implicit class DataTypeToAttr(dt: ProtoDataType) { + private implicit class DataTypeToAttr(dt: DataType) { def toAttr: AttrValue = dataTypeToAttrValue(dt) } @@ -66,8 +66,8 @@ object dsl extends Logging { def +(other: Node): Node = op_add(this, other) } - private[tensorframes] def placeholder(dtype: DataType, shape: Shape): Node = { - build("Placeholder", shape=shape, dtype=dtype.asInstanceOf[NumericType], isOp = false, + private[tensorframes] def placeholder(dtype: NumericType, shape: Shape): Node = { + build("Placeholder", shape=shape, dtype=dtype, isOp = false, extraAttrs = Map("shape" -> shape.toAttr)) } @@ -165,9 +165,8 @@ object dsl extends Logging { private def build_constant(dt: DenseTensor): Node = { val a = AttrValue.newBuilder().setTensor(DenseTensor.toTensorProto(dt)) - val dt2 = SupportedOperations.opsFor(dt.dtype).sqlType.asInstanceOf[NumericType] build("Const", isOp = false, - shape = dt.shape, dtype = dt2, + shape = dt.shape, dtype = dt.dtype, extraAttrs = Map("value" -> a.build())) } @@ -197,7 +196,7 @@ object dsl extends Logging { dtype = parent.scalarType, shape = reduce_shape(parent.shape, Option(reduction_indices).getOrElse(Nil)), extraAttrs = Map( - "Tidx" -> AttrValue.newBuilder().setType(ProtoDataType.DT_INT32).build(), + "Tidx" -> AttrValue.newBuilder().setType(DataType.DT_INT32).build(), "keep_dims" -> AttrValue.newBuilder().setB(false).build())) } @@ -219,16 +218,13 @@ object dsl extends Logging { * Utilities to convert data back and forth between the proto descriptions and the dataframe descriptions. */ object ProtoConversions { - def getDType(nodeDef: NodeDef): ProtoDataType = { + def getDType(nodeDef: NodeDef): DataType = { val opt = Option(nodeDef.getAttr.get("T")).orElse(Option(nodeDef.getAttr.get("dtype"))) val v = opt.getOrElse(throw new Exception(s"Neither 'T' no 'dtype' was found in $nodeDef")) v.getType } - def getDType(sqlType: NumericType): ProtoDataType = { - SupportedOperations.opsFor(sqlType).tfType - } - def getDType(sqlType: ScalarType): ProtoDataType = { + def getDType(sqlType: NumericType): DataType = { SupportedOperations.opsFor(sqlType).tfType } @@ -236,7 +232,7 @@ object ProtoConversions { AttrValue.newBuilder().setType(getDType(sqlType)).build() } - def dataTypeToAttrValue(dataType: ProtoDataType): AttrValue = { + def dataTypeToAttrValue(dataType: DataType): AttrValue = { AttrValue.newBuilder().setType(dataType).build() } From 2d0b49773aaf4928ea36fa42ae0effc94f9df474 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Thu, 20 Apr 2017 14:07:32 -0700 Subject: [PATCH 6/7] changes --- src/main/python/tensorframes/core.py | 5 ++--- src/main/python/tensorframes/core_test.py | 15 +-------------- .../org/tensorframes/DebugRowOpsSuite.scala | 10 +++++----- .../tensorframes/ExtraOperationsSuite.scala | 19 +++++++++---------- .../perf/ConvertBackPerformanceSuite.scala | 6 +++--- .../perf/ConvertPerformanceSuite.scala | 4 ++-- 6 files changed, 22 insertions(+), 37 deletions(-) diff --git a/src/main/python/tensorframes/core.py b/src/main/python/tensorframes/core.py index 17d70c4..c148551 100644 --- a/src/main/python/tensorframes/core.py +++ b/src/main/python/tensorframes/core.py @@ -43,8 +43,7 @@ def _add_graph(graph, builder, use_file=True): fname = d + "/proto.pb" builder.graphFromFile(fname) else: - # Make sure that TF adds the shapes. - gser = graph.as_graph_def(add_shapes=True).SerializeToString() + gser = graph.as_graph_def().SerializeToString() gbytes = bytearray(gser) builder.graph(gbytes) @@ -56,7 +55,7 @@ def _add_shapes(graph, builder, fetches): # dimensions are unknown ph_names = [] ph_shapes = [] - for n in graph.as_graph_def(add_shapes=True).node: + for n in graph.as_graph_def().node: # Just the input nodes: if not n.input: op_name = n.name diff --git a/src/main/python/tensorframes/core_test.py b/src/main/python/tensorframes/core_test.py index cb6e8a3..419ebf2 100644 --- a/src/main/python/tensorframes/core_test.py +++ b/src/main/python/tensorframes/core_test.py @@ -14,9 +14,7 @@ class TestCore(object): @classmethod def setup_class(cls): print("setup ", cls) - sc = SparkContext('local[1]', cls.__name__) - sc.setLogLevel('DEBUG') - cls.sc = sc + cls.sc = SparkContext('local[1]', cls.__name__) @classmethod def teardown_class(cls): @@ -27,7 +25,6 @@ def setUp(self): self.sql = SQLContext(TestCore.sc) self.api = _java_api() self.api.initialize_logging() - TestCore.sc.setLogLevel('INFO') print("setup") @@ -129,16 +126,6 @@ def test_groupby_1(self): data2 = df2.collect() assert data2 == [Row(key='0', x=2.0), Row(key='1', x=4.0)], data2 - def test_byte_array(self): - data = [Row(x=bytearray('123', 'utf-8'))] - df = self.sql.createDataFrame(data) - with tf.Graph().as_default(): - x = tf.placeholder(tf.string, shape=[], name="x") - z = tf.string_to_number(x, tf.int32, name='z') - df2 = tfs.map_rows(z, df) - data2 = df2.collect() - assert data2[0].z == 123, data2 - if __name__ == "__main__": # Some testing stuff that should not be executed diff --git a/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala b/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala index cd7d50c..d92e58e 100644 --- a/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala +++ b/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala @@ -3,7 +3,7 @@ package org.tensorframes import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DoubleType, StructType} import org.scalatest.FunSuite -import org.tensorframes.impl.{DebugRowOpsImpl, ScalarDoubleType} +import org.tensorframes.impl.DebugRowOpsImpl import org.tensorframes.dsl._ class DebugRowOpsSuite @@ -14,10 +14,10 @@ class DebugRowOpsSuite testGraph("Simple identity") { val rows = Array(Row(1.0)) - val input = StructType(Array(structField("x", ScalarDoubleType, Shape(Unknown)))) + val input = StructType(Array(structField("x", DoubleType, Shape(Unknown)))) val p2 = placeholder[Double](1) named "x" val out = identity(p2) named "y" - val outputSchema = StructType(Array(structField("y", ScalarDoubleType, Shape(Unknown)))) + val outputSchema = StructType(Array(structField("y", DoubleType, Shape(Unknown)))) val (g, _) = TestUtilities.analyzeGraph(out) logDebug(g.toString) val res = DebugRowOpsImpl.performMap(rows, input, Array(0), g, outputSchema) @@ -26,10 +26,10 @@ class DebugRowOpsSuite testGraph("Simple add") { val rows = Array(Row(1.0)) - val input = StructType(Array(structField("x", ScalarDoubleType, Shape(Unknown)))) + val input = StructType(Array(structField("x", DoubleType, Shape(Unknown)))) val p2 = placeholder[Double](1) named "x" val out = p2 + p2 named "y" - val outputSchema = StructType(Array(structField("y", ScalarDoubleType, Shape(Unknown)))) + val outputSchema = StructType(Array(structField("y", DoubleType, Shape(Unknown)))) val (g, _) = TestUtilities.analyzeGraph(out) logDebug(g.toString) val res = DebugRowOpsImpl.performMap(rows, input, Array(0), g, outputSchema) diff --git a/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala b/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala index b197df7..2a0a1a0 100644 --- a/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala +++ b/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala @@ -2,7 +2,6 @@ package org.tensorframes import org.apache.spark.sql.types.{DoubleType, IntegerType} import org.scalatest.FunSuite -import org.tensorframes.impl.{ScalarDoubleType, ScalarIntType} class ExtraOperationsSuite @@ -17,7 +16,7 @@ class ExtraOperationsSuite val di = ExtraOperations.explainDetailed(df) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === ScalarDoubleType) + assert(s.dataType === DoubleType) assert(s.shape === Shape(Unknown)) logDebug(df.toString() + "->" + di.toString) } @@ -27,7 +26,7 @@ class ExtraOperationsSuite val di = explainDetailed(df) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === ScalarIntType) + assert(s.dataType === IntegerType) assert(s.shape === Shape(Unknown)) logDebug(df.toString() + "->" + di.toString) } @@ -38,13 +37,13 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1, c2, c3) = di.cols val Some(s1) = c1.stf - assert(s1.dataType === ScalarDoubleType) + assert(s1.dataType === DoubleType) assert(s1.shape === Shape(Unknown)) val Some(s2) = c2.stf - assert(s2.dataType === ScalarDoubleType) + assert(s2.dataType === DoubleType) assert(s2.shape === Shape(Unknown, Unknown)) val Some(s3) = c3.stf - assert(s3.dataType === ScalarDoubleType) + assert(s3.dataType === DoubleType) assert(s3.shape === Shape(Unknown, Unknown, Unknown)) } @@ -55,7 +54,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === ScalarDoubleType) + assert(s.dataType === DoubleType) assert(s.shape === Shape(1)) // There is only one partition } @@ -66,7 +65,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1) = di.cols val Some(s) = c1.stf - assert(s.dataType === ScalarDoubleType) + assert(s.dataType === DoubleType) assert(s.shape === Shape(Unknown)) // There is only one partition } @@ -79,7 +78,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1, c2) = di.cols val Some(s2) = c2.stf - assert(s2.dataType === ScalarDoubleType) + assert(s2.dataType === DoubleType) assert(s2.shape === Shape(2, Unknown)) // There is only one partition } @@ -93,7 +92,7 @@ class ExtraOperationsSuite logDebug(df.toString() + "->" + di.toString) val Seq(c1, c2) = di.cols val Some(s2) = c2.stf - assert(s2.dataType === ScalarDoubleType) + assert(s2.dataType === DoubleType) assert(s2.shape === Shape(3, 2)) // There is only one partition } } diff --git a/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala b/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala index 3624e72..a1680f9 100644 --- a/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala +++ b/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala @@ -2,7 +2,7 @@ package org.tensorframes.perf import org.scalatest.FunSuite import org.tensorframes.{ColumnInformation, Shape, TensorFramesTestSparkContext} -import org.tensorframes.impl.{ScalarIntType, SupportedOperations, TFDataOps} +import org.tensorframes.impl.{SupportedOperations, TFDataOps} import org.tensorframes.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.types._ @@ -48,9 +48,9 @@ class ConvertBackPerformanceSuite // Creating the rows this way, because we need to respect the collection used by Spark when // unpacking the rows. val rows = sqlContext.createDataFrame(Seq.fill(numCells)(Tuple1(Seq.fill(numVals)(1)))).collect() - val schema = StructType(Seq(ColumnInformation.structField("f1", ScalarIntType, + val schema = StructType(Seq(ColumnInformation.structField("f1", IntegerType, Shape(numCells, numVals)))) - val tfSchema = StructType(Seq(ColumnInformation.structField("f2", ScalarIntType, + val tfSchema = StructType(Seq(ColumnInformation.structField("f2", IntegerType, Shape(numCells, numVals)))) val tensor = getTFTensor(IntegerType, Row(Seq.fill(numVals)(1)), Shape(numVals), numCells) println("generated data") diff --git a/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala b/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala index b3112e8..556c8f1 100644 --- a/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala +++ b/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala @@ -5,7 +5,7 @@ import org.tensorframes.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import org.tensorframes.{ColumnInformation, Shape, TensorFramesTestSparkContext} -import org.tensorframes.impl.{DataOps, ScalarIntType, SupportedOperations, TFDataOps} +import org.tensorframes.impl.{DataOps, SupportedOperations, TFDataOps} class ConvertPerformanceSuite extends FunSuite with TensorFramesTestSparkContext with Logging { @@ -44,7 +44,7 @@ class ConvertPerformanceSuite // Creating the rows this way, because we need to respect the collection used by Spark when // unpacking the rows. val rows = sqlContext.createDataFrame(Seq.fill(numCells)(Tuple1(Seq.fill(numVals)(1)))).collect() - val schema = StructType(Seq(ColumnInformation.structField("f1", ScalarIntType, + val schema = StructType(Seq(ColumnInformation.structField("f1", IntegerType, Shape(numCells, numVals)))) println("generated data") logInfo("generated data") From 7ab76b85a5b695b52623a6ee7c386fd396ab1c1b Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Mon, 24 Apr 2017 10:09:42 -0700 Subject: [PATCH 7/7] small change --- .../python/tensorframes_snippets/preparation_inceptionv3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/python/tensorframes_snippets/preparation_inceptionv3.py b/src/main/python/tensorframes_snippets/preparation_inceptionv3.py index 0276373..77f3eb9 100644 --- a/src/main/python/tensorframes_snippets/preparation_inceptionv3.py +++ b/src/main/python/tensorframes_snippets/preparation_inceptionv3.py @@ -185,5 +185,5 @@ def get_op_name(tensor): value_output = tf.identity(g2.get_tensor_by_name('top_predictions:0'), name="value") pred_df = tfs.map_rows([index_output, value_output], df, feed_dict={'image_input':'image_data'}) -pred_df.select('index', 'value').head() +pred_df.select('index', 'value').show()