Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 113 additions & 15 deletions newton/_src/viewer/viewer_usd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,34 @@ def _compute_segment_xform(pos0, pos1):
return (mid, Gf.Quath(rot.GetQuat()), scale)


def _usd_add_xform(prim):
prim = UsdGeom.Xform(prim)
prim.ClearXformOpOrder()

prim.AddTranslateOp()
prim.AddOrientOp()
prim.AddScaleOp()


def _usd_set_xform(
xform,
pos: tuple | None = None,
rot: tuple | None = None,
scale: tuple | None = None,
time: float = 0.0,
):
xform = UsdGeom.Xform(xform)

xform_ops = xform.GetOrderedXformOps()

if pos is not None:
xform_ops[0].Set(Gf.Vec3d(float(pos[0]), float(pos[1]), float(pos[2])), time)
if rot is not None:
xform_ops[1].Set(Gf.Quatf(float(rot[3]), float(rot[0]), float(rot[1]), float(rot[2])), time)
if scale is not None:
xform_ops[2].Set(Gf.Vec3d(float(scale[0]), float(scale[1]), float(scale[2])), time)


class ViewerUSD(ViewerBase):
"""
USD viewer backend for Newton physics simulations.
Expand All @@ -37,15 +65,16 @@ class ViewerUSD(ViewerBase):
and visualization of simulation data.
"""

def __init__(self, output_path, fps=60, up_axis="Z", num_frames=None):
def __init__(self, output_path, fps=60, up_axis="Z", num_frames=100, scaling=1.0):
"""
Initialize the USD viewer backend for Newton physics simulations.

Args:
output_path (str): Path to the output USD file.
fps (int, optional): Frames per second for time sampling. Default is 60.
up_axis (str, optional): USD up axis, either 'Y' or 'Z'. Default is 'Z'.
num_frames (int, optional): Maximum number of frames to record. If None, recording is unlimited.
num_frames (int, optional): Maximum number of frames to record. Default is 100. If None, recording is unlimited.
scaling (float, optional): Uniform scaling applied to the scene root. Default is 1.0.

Raises:
ImportError: If the usd-core package is not installed.
Expand All @@ -65,7 +94,23 @@ def __init__(self, output_path, fps=60, up_axis="Z", num_frames=None):
self.stage.SetFramesPerSecond(fps)
self.stage.SetStartTimeCode(0)

UsdGeom.SetStageUpAxis(self.stage, UsdGeom.Tokens.z)
axis_token = {
"X": UsdGeom.Tokens.x,
"Y": UsdGeom.Tokens.y,
"Z": UsdGeom.Tokens.z,
}.get(self.up_axis.strip().upper())

UsdGeom.SetStageUpAxis(self.stage, axis_token)
UsdGeom.SetStageMetersPerUnit(self.stage, 1.0)

self.root = UsdGeom.Xform.Define(self.stage, "/root")

# apply root scaling
self.root.ClearXformOpOrder()
s = self.root.AddScaleOp()
s.Set(Gf.Vec3d(float(scaling), float(scaling), float(scaling)), 0.0)

self.stage.SetDefaultPrim(self.root.GetPrim())

# Track meshes and instancers
self._meshes = {} # mesh_name -> prototype_path
Expand Down Expand Up @@ -124,6 +169,13 @@ def close(self):
if self.output_path:
print(f"USD output saved in: {os.path.abspath(self.output_path)}")

def _get_path(self, name):
# Handle both absolute and relative paths correctly
if name.startswith("/"):
return "/root" + name
else:
return "/root/" + name

def log_mesh(
self,
name,
Expand Down Expand Up @@ -155,9 +207,9 @@ def log_mesh(
indices_np = indices.numpy().astype(np.uint32)

if name not in self._meshes:
self._ensure_scopes_for_path(self.stage, name)
self._ensure_scopes_for_path(self.stage, self._get_path(name))

mesh_prim = UsdGeom.Mesh.Define(self.stage, name)
mesh_prim = UsdGeom.Mesh.Define(self.stage, self._get_path(name))

# setup topology once (do not set every frame)
face_vertex_counts = [3] * (len(indices_np) // 3)
Expand All @@ -182,9 +234,54 @@ def log_mesh(
pass

# how to hide the prototype mesh but not the instances in USD?
# mesh_prim.GetVisibilityAttr().Set("inherited" if not hidden else "invisible", self._frame_index)
mesh_prim.GetVisibilityAttr().Set("inherited" if not hidden else "invisible", self._frame_index)

# log a set of instances as individual mesh prims, slower but makes it easier
# to do post-editing of instance materials etc. default for Newton shapes
def log_instances(self, name, mesh, xform, scale, color, material):
# Get prototype path
if mesh not in self._meshes:
msg = f"Mesh prototype '{mesh}' not found for log_instances(). Call log_mesh() first."
raise RuntimeError(msg)

self._ensure_scopes_for_path(self.stage, self._get_path(name) + "/scope")

if xform:
xform = xform.numpy()

if scale:
scale = scale.numpy()
else:
scale = np.ones((len(xform), 3), dtype=np.float32)

if color:
color = color.numpy()

for i in range(len(xform)):
instance_path = self._get_path(name) + f"/instance_{i}"
instance = self.stage.GetPrimAtPath(instance_path)

if not instance:
instance = self.stage.DefinePrim(instance_path)
instance.GetReferences().AddInternalReference(self._get_path(mesh))

UsdGeom.Imageable(instance).GetVisibilityAttr().Set("inherited")
_usd_add_xform(instance)

# update transform
if xform is not None:
pos = xform[i][:3]
rot = xform[i][3:7]

_usd_set_xform(instance, pos, rot, scale[i], self._frame_index)

# update color
if color is not None:
displayColor = UsdGeom.PrimvarsAPI(instance).GetPrimvar("displayColor")
displayColor.Set(color[i], self._frame_index)

def log_instances(self, name, mesh, xforms, scales, colors, materials, hidden=False):
# log a set of instances as a point instancer, faster but less flexible
def log_instances_point_instancer(self, name, mesh, xforms, scales, colors, materials):
"""
Create or update a PointInstancer for mesh instances.

Expand All @@ -208,17 +305,17 @@ def log_instances(self, name, mesh, xforms, scales, colors, materials, hidden=Fa

# Create instancer if it doesn't exist
if name not in self._instancers:
self._ensure_scopes_for_path(self.stage, name)
self._ensure_scopes_for_path(self.stage, self._get_path(name))

instancer = UsdGeom.PointInstancer.Define(self.stage, name)
instancer = UsdGeom.PointInstancer.Define(self.stage, self._get_path(name))
instancer.CreateIdsAttr().Set(list(range(num_instances)))
instancer.CreateProtoIndicesAttr().Set([0] * num_instances)
UsdGeom.PrimvarsAPI(instancer).CreatePrimvar(
"displayColor", Sdf.ValueTypeNames.Color3fArray, UsdGeom.Tokens.vertex, 1
)

# Set the prototype relationship
instancer.GetPrototypesRel().AddTarget(mesh)
instancer.GetPrototypesRel().AddTarget(self._get_path(mesh))

self._instancers[name] = instancer

Expand Down Expand Up @@ -283,9 +380,9 @@ def log_lines(self, name, starts, ends, colors, width: float = 0.01, hidden=Fals
"""

if name not in self._instancers:
self._ensure_scopes_for_path(self.stage, name)
self._ensure_scopes_for_path(self.stage, self._get_path(name))

instancer = UsdGeom.PointInstancer.Define(self.stage, name)
instancer = UsdGeom.PointInstancer.Define(self.stage, self._get_path(name))

# define nested capsule prim
instancer_capsule = UsdGeom.Capsule.Define(self.stage, instancer.GetPath().AppendChild("capsule"))
Expand Down Expand Up @@ -357,10 +454,11 @@ def log_points(self, name, points, radii, colors, hidden=False):
else:
color_interp = "vertex"

instancer = UsdGeom.Points.Get(self.stage, name)
path = self._get_path(name)
instancer = UsdGeom.Points.Get(self.stage, path)
if not instancer:
self._ensure_scopes_for_path(self.stage, name)
instancer = UsdGeom.Points.Define(self.stage, name)
self._ensure_scopes_for_path(self.stage, path)
instancer = UsdGeom.Points.Define(self.stage, path)

UsdGeom.Primvar(instancer.GetWidthsAttr()).SetInterpolation(radius_interp)
UsdGeom.Primvar(instancer.GetDisplayColorAttr()).SetInterpolation(color_interp)
Expand Down
2 changes: 1 addition & 1 deletion newton/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def create_parser():
help="Viewer to use (gl, usd, rerun, or null).",
)
parser.add_argument(
"--output-path", type=str, default=None, help="Path to the output USD file (required for usd viewer)."
"--output-path", type=str, default="output.usd", help="Path to the output USD file (required for usd viewer)."
)
parser.add_argument("--num-frames", type=int, default=100, help="Total number of frames.")
parser.add_argument(
Expand Down
Loading