Skip to content

Commit 722c2b7

Browse files
megluyagaofacebook-github-bot
authored andcommitted
Add BlenderCamera
Summary: Adding BlenderCamera (for rendering with R2N2 Blender transformations in the next diff). Reviewed By: nikhilaravi Differential Revision: D22462515 fbshipit-source-id: 4b40ee9bba8b6d56788dd3c723036ec704668153
1 parent 483e538 commit 722c2b7

File tree

4 files changed

+55
-3
lines changed

4 files changed

+55
-3
lines changed

pytorch3d/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
from .r2n2 import R2N2
3+
from .r2n2 import R2N2, BlenderCamera
44
from .shapenet import ShapeNetCore
55
from .utils import collate_batched_meshes
66

pytorch3d/datasets/r2n2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
from .r2n2 import R2N2
3+
from .r2n2 import R2N2, BlenderCamera
44

55

66
__all__ = [k for k in globals().keys() if not k.startswith("_")]

pytorch3d/datasets/r2n2/r2n2.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,18 @@
1111
from PIL import Image
1212
from pytorch3d.datasets.shapenet_base import ShapeNetBase
1313
from pytorch3d.io import load_obj
14+
from pytorch3d.renderer.cameras import CamerasBase
15+
from pytorch3d.transforms import Transform3d
1416
from tabulate import tabulate
1517

1618

1719
SYNSET_DICT_DIR = Path(__file__).resolve().parent
1820

21+
# Default values of rotation, translation and intrinsic matrices for BlenderCamera.
22+
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
23+
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
24+
k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4)
25+
1926

2027
class R2N2(ShapeNetBase):
2128
"""
@@ -217,3 +224,27 @@ def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
217224
model["images"] = torch.stack(images)
218225

219226
return model
227+
228+
229+
class BlenderCamera(CamerasBase):
230+
"""
231+
Camera for rendering objects with calibration matrices from the R2N2 dataset
232+
(which uses Blender for rendering the views for each model).
233+
"""
234+
235+
def __init__(self, R=r, T=t, K=k, device="cpu"):
236+
"""
237+
Args:
238+
R: Rotation matrix of shape (N, 3, 3).
239+
T: Translation matrix of shape (N, 3).
240+
K: Intrinsic matrix of shape (N, 4, 4).
241+
device: torch.device or str.
242+
"""
243+
# The initializer formats all inputs to torch tensors and broadcasts
244+
# all the inputs to have the same batch dimension where necessary.
245+
super().__init__(device=device, R=R, T=T, K=K)
246+
247+
def get_projection_transform(self, **kwargs) -> Transform3d:
248+
transform = Transform3d(device=self.device)
249+
transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16]
250+
return transform

tests/test_r2n2.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
import torch
1212
from common_testing import TestCaseMixin, load_rgb_image
1313
from PIL import Image
14-
from pytorch3d.datasets import R2N2, collate_batched_meshes
14+
from pytorch3d.datasets import R2N2, BlenderCamera, collate_batched_meshes
1515
from pytorch3d.renderer import (
1616
OpenGLPerspectiveCameras,
1717
PointLights,
1818
RasterizationSettings,
1919
look_at_view_transform,
2020
)
21+
from pytorch3d.renderer.cameras import get_world_to_view_transform
22+
from pytorch3d.transforms import Transform3d
23+
from pytorch3d.transforms.so3 import so3_exponential_map
2124
from torch.utils.data import DataLoader
2225

2326

@@ -258,3 +261,21 @@ def test_render_r2n2(self):
258261
"test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR
259262
)
260263
self.assertClose(mixed_rgb, image_ref, atol=0.05)
264+
265+
def test_blender_camera(self):
266+
"""
267+
Test BlenderCamera.
268+
"""
269+
# Test get_world_to_view_transform.
270+
T = torch.randn(10, 3)
271+
R = so3_exponential_map(torch.randn(10, 3) * 3.0)
272+
RT = get_world_to_view_transform(R=R, T=T)
273+
cam = BlenderCamera(R=R, T=T)
274+
RT_class = cam.get_world_to_view_transform()
275+
self.assertTrue(torch.allclose(RT.get_matrix(), RT_class.get_matrix()))
276+
self.assertTrue(isinstance(RT, Transform3d))
277+
278+
# Test getting camera center.
279+
C = cam.get_camera_center()
280+
C_ = -torch.bmm(R, T[:, :, None])[:, :, 0]
281+
self.assertTrue(torch.allclose(C, C_, atol=1e-05))

0 commit comments

Comments
 (0)