diff --git a/examples/fft_benchmark.py b/examples/fft_benchmark.py new file mode 100644 index 00000000..464ab933 --- /dev/null +++ b/examples/fft_benchmark.py @@ -0,0 +1,149 @@ +""" +Benchmark script for studying the scaling of distributed FFTs on Mesh Tensorflow +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +import tensorflow.compat.v1 as tf +import mesh_tensorflow as mtf + +from tensorflow.python.tpu import tpu_config # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.tpu import tpu_estimator # pylint: disable=g-direct-tensorflow-import +from tensorflow_estimator.python.estimator import estimator as estimator_lib + +# Cloud TPU Cluster Resolver flags +tf.flags.DEFINE_string( + "tpu", default=None, + help="The Cloud TPU to use for training. This should be either the name " + "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " + "url.") +tf.flags.DEFINE_string( + "tpu_zone", default=None, + help="[Optional] GCE zone where the Cloud TPU is located in. If not " + "specified, we will attempt to automatically detect the GCE project from " + "metadata.") +tf.flags.DEFINE_string( + "gcp_project", default=None, + help="[Optional] Project name for the Cloud TPU-enabled project. If not " + "specified, we will attempt to automatically detect the GCE project from " + "metadata.") + +tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir") + +tf.flags.DEFINE_integer("cube_size", 512, "Size of the 3D volume.") +tf.flags.DEFINE_integer("batch_size", 128, + "Mini-batch size for the training. Note that this " + "is the global batch size and not the per-shard batch.") + +tf.flags.DEFINE_string("mesh_shape", "b1:32", "mesh shape") +tf.flags.DEFINE_string("layout", "nx:b1,tny:b1", "layout rules") + +FLAGS = tf.flags.FLAGS + +def benchmark_model(mesh): + """ + Initializes a 3D volume with random noise, and execute a forward FFT + """ + batch_dim = mtf.Dimension("batch", FLAGS.batch_size) + + # Declares real space dimensions + x_dim = mtf.Dimension("nx", FLAGS.cube_size) + y_dim = mtf.Dimension("ny", FLAGS.cube_size) + z_dim = mtf.Dimension("nz", FLAGS.cube_size) + + # Declares Fourier space dimensions + tx_dim = mtf.Dimension("tnx", FLAGS.cube_size) + ty_dim = mtf.Dimension("tny", FLAGS.cube_size) + tz_dim = mtf.Dimension("tnz", FLAGS.cube_size) + + # Create field + field = mtf.random_uniform(mesh, [batch_dim, x_dim, y_dim, z_dim]) + + # Apply FFT + fft_field = mtf.signal.fft3d(mtf.cast(field, tf.complex64), [tx_dim, ty_dim, tz_dim]) + + # Inverse FFT + rfield = mtf.cast(mtf.signal.ifft3d(fft_field, [x_dim, y_dim, z_dim]), tf.float32) + + # Compute errors + err = mtf.reduce_max(mtf.abs(field - rfield)) + return err + +def model_fn(features, labels, mode, params): + """A model is called by TpuEstimator.""" + del labels + del features + + mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) + layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) + + ctx = params['context'] + num_hosts = ctx.num_hosts + host_placement_fn = ctx.tpu_host_placement_function + device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] + tf.logging.info('device_list = %s' % device_list,) + + mesh_devices = [''] * mesh_shape.size + mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( + mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) + + graph = mtf.Graph() + mesh = mtf.Mesh(graph, "fft_mesh") + + with mtf.utils.outside_all_rewrites(): + err = benchmark_model(mesh) + + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) + + tf_err = tf.to_float(lowering.export_to_tf_tensor(err)) + + with mtf.utils.outside_all_rewrites(): + return tpu_estimator.TPUEstimatorSpec(mode, loss=tf_err) + + +def main(_): + + tf.logging.set_verbosity(tf.logging.INFO) + mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) + + # Resolve the TPU environment + tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + FLAGS.tpu, + zone=FLAGS.tpu_zone, + project=FLAGS.gcp_project + ) + + run_config = tf.estimator.tpu.RunConfig( + cluster=tpu_cluster_resolver, + model_dir=FLAGS.model_dir, + save_checkpoints_steps=None, # Disable the default saver + save_checkpoints_secs=None, # Disable the default saver + log_step_count_steps=100, + save_summary_steps=100, + tpu_config=tpu_config.TPUConfig( + num_shards=mesh_shape.size, + iterations_per_loop=100, + num_cores_per_replica=1, + per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST)) + + model = tpu_estimator.TPUEstimator( + use_tpu=True, + model_fn=model_fn, + config=run_config, + train_batch_size=FLAGS.batch_size, + eval_batch_size=FLAGS.batch_size) + + def dummy_input_fn(params): + """Dummy input function """ + return tf.zeros(shape=[params['batch_size']], dtype=tf.float32), tf.zeros(shape=[params['batch_size']], dtype=tf.float32) + + # Run evaluate loop for ever, we will be connecting to this process using a profiler + model.evaluate(input_fn=dummy_input_fn, steps=100000) + +if __name__ == "__main__": + tf.disable_v2_behavior() + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run() diff --git a/mesh_tensorflow/ops.py b/mesh_tensorflow/ops.py index 4b21e569..af81912f 100644 --- a/mesh_tensorflow/ops.py +++ b/mesh_tensorflow/ops.py @@ -1594,8 +1594,8 @@ def to_string(self): @property def has_gradient(self): return ( - [t for t in self.inputs if t.dtype.is_floating] and - [t for t in self.outputs if t.dtype.is_floating]) + [t for t in self.inputs if t.dtype.is_floating or t.dtype.is_complex] and + [t for t in self.outputs if t.dtype.is_floating or t.dtype.is_complex]) def gradient(self, unused_grad_ys): raise NotImplementedError("Gradient not implemented") @@ -5815,7 +5815,7 @@ def random_uniform(mesh, shape, **kwargs): def random_normal(mesh, shape, **kwargs): - """Random uniform. + """Random normal. Args: mesh: a Mesh @@ -6674,3 +6674,56 @@ def reduce_first(tensor, reduced_dim): r = mtf_range(tensor.mesh, reduced_dim, dtype=tf.int32) first_element_filter = cast(equal(r, 0), tensor.dtype) return reduce_sum(tensor * first_element_filter, reduced_dim=reduced_dim) + + +def to_complex(x, complex_dim=None): + """Gathers the real and imaginary of a tensor in a complex tensor + + Args: + x: a float Tensor + complex_dim: a Dimension where both the real and imaginary parts of the + tensor are. Defaults to None, which corresponds to the last + dimension of the tensor. + Returns: + a Tensor, complex-valued + """ + if complex_dim is None: + complex_dim = x.shape[-1] + x_real, x_imag = split(x, complex_dim, 2) + x_real = cast(x_real, tf.complex64) + x_imag = cast(x_imag, tf.complex64) + x_complex = x_real + 1j * x_imag + return x_complex + + +def split_complex(x, complex_dim=None): + """Splits a complex tensor into real and imaginary, concatenated + + Args: + x: a float Tensor + complex_dim: a Dimension where you want the split to happen. + Defaults to None, which corresponds to the last dimension of the tensor. + Returns: + a Tensor, float-valued + """ + if complex_dim is None: + split_dim = x.shape.dims[-1] + split_axis = -1 + else: + split_dim = complex_dim + split_axis = x.shape.index(complex_dim) + splittable_dims = [d for d in x.shape if d != split_dim] + def tf_fn(tf_input): + tf_real = tf.math.real(tf_input) + tf_imag = tf.math.imag(tf_input) + output = tf.concat([tf_real, tf_imag], axis=split_axis) + return output + output = slicewise( + tf_fn, + [x], + output_shape=x.shape.resize_dimension(split_dim.name, split_dim.size*2), + output_dtype=tf.float32, + splittable_dims=splittable_dims, + name='split_complex', + ) + return output diff --git a/mesh_tensorflow/ops_test.py b/mesh_tensorflow/ops_test.py index 14355f5e..11e34871 100644 --- a/mesh_tensorflow/ops_test.py +++ b/mesh_tensorflow/ops_test.py @@ -671,6 +671,51 @@ def x_squared_plus_x(x): self.evaluate(expected_dx)) +class ComplexManipulationTest(tf.test.TestCase): + def setUp(self): + super(ComplexManipulationTest, self).setUp() + self.graph = mtf.Graph() + self.mesh = mtf.Mesh(self.graph, "my_mesh") + + def testToComplex(self): + tensor = tf.random.normal([1, 10, 4]) + mtf_shape = [mtf.Dimension(f'dim_{i}', s) for i, s in enumerate(tensor.shape)] + tensor_mesh = mtf.import_tf_tensor(self.mesh, tensor, shape=mtf_shape) + outputs = mtf.to_complex(tensor_mesh) + assert outputs.dtype == tf.complex64 + assert len(outputs.shape) == 3 + assert outputs.shape[-1].size == 2 + assert [s.size for s in outputs.shape[:-1]] == [s.size for s in tensor_mesh.shape[:-1]] + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( + shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) + outputs_tf = lowering.export_to_tf_tensor(outputs) + self.assertAllEqual( + outputs_tf, + tf.complex(tensor[..., 0:2], tensor[..., 2:4]), + ) + + def testSplitComplex(self): + tensor = tf.complex( + tf.random.normal([1, 10, 2]), + tf.random.normal([1, 10, 2]), + ) + mtf_shape = [mtf.Dimension(f'dim_{i}', s) for i, s in enumerate(tensor.shape)] + tensor_mesh = mtf.import_tf_tensor(self.mesh, tensor, shape=mtf_shape) + outputs = mtf.split_complex(tensor_mesh) + assert outputs.dtype == tf.float32 + assert len(outputs.shape) == 3 + assert outputs.shape[-1].size == 4 + assert [s.size for s in outputs.shape[:-1]] == [s.size for s in tensor_mesh.shape[:-1]] + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( + shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) + outputs_tf = lowering.export_to_tf_tensor(outputs) + self.assertAllEqual( + outputs_tf, + tf.concat([tf.math.real(tensor), tf.math.imag(tensor)], axis=-1), + ) + if __name__ == "__main__": tf.disable_v2_behavior() tf.enable_eager_execution() diff --git a/mesh_tensorflow/ops_with_redefined_builtins.py b/mesh_tensorflow/ops_with_redefined_builtins.py index 50848c6e..9b5ef108 100644 --- a/mesh_tensorflow/ops_with_redefined_builtins.py +++ b/mesh_tensorflow/ops_with_redefined_builtins.py @@ -24,7 +24,7 @@ from mesh_tensorflow.ops import mtf_pow as pow # pylint: disable=redefined-builtin,unused-import from mesh_tensorflow.ops import mtf_range as range # pylint: disable=redefined-builtin,unused-import from mesh_tensorflow.ops import mtf_slice as slice # pylint: disable=redefined-builtin,unused-import - +import mesh_tensorflow.signal_ops as signal # TODO(trandustin): Seal module. diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py new file mode 100644 index 00000000..27fdc5f3 --- /dev/null +++ b/mesh_tensorflow/signal_ops.py @@ -0,0 +1,122 @@ +"""Spectral ops for Mesh TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import tensorflow.compat.v1 as tf + +from mesh_tensorflow import ops_with_redefined_builtins as mtf + + +class FFT3DBaseOperation(mtf.Operation): + """ Base class for performing distributed FFTs. + + Handles slicewise ffts and array transpositions. Note that to save one global + transposition at the end of forward and inverse FFTs, these operations + assume a transposed fourier space with shape: + input.shape[:-3] + freq_dims[1] + freq_dims[2] + freq_dims[0] + """ + def __init__(self, inputs, dims, inverse=False, name=None): + self.inverse = inverse + if self.inverse: + self.default_name = 'IFFT3D' + self.tf_op = tf.spectral.ifft + else: + self.default_name = 'FFT3D' + self.tf_op = tf.spectral.fft + super(FFT3DBaseOperation, self).__init__([inputs], name=name or self.default_name) + self._dims = dims + if self.inverse: + dims_reordered = dims + else: + dims_reordered = [dims[1], dims[2], dims[0]] + self._output_shape = mtf.Shape(inputs.shape[:-3]+dims_reordered) + self._outputs = [mtf.Tensor(self, mtf.Shape(self._output_shape), inputs.dtype)] + + def gradient(self, grad_ys): + dy = grad_ys[0] + if self.inverse: + ky, kz, kx = self.inputs[0].shape[-3:] + return [fft3d(dy, [kx, ky, kz])] + else: + x = self.inputs[0] + return [ifft3d(dy, x.shape[-3:])] + + def lower(self, lowering): + mesh_impl = lowering.mesh_impl(self) + x = self.inputs[0] + naxes = len(x.shape) + slices = lowering.tensors[self.inputs[0]] + # Before performing any operations, we check the splitting + split_axes = [] + for i in range(3): + split_axes.append(mesh_impl.tensor_dimension_to_mesh_axis(x.shape.dims[-3:][i])) + + # Perform transform followed by tranposes + for i in range(2): + # Apply FFT along last axis + slices = mesh_impl.slicewise(self.tf_op, slices) + + split_axes, slices = self._transpose( + mesh_impl, + split_axes, + slices, + naxes, + ) + + # Apply transform along last axis + slices = mesh_impl.slicewise(self.tf_op, slices) + lowering.set_tensor_lowering(self.outputs[0], slices) + + def _transpose(self, mesh_impl, split_axes, slices, naxes): + # Before transposing the array, making sure the new last dimension will + # be contiguous + if self.inverse: + if split_axes[0] is not None: + slices = mesh_impl.alltoall(slices, split_axes[0], naxes-1, naxes-3) + split_axes[-1] = split_axes[0] + split_axes[0] = None + split_axes = [split_axes[1], split_axes[2], split_axes[0]] + else: + if split_axes[-2] is not None: + slices = mesh_impl.alltoall(slices, split_axes[-2], naxes-1, naxes-2) + split_axes[-1] = split_axes[-2] + split_axes[-2] = None + split_axes = [split_axes[2], split_axes[0], split_axes[1]] + + perm = np.arange(naxes) + perm[-3:] = np.roll(perm[-3:], shift=-1 if self.inverse else 1) + slices = mesh_impl.slicewise(lambda x: tf.transpose(x, perm), slices) + return split_axes, slices + +def fft3d(x, freq_dims, name=None): + """ + Computes the 3-dimensional discrete Fourier transform over the inner-most 3 + dimensions of input tensor. Note that the output FFT is transposed. + + Args: + input: A Tensor. Must be one of the following types: complex64, complex128 + freq_dims: List of 3 Dimensions representing the frequency dimensions. + name: A name for the operation (optional). + + Returns: + A Tensor of shape `input.shape[:-3] + freq_dims[1] + freq_dims[2] + freq_dims[0]`. + """ + return FFT3DBaseOperation(x, freq_dims, inverse=False, name=name).outputs[0] + +def ifft3d(x, dims, name=None): + """ + Computes the inverse 3-dimensional discrete Fourier transform over the inner-most 3 + dimensions of input tensor. Note that the input FFT is assumed transposed. + + Args: + input: A Tensor. Must be one of the following types: complex64, complex128 + dims: List of 3 Dimensions representing the direct space dimensions. + name: A name for the operation (optional). + + Returns: + A Tensor of shape `input.shape[:-3] + dims`. + """ + return FFT3DBaseOperation(x, dims, inverse=True, name=name).outputs[0] diff --git a/mesh_tensorflow/signal_ops_test.py b/mesh_tensorflow/signal_ops_test.py new file mode 100644 index 00000000..a6f75f10 --- /dev/null +++ b/mesh_tensorflow/signal_ops_test.py @@ -0,0 +1,63 @@ +import mesh_tensorflow as mtf +from mesh_tensorflow.signal_ops import fft3d, ifft3d +import tensorflow as tf + + +class FFTTest(tf.test.TestCase): + def setUp(self): + super(FFTTest, self).setUp() + self.graph = mtf.Graph() + self.mesh = mtf.Mesh(self.graph, "my_mesh") + volume_size = 32 + batch_dim = mtf.Dimension("batch", 1) + slices_dim = mtf.Dimension("slices", volume_size//2) + rows_dim = mtf.Dimension("rows", volume_size) + cols_dim = mtf.Dimension("cols", volume_size) + self.shape = [batch_dim, slices_dim, rows_dim, cols_dim,] + volume_shape = [d.size for d in self.shape] + self.volume = tf.complex( + tf.random.normal(volume_shape), + tf.random.normal(volume_shape), + ) + self.volume_mesh = mtf.import_tf_tensor(self.mesh, self.volume, shape=self.shape) + + + def testFft3d(self): + outputs = fft3d(self.volume_mesh, freq_dims=self.shape[1:4]) + assert len(outputs.shape) == 4 + assert outputs.dtype == tf.complex64 + assert set(outputs.shape) == set(mtf.Shape(self.shape)) + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( + shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) + outputs_tf = lowering.export_to_tf_tensor(outputs) + expected_outputs = tf.signal.fft3d(self.volume) + expected_outputs = tf.transpose(expected_outputs, perm=[0, 2, 3, 1]) + self.assertAllClose( + outputs_tf, + expected_outputs, + rtol=1e-4, + atol=1e-4, + ) + + def testIfft3d(self): + outputs = ifft3d( + self.volume_mesh, + # ordering is not the same for ifft3d + dims=[self.shape[3], self.shape[1], self.shape[2]], + ) + assert len(outputs.shape) == 4 + assert outputs.dtype == tf.complex64 + assert set(outputs.shape) == set(mtf.Shape(self.shape)) + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( + shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) + outputs_tf = lowering.export_to_tf_tensor(outputs) + volume = tf.transpose(self.volume, perm=[0, 3, 1, 2]) + expected_outputs = tf.signal.ifft3d(volume) + self.assertAllClose( + outputs_tf, + expected_outputs, + rtol=1e-4, + atol=1e-4, + )