diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 180f341fc..f36980b92 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -390,16 +390,16 @@ def transform_normals(self, normals) -> torch.Tensor: return normals_out def translate(self, *args, **kwargs) -> "Transform3d": - return self.compose(Translate(device=self.device, *args, **kwargs)) + return self.compose(Translate(device=self.device, dtype=self.dtype, *args, **kwargs)) def scale(self, *args, **kwargs) -> "Transform3d": - return self.compose(Scale(device=self.device, *args, **kwargs)) + return self.compose(Scale(device=self.device, dtype=self.dtype, *args, **kwargs)) def rotate(self, *args, **kwargs) -> "Transform3d": - return self.compose(Rotate(device=self.device, *args, **kwargs)) + return self.compose(Rotate(device=self.device, dtype=self.dtype, *args, **kwargs)) def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d": - return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs)) + return self.compose(RotateAxisAngle(device=self.device, dtype=self.dtype, *args, **kwargs)) def clone(self) -> "Transform3d": """ @@ -488,7 +488,7 @@ def __init__( - A 1D torch tensor """ xyz = _handle_input(x, y, z, dtype, device, "Translate") - super().__init__(device=xyz.device) + super().__init__(device=xyz.device, dtype=dtype) N = xyz.shape[0] mat = torch.eye(4, dtype=dtype, device=self.device) @@ -532,7 +532,7 @@ def __init__( - 1D torch tensor """ xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True) - super().__init__(device=xyz.device) + super().__init__(device=xyz.device, dtype=dtype) N = xyz.shape[0] # TODO: Can we do this all in one go somehow? @@ -571,7 +571,7 @@ def __init__( """ device_ = get_device(R, device) - super().__init__(device=device_) + super().__init__(device=device_, dtype=dtype) if R.dim() == 2: R = R[None] if R.shape[-2:] != (3, 3): @@ -598,7 +598,7 @@ def __init__( angle, axis: str = "X", degrees: bool = True, - dtype: torch.dtype = torch.float64, + dtype: torch.dtype = torch.float32, device: Optional[Device] = None, ) -> None: """ @@ -629,7 +629,7 @@ def __init__( # is for transforming column vectors. Therefore we transpose this matrix. # R will always be of shape (N, 3, 3) R = _axis_angle_rotation(axis, angle).transpose(1, 2) - super().__init__(device=angle.device, R=R) + super().__init__(device=angle.device, R=R, dtype=dtype) def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor: @@ -646,8 +646,8 @@ def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor: c = torch.tensor(c, dtype=dtype, device=device) if c.dim() == 0: c = c.view(1) - if c.device != device: - c = c.to(device=device) + if c.device != device or c.dtype != dtype: + c = c.to(device=device, dtype=dtype) return c @@ -696,7 +696,7 @@ def _handle_input( if y is not None or z is not None: msg = "Expected y and z to be None (in %s)" % name raise ValueError(msg) - return x.to(device=device_) + return x.to(device=device_, dtype=dtype) if allow_singleton and y is None and z is None: y = x diff --git a/tests/test_transforms.py b/tests/test_transforms.py index f4690a413..5e136112f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -87,6 +87,34 @@ def test_to(self): t = t.cuda() t = t.cpu() + def test_dtype_propagation(self): + """ + Check that a given dtype is correctly passed along to child + transformations. + """ + # Use at least two dtypes so we avoid only testing on the + # default dtype. + for dtype in [torch.float32, torch.float64]: + R = torch.tensor( + [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], + dtype=dtype, + ) + tf = Transform3d(dtype=dtype) \ + .rotate(R) \ + .rotate_axis_angle( + R[0], + 'X', + ) \ + .translate(3, 2, 1) \ + .scale(0.5) + + self.assertEqual(tf.dtype, dtype) + for inner_tf in tf._transforms: + self.assertEqual(inner_tf.dtype, dtype) + + transformed = tf.transform_points(R) + self.assertEqual(transformed.dtype, dtype) + def test_clone(self): """ Check that cloned transformations contain different _matrix objects. @@ -219,8 +247,8 @@ def test_rotate_axis_angle(self): normals_out_expected = torch.tensor( [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] ).view(1, 3, 3) - self.assertTrue(torch.allclose(points_out, points_out_expected)) - self.assertTrue(torch.allclose(normals_out, normals_out_expected)) + self.assertTrue(torch.allclose(points_out, points_out_expected, atol=1e-7)) + self.assertTrue(torch.allclose(normals_out, normals_out_expected, atol=1e-7)) def test_transform_points_fail(self): t1 = Scale(0.1, 0.1, 0.1) @@ -951,7 +979,7 @@ def test_rotate_x_python_scalar(self): self.assertTrue( torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7) ) - self.assertTrue(torch.allclose(t._matrix, matrix)) + self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) def test_rotate_x_torch_scalar(self): angle = torch.tensor(90.0)