Skip to content

Commit c6f97e1

Browse files
authored
Merge pull request #226 from Visual-Behavior/woodscape
new: woodscape dataset
2 parents d7bc55a + 42dce6c commit c6f97e1

File tree

4 files changed

+276
-2
lines changed

4 files changed

+276
-2
lines changed

alodataset/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414
from .sintel_flow_dataset import SintelFlowDataset
1515
from .sintel_disparity_dataset import SintelDisparityDataset
1616
from .sintel_multi_dataset import SintelMultiDataset
17+
from .woodScape_dataset import WooodScapeDataset
18+
from .woodScape_split_dataset import WoodScapeSplitDataset

alodataset/woodScape_dataset.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
from aloscene import Frame, Mask, BoundingBoxes2D, Labels
2+
from alodataset import BaseDataset
3+
4+
from PIL import Image
5+
import numpy as np
6+
import torch
7+
import glob
8+
import os
9+
10+
11+
class WooodScapeDataset(BaseDataset):
12+
"""WoodScape dataset iterator
13+
14+
Paramneters
15+
-----------
16+
labels : List[str]
17+
List of labels to stick to the frame. If the list is empty all labels are loaded. By default all labels are attached.
18+
cameras : List[str]
19+
List of cameras to consider. If the list empty all cameras are loaded. By default all camera views are considered.
20+
fragment : Union[int, float]
21+
Either the portion of dataset to to consider if the arg is float or the number of samples if int.
22+
Passing a negative value will start the count from the end. By default 0.9
23+
seg_classes : List[sstr]
24+
Classes to consider for segmentation. By default all classes are considered.
25+
merge_classees : bool
26+
Assign the same classe index for all segementation classes, Default if False.
27+
rename_merged : str
28+
Name to give to merged instancee. Only if merge_classes is True. Default is "mix".
29+
30+
Raises
31+
------
32+
AssertionError
33+
One of the labels is not in ["Seg", "bxox_2d"].
34+
AssertionError
35+
One of the cameras is not in ["LV", "FV", "MVL", "MVR"].
36+
AssertionError
37+
One of the passed clases is not available/supported.
38+
39+
"""
40+
41+
CAMERAS = [
42+
"RV", # Right View
43+
"FV", # Front View
44+
"MVL", # Mirror Left View
45+
"MVR", # Mirror Right View
46+
]
47+
LABELS = [
48+
"seg",
49+
"box_2d"
50+
]
51+
SEG_CLASSES = [
52+
"void",
53+
"road",
54+
"lanemarks",
55+
"curb",
56+
"person",
57+
"rider",
58+
"vehicles",
59+
"bicycle",
60+
"motorcycle",
61+
"traffic_sign"
62+
]
63+
64+
def __init__(
65+
self,
66+
labels=[],
67+
cameras=[],
68+
fragment=0.9,
69+
name="WoodScape",
70+
seg_classes=[],
71+
merge_classes=False,
72+
rename_merged="mix",
73+
**kwargs):
74+
super().__init__(name=name, **kwargs)
75+
76+
cameras = self.CAMERAS if cameras == list() else cameras
77+
labels = self.LABELS if labels == list() else labels
78+
79+
if isinstance(fragment , int):
80+
pass
81+
elif isinstance(fragment, float):
82+
assert fragment <= 1 and fragment >= -1, "fragment of type float can not be higher than 1 of less than -1."
83+
else:
84+
raise AttributeError("Invalid type of fragment type.")
85+
86+
if seg_classes == list():
87+
seg_classes = self.SEG_CLASSES
88+
89+
assert all([c in self.SEG_CLASSES for c in seg_classes]), f"some segmentation classes are invalid, supported classes are :\n {self.SEG_CLASSES}"
90+
assert all([v in self.CAMERAS for v in cameras]), f"Some cameras are invalid, should be in {self.CAMERAS}"
91+
assert all([l in self.LABELS for l in labels]), f"Some labels are invalid, should be ib {self.LABELS}"
92+
assert isinstance(merge_classes, (int, bool)), "Invalid merge_classes argument"
93+
assert isinstance(rename_merged, str), "Invalid rename_merged argument"
94+
95+
self.items = sorted(glob.glob(os.path.join(self.dataset_dir, "rgb_images", "*")))
96+
self.items = self._filter_cameras(self.items, cameras)
97+
self.items = self._filter_non_png(self.items)
98+
99+
self.labels = labels
100+
self.cameras = cameras
101+
self.seg_classes = seg_classes
102+
self.merge_classes = merge_classes
103+
self.num_seg_classes = len(seg_classes)
104+
self.seg_classes_renamed = seg_classes if not merge_classes else [rename_merged]
105+
106+
# Encode fraction
107+
self.fragment = min(abs(fragment), len(self)) if isinstance(fragment, int) else int(abs(fragment) * len(self))
108+
109+
# Restricting the number of samples
110+
if fragment > 0:
111+
self.items = self.items[:self.fragment]
112+
else:
113+
self.items = self.items[len(self) - self.fragment:]
114+
115+
def getitem(self, idx):
116+
ipath = self.items[idx]
117+
frame = Frame(ipath, names=tuple("CHW"))
118+
119+
if "seg" in self.labels:
120+
segmentation = self._path2segLabel(ipath)
121+
frame.append_segmentation(segmentation)
122+
123+
if "box_2d" in self.labels:
124+
_, H, W = frame.shape
125+
bbox2d_path = self._path2boxLabel(ipath, frame_size=(H, W))
126+
frame.append_boxes2d(bbox2d_path)
127+
return frame
128+
129+
@staticmethod
130+
def _path2segPath(path):
131+
"""Maps rgb image path to corresponding segmentation path
132+
133+
Parameters
134+
----------
135+
path: str
136+
Path to rgb image.
137+
138+
"""
139+
path, file = os.path.split(path)
140+
path, _ = os.path.split(path)
141+
path = os.path.join(path, "semantic_annotations", "gtLabels", file)
142+
return path
143+
144+
@staticmethod
145+
def _path2boxPath(path):
146+
"""Maps rgb image path to corresponding json 2dbbox file
147+
148+
Parameters
149+
----------
150+
path: str
151+
path to rgb image
152+
153+
"""
154+
path, file = os.path.split(path)
155+
path, _ = os.path.split(path)
156+
path = os.path.join(path, "box_2d_annotations", file.replace(".png", ".txt"))
157+
return path
158+
159+
def _path2segLabel(self, path):
160+
"""Maps image path to segmentation mask
161+
162+
Parametrs
163+
---------
164+
path: str
165+
path to rgb image
166+
167+
"""
168+
path = self._path2segPath(path)
169+
mask = np.asarray(Image.open(path))
170+
mask = self.mask_2d_idx_to_3d_onehot_mask(mask)
171+
return mask
172+
173+
def mask_2d_idx_to_3d_onehot_mask(self, mask_2d):
174+
"""Converts 2d index encoding mask to 3d one hot encoding one
175+
176+
Parameters
177+
----------
178+
mask : np.ndarray
179+
Mask of size (H, W) with int values
180+
181+
"""
182+
sample_seg_classes = torch.unique(torch.Tensor(mask_2d.reshape(-1)))
183+
184+
num_sample_seg_classes = len(self.seg_classes_renamed)
185+
mask_3d = np.zeros((num_sample_seg_classes, ) + mask_2d.shape)
186+
187+
dec = 0
188+
for i, name in enumerate(self.seg_classes):
189+
if i in sample_seg_classes:
190+
mask_3d[i - dec] += (mask_2d == self.SEG_CLASSES.index(name)).astype(int)
191+
if self.merge_classes:
192+
dec += 1
193+
else:
194+
dec += 1
195+
196+
mask_3d = Mask(mask_3d, names=tuple("CHW"))
197+
mlabels = Labels(torch.arange(num_sample_seg_classes).to(torch.float32), labels_names=self.seg_classes_renamed, names=("N"), encoding="id")
198+
mask_3d.append_labels(mlabels)
199+
return mask_3d
200+
201+
def _path2boxLabel(self, path, frame_size):
202+
"""Maps image patgh to bbox2d label
203+
204+
Parameters
205+
----------
206+
path: str
207+
rgb image path
208+
209+
"""
210+
path = self._path2boxPath(path)
211+
212+
with open(path, "r") as f:
213+
content = f.readlines()
214+
content = [x.replace("\n", "") for x in content]
215+
content = [x.split(",") for x in content]
216+
bboxs2d = [[int(x) for x in c[2:]] for c in content]
217+
return BoundingBoxes2D(bboxs2d, boxes_format="xyxy", absolute=True, frame_size=frame_size)
218+
219+
@staticmethod
220+
def _filter_non_png(items):
221+
"""Filters non png files from a list of paths to files
222+
223+
Parameters
224+
----------
225+
items : List[str]
226+
list of paths to filter
227+
228+
"""
229+
return [p for p in items if p.endswith(".png")]
230+
231+
@staticmethod
232+
def _filter_cameras(items, cameras):
233+
"""Filters paths by given cameras list
234+
235+
Parameters
236+
----------
237+
238+
cameras : List[str]
239+
List of cameras
240+
241+
"""
242+
return list(filter(lambda x : any([v in x for v in cameras]), items))
243+
244+
245+
if __name__ == "__main__":
246+
ds = WooodScapeDataset(
247+
labels=[],
248+
cameras=[],
249+
fragment=1.,
250+
)
251+
idx = 222
252+
frame = ds[idx]
253+
frame.get_view().render()

alodataset/woodScape_split_dataset.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from alodataset import Split, SplitMixin, WooodScapeDataset
2+
3+
4+
class WoodScapeSplitDataset(WooodScapeDataset, SplitMixin):
5+
SPLIT_FOLDERS = {Split.VAL: -0.1, Split.TRAIN: 0.9}
6+
7+
def __init__(
8+
self,
9+
split=Split.TRAIN,
10+
**kwargs
11+
):
12+
self.split = split
13+
super().__init__(fragment=self.get_split_folder(), **kwargs)
14+
15+
16+
if __name__ == "__main__":
17+
val = WoodScapeSplitDataset(split=Split.VAL)
18+
train = WoodScapeSplitDataset(split=Split.TRAIN)
19+
20+
print("val :", len(val))
21+
print("train :", len(train))

alonet/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,3 @@
1010

1111
from . import detr_panoptic
1212
from . import deformable_detr_panoptic
13-
14-

0 commit comments

Comments
 (0)