Skip to content

Commit 653e498

Browse files
committed
allow saving vertex normal in save_obj
1 parent 1af6bf4 commit 653e498

File tree

2 files changed

+186
-1
lines changed

2 files changed

+186
-1
lines changed

pytorch3d/io/obj_io.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,8 @@ def save_obj(
684684
decimal_places: Optional[int] = None,
685685
path_manager: Optional[PathManager] = None,
686686
*,
687+
verts_normals: Optional[torch.Tensor] = None,
688+
faces_normals: Optional[torch.Tensor] = None,
687689
verts_uvs: Optional[torch.Tensor] = None,
688690
faces_uvs: Optional[torch.Tensor] = None,
689691
texture_map: Optional[torch.Tensor] = None,
@@ -698,6 +700,9 @@ def save_obj(
698700
decimal_places: Number of decimal places for saving.
699701
path_manager: Optional PathManager for interpreting f if
700702
it is a str.
703+
verts_normals: FloatTensor of shape (V, 3) giving the normal per vertex.
704+
faces_normals: LongTensor of shape (F, 3) giving the index into verts_normals
705+
for each vertex in the face.
701706
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
702707
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
703708
each vertex in the face.
@@ -712,6 +717,14 @@ def save_obj(
712717
if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
713718
message = "'faces' should either be empty or of shape (num_faces, 3)."
714719
raise ValueError(message)
720+
721+
if faces_normals is not None and (faces_normals.dim() != 2 or faces_normals.size(1) != 3):
722+
message = "'faces_normals' should either be empty or of shape (num_faces, 3)."
723+
raise ValueError(message)
724+
725+
if verts_normals is not None and (verts_normals.dim() != 2 or verts_normals.size(1) != 3):
726+
message = "'verts_normals' should either be empty or of shape (num_verts, 3)."
727+
raise ValueError(message)
715728

716729
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
717730
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
@@ -728,6 +741,7 @@ def save_obj(
728741
if path_manager is None:
729742
path_manager = PathManager()
730743

744+
save_normals = all([n is not None for n in [verts_normals, faces_normals]])
731745
save_texture = all([t is not None for t in [faces_uvs, verts_uvs, texture_map]])
732746
output_path = Path(f)
733747

@@ -742,9 +756,12 @@ def save_obj(
742756
verts,
743757
faces,
744758
decimal_places,
759+
verts_normals=verts_normals,
760+
faces_normals=faces_normals,
745761
verts_uvs=verts_uvs,
746762
faces_uvs=faces_uvs,
747763
save_texture=save_texture,
764+
save_normals=save_normals,
748765
)
749766

750767
# Save the .mtl and .png files associated with the texture
@@ -777,9 +794,12 @@ def _save(
777794
faces,
778795
decimal_places: Optional[int] = None,
779796
*,
797+
verts_normals: Optional[torch.Tensor] = None,
798+
faces_normals: Optional[torch.Tensor] = None,
780799
verts_uvs: Optional[torch.Tensor] = None,
781800
faces_uvs: Optional[torch.Tensor] = None,
782801
save_texture: bool = False,
802+
save_normals: bool = False,
783803
) -> None:
784804

785805
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
@@ -809,6 +829,25 @@ def _save(
809829
vert = [float_str % verts[i, j] for j in range(D)]
810830
lines += "v %s\n" % " ".join(vert)
811831

832+
if save_normals:
833+
if faces_normals is not None and (faces_normals.dim() != 2 or faces_normals.size(1) != 3):
834+
message = "'faces_normals' should either be empty or of shape (num_faces, 3)."
835+
raise ValueError(message)
836+
837+
if verts_normals is not None and (verts_normals.dim() != 2 or verts_normals.size(1) != 3):
838+
message = "'verts_normals' should either be empty or of shape (num_verts, 3)."
839+
raise ValueError(message)
840+
841+
# pyre-fixme[16] # undefined attribute cpu
842+
verts_normals, faces_normals = verts_normals.cpu(), faces_normals.cpu()
843+
844+
# Save verts normals after verts
845+
if len(verts_normals):
846+
V, D = verts_normals.shape
847+
for i in range(V):
848+
normal = [float_str % verts_normals[i, j] for j in range(D)]
849+
lines += "vn %s\n" % " ".join(normal)
850+
812851
if save_texture:
813852
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
814853
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
@@ -834,7 +873,22 @@ def _save(
834873
if len(faces):
835874
F, P = faces.shape
836875
for i in range(F):
837-
if save_texture:
876+
if save_texture and save_normals:
877+
# Format faces as {verts_idx}/{verts_uvs_idx}/{verts_normals_idx}
878+
face = [
879+
"%d/%d/%d" % (
880+
faces[i, j] + 1,
881+
faces_uvs[i, j] + 1,
882+
faces_normals[i, j] + 1,
883+
)
884+
for j in range(P)
885+
]
886+
elif save_normals:
887+
# Format faces as {verts_idx}//{verts_normals_idx}
888+
face = [
889+
"%d//%d" % (faces[i, j] + 1, faces_normals[i, j] + 1) for j in range(P)
890+
]
891+
elif save_texture:
838892
# Format faces as {verts_idx}/{verts_uvs_idx}
839893
face = [
840894
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)

tests/test_io_obj.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,59 @@ def check_item(x, y):
895895
with self.assertRaisesRegex(ValueError, "same type of texture"):
896896
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
897897

898+
def test_save_obj_with_normal(self):
899+
verts = torch.tensor(
900+
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
901+
dtype=torch.float32,
902+
)
903+
faces = torch.tensor(
904+
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
905+
)
906+
verts_normals = torch.tensor(
907+
[[0.02, 0.5, 0.73], [0.3, 0.03, 0.361], [0.32, 0.12, 0.47], [0.36, 0.17, 0.9],
908+
[0.40, 0.7, 0.19], [1.0, 0.00, 0.000], [0.00, 1.00, 0.00], [0.00, 0.00, 1.0]],
909+
dtype=torch.float32,
910+
)
911+
faces_normals = torch.tensor(
912+
[[0, 1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 0]], dtype=torch.int64
913+
)
914+
915+
with TemporaryDirectory() as temp_dir:
916+
obj_file = os.path.join(temp_dir, "mesh.obj")
917+
save_obj(
918+
obj_file,
919+
verts,
920+
faces,
921+
decimal_places=2,
922+
verts_normals=verts_normals,
923+
faces_normals=faces_normals,
924+
)
925+
926+
expected_obj_file = "\n".join(
927+
[
928+
"v 0.01 0.20 0.30",
929+
"v 0.20 0.03 0.41",
930+
"v 0.30 0.40 0.05",
931+
"v 0.60 0.70 0.80",
932+
"vn 0.02 0.50 0.73",
933+
"vn 0.30 0.03 0.36",
934+
"vn 0.32 0.12 0.47",
935+
"vn 0.36 0.17 0.90",
936+
"vn 0.40 0.70 0.19",
937+
"vn 1.00 0.00 0.00",
938+
"vn 0.00 1.00 0.00",
939+
"vn 0.00 0.00 1.00",
940+
"f 1//1 3//2 2//3",
941+
"f 1//3 2//4 3//5",
942+
"f 4//5 3//6 2//7",
943+
"f 4//7 2//8 1//1",
944+
]
945+
)
946+
947+
# Check the obj file is saved correctly
948+
actual_file = open(obj_file, "r")
949+
self.assertEqual(actual_file.read(), expected_obj_file)
950+
898951
def test_save_obj_with_texture(self):
899952
verts = torch.tensor(
900953
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
@@ -962,6 +1015,84 @@ def test_save_obj_with_texture(self):
9621015
texture_image = load_rgb_image("mesh.png", temp_dir)
9631016
self.assertClose(texture_image, texture_map)
9641017

1018+
def test_save_obj_with_normal_and_texture(self):
1019+
verts = torch.tensor(
1020+
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
1021+
dtype=torch.float32,
1022+
)
1023+
faces = torch.tensor(
1024+
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
1025+
)
1026+
verts_normals = torch.tensor(
1027+
[[0.02, 0.5, 0.73], [0.3, 0.03, 0.361], [0.32, 0.12, 0.47], [0.36, 0.17, 0.9]],
1028+
dtype=torch.float32,
1029+
)
1030+
faces_normals = faces
1031+
verts_uvs = torch.tensor(
1032+
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
1033+
dtype=torch.float32,
1034+
)
1035+
faces_uvs = faces
1036+
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
1037+
1038+
with TemporaryDirectory() as temp_dir:
1039+
obj_file = os.path.join(temp_dir, "mesh.obj")
1040+
save_obj(
1041+
obj_file,
1042+
verts,
1043+
faces,
1044+
decimal_places=2,
1045+
verts_normals=verts_normals,
1046+
faces_normals=faces_normals,
1047+
verts_uvs=verts_uvs,
1048+
faces_uvs=faces_uvs,
1049+
texture_map=texture_map,
1050+
)
1051+
1052+
expected_obj_file = "\n".join(
1053+
[
1054+
"",
1055+
"mtllib mesh.mtl",
1056+
"usemtl mesh",
1057+
"",
1058+
"v 0.01 0.20 0.30",
1059+
"v 0.20 0.03 0.41",
1060+
"v 0.30 0.40 0.05",
1061+
"v 0.60 0.70 0.80",
1062+
"vn 0.02 0.50 0.73",
1063+
"vn 0.30 0.03 0.36",
1064+
"vn 0.32 0.12 0.47",
1065+
"vn 0.36 0.17 0.90",
1066+
"vt 0.02 0.50",
1067+
"vt 0.30 0.03",
1068+
"vt 0.32 0.12",
1069+
"vt 0.36 0.17",
1070+
"f 1/1/1 3/3/3 2/2/2",
1071+
"f 1/1/1 2/2/2 3/3/3",
1072+
"f 4/4/4 3/3/3 2/2/2",
1073+
"f 4/4/4 2/2/2 1/1/1",
1074+
]
1075+
)
1076+
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
1077+
1078+
# Check there are only 3 files in the temp dir
1079+
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
1080+
tempfiles_dir = os.listdir(temp_dir)
1081+
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
1082+
1083+
# Check the obj file is saved correctly
1084+
actual_file = open(obj_file, "r")
1085+
self.assertEqual(actual_file.read(), expected_obj_file)
1086+
1087+
# Check the mtl file is saved correctly
1088+
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
1089+
mtl_file = open(mtl_file_name, "r")
1090+
self.assertEqual(mtl_file.read(), expected_mtl_file)
1091+
1092+
# Check the texture image file is saved correctly
1093+
texture_image = load_rgb_image("mesh.png", temp_dir)
1094+
self.assertClose(texture_image, texture_map)
1095+
9651096
def test_save_obj_with_texture_errors(self):
9661097
verts = torch.tensor(
9671098
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],

0 commit comments

Comments
 (0)