Skip to content

Commit 1aa27b5

Browse files
authored
Merge pull request #204 from Visual-Behavior/scene_flow
Scene flow
2 parents 22f455e + f1463e8 commit 1aa27b5

File tree

2 files changed

+191
-4
lines changed

2 files changed

+191
-4
lines changed

aloscene/io/flow.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch
33

44

5-
65
def load_flow_flo(flo_path):
76
"""
87
Load a 2D flow map with pytorch in float32 format
@@ -31,12 +30,19 @@ def load_flow_flo(flo_path):
3130
return flow
3231

3332

34-
35-
3633
def load_flow(flow_path):
3734
if flow_path.endswith(".flo"):
3835
return load_flow_flo(flow_path)
3936
elif flow_path.endswith(".zfd"):
4037
raise Exception("zfd format is not supported.")
41-
else :
38+
else:
4239
raise ValueError(f"Unknown extension for flow file: {flow_path}")
40+
41+
42+
def load_scene_flow(path: str) -> np.ndarray:
43+
if not path.endswith(".npy"):
44+
raise ValueError(
45+
f"Scene flow file should be of type .npy, but {path} has the extension .{path.split('.')[-1]}"
46+
)
47+
with open(path, "rb") as file:
48+
return np.load(file)

aloscene/scene_flow.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import aloscene
2+
from aloscene import Depth, CameraIntrinsic, Mask, Flow
3+
from aloscene.io.flow import load_scene_flow
4+
from typing import Union
5+
import torch
6+
import torch.nn.functional as F
7+
8+
9+
class SceneFlow(aloscene.tensors.SpatialAugmentedTensor):
10+
"""
11+
Scene flow map
12+
13+
Parameters
14+
----------
15+
x : str or tensor or ndarray
16+
load scene flow from a numpy file
17+
"""
18+
19+
@staticmethod
20+
def __new__(cls, x, occlusion: Union[Mask, None] = None, *args, names=("C", "H", "W"), **kwargs):
21+
if isinstance(x, str):
22+
# load flow from path
23+
x = load_scene_flow(x)
24+
names = ("C", "H", "W")
25+
26+
tensor = super().__new__(cls, x, *args, names=names, **kwargs)
27+
tensor.add_child("occlusion", occlusion, align_dim=["B", "T"], mergeable=True)
28+
return tensor
29+
30+
def __init__(self, x, *args, **kwargs):
31+
super().__init__(x)
32+
33+
@classmethod
34+
def from_optical_flow(
35+
cls,
36+
optical_flow: Flow,
37+
depth: Depth,
38+
next_depth: Depth,
39+
intrinsic: CameraIntrinsic,
40+
sampling: str = "bilinear",
41+
):
42+
"""Create scene flow from optical flow, depth a T, depth at T + 1 and the intrinsic
43+
44+
Parameters
45+
----------
46+
optical flow: aloscene.Flow
47+
The optical flow at T.
48+
depth: aloscene.Depth
49+
The depth at T.
50+
next_depth: aloscene.Depth
51+
The depth at T + 1
52+
intrinsic : aloscene.CameraIntrinsic
53+
The intrinsic of the image at T.
54+
sampling: str
55+
The sampling method to use for the scene flow.
56+
"""
57+
has_batch = "B" in optical_flow.names
58+
59+
if optical_flow.names != depth.names or optical_flow.names != next_depth.names:
60+
raise ValueError("The optical flow, depth and next_depth must have the same names")
61+
62+
if optical_flow.names != ("C", "H", "W") and optical_flow.names != ("B", "C", "H", "W"):
63+
raise ValueError("The optical flow must have the names (C, H, W) or (B, C, H, W)")
64+
65+
# Artifical batch dimension
66+
optical_flow = optical_flow.batch()
67+
depth = depth.batch()
68+
next_depth = next_depth.batch()
69+
70+
H, W = depth.HW
71+
B = depth.shape[0]
72+
73+
# Compute the point cloud at T and T + 1
74+
start_vector = depth.as_points3d(intrinsic).as_tensor().reshape(-1, H, W, 3).permute(0, 3, 1, 2)
75+
next_vector = next_depth.as_points3d(intrinsic).as_tensor().reshape(-1, H, W, 3).permute(0, 3, 1, 2)
76+
77+
# Compute the position of the point cloud at T + 1
78+
y_coords, x_coords = torch.meshgrid(torch.arange(H), torch.arange(W))
79+
new_x = x_coords + optical_flow.as_tensor()[:, 0, :, :]
80+
new_y = y_coords + optical_flow.as_tensor()[:, 1, :, :]
81+
82+
# Normalize the coordinates bettwen -1 and 1 and create the new points coordinates
83+
new_x = new_x / W * 2 - 1
84+
new_y = new_y / H * 2 - 1
85+
new_coords = torch.stack([new_x, new_y], dim=3)
86+
87+
# Move the point cloud at T + 1 to the new position
88+
end_vector = F.grid_sample(next_vector, new_coords, mode=sampling, padding_mode="zeros", align_corners=True)
89+
90+
# Compute the scene flow
91+
scene_flow_vector = end_vector - start_vector
92+
93+
# Create the occlusion mask if needed
94+
occlusion = None
95+
if optical_flow.occlusion is not None or depth.occlusion is not None or next_depth.occlusion is not None:
96+
occlusion = torch.zeros(B, H, W, dtype=torch.bool)
97+
98+
# Add depth and optical flow occlusion to main occlusion mask
99+
if optical_flow.occlusion is not None:
100+
occlusion = occlusion | optical_flow.occlusion.as_tensor().bool()
101+
if depth.occlusion is not None:
102+
occlusion = occlusion | depth.occlusion.as_tensor().bool()
103+
if next_depth.occlusion is not None:
104+
next_depth_tensor = next_depth.occlusion.as_tensor().bool().unsqueeze(1)
105+
106+
# Use of 'not' needed because the grid_sample has padding_mode="zeros" and
107+
# the 0 from this function mean that the pixel is occluded
108+
next_depth_tensor = ~next_depth_tensor
109+
110+
# Move the occlusion mask like the scene flow to check if occluded pixel are used in the calculation
111+
moved_occlusion = F.grid_sample(
112+
next_depth_tensor.float(), new_coords, mode=sampling, padding_mode="zeros", align_corners=True
113+
)
114+
115+
# Sometimes moved_occlusion is not exactly 1 even if the pixels around are not occluded
116+
moved_occlusion = ~(moved_occlusion >= 0.99999)
117+
118+
# Fusion of the 2 occlusion mask
119+
moved_occlusion = moved_occlusion.squeeze(1)
120+
occlusion = occlusion | moved_occlusion
121+
122+
# Remove the artificial batch dimension
123+
if not has_batch:
124+
scene_flow_vector = scene_flow_vector.squeeze(0)
125+
optical_flow = optical_flow.squeeze(0)
126+
occlusion = occlusion.squeeze(0)
127+
128+
# Create the scene flow object
129+
tensor = cls(
130+
scene_flow_vector,
131+
names=("B", "C", "H", "W") if has_batch else ("C", "H", "W"),
132+
occlusion=None
133+
if occlusion is None
134+
else Mask(occlusion, names=("B", "H", "W") if has_batch else ("H", "W")),
135+
)
136+
return tensor
137+
138+
def append_occlusion(self, occlusion: Mask, name: Union[str, None] = None):
139+
"""Attach an occlusion mask to the scene flow.
140+
141+
Parameters
142+
----------
143+
occlusion: aloscene.Mask
144+
Occlusion mask to attach to the Scene Flow
145+
name: str
146+
If none, the occlusion mask will be attached without name (if possible). Otherwise if no other unnamed
147+
occlusion mask are attached to the scene flow, the mask will be added to the set of mask.
148+
"""
149+
self._append_child("occlusion", occlusion, name)
150+
151+
def _hflip(self, **kwargs):
152+
"""Flip scene flow horizontally.
153+
154+
Returns
155+
-------
156+
flipped_scene_flow : aloscene.SceneFlow
157+
horizontally flipped scene flow map
158+
"""
159+
flow_flipped = super()._hflip(**kwargs)
160+
# invert x axis of flow vector
161+
labels = flow_flipped.drop_children()
162+
sl_x = flow_flipped.get_slices({"C": 0})
163+
flow_flipped[sl_x] = -1 * flow_flipped[sl_x]
164+
flow_flipped.set_children(labels)
165+
return flow_flipped
166+
167+
def _vflip(self, **kwargs):
168+
"""Flip scene flow vertically.
169+
170+
Returns
171+
-------
172+
flipped_scene_flow : aloscene.SceneFlow
173+
vertically flipped scene flow map
174+
"""
175+
flow_flipped = super()._vflip(**kwargs)
176+
# invert y axis of flow vector
177+
labels = flow_flipped.drop_children()
178+
sl_y = flow_flipped.get_slices({"C": 1})
179+
flow_flipped[sl_y] = -1 * flow_flipped[sl_y]
180+
flow_flipped.set_children(labels)
181+
return flow_flipped

0 commit comments

Comments
 (0)