Skip to content

Commit 7d8eaa1

Browse files
authored
Merge pull request #1984 from pupil-labs/support-pi-headset-worn-classifier-data
Support PI headset on/off classifier data
2 parents 48a8d54 + 22463c8 commit 7d8eaa1

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

pupil_src/shared_modules/pupil_recording/update/invisible.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
logger = logging.getLogger(__name__)
3030

31-
NEWEST_SUPPORTED_VERSION = Version("1.2")
31+
NEWEST_SUPPORTED_VERSION = Version("1.3")
3232

3333

3434
def transform_invisible_to_corresponding_new_style(rec_dir: str):
@@ -144,14 +144,15 @@ def _convert_gaze(recording: PupilRecording):
144144
"topic": "gaze.pi",
145145
"norm_pos": None,
146146
"timestamp": None,
147-
"confidence": 1.0,
147+
"confidence": None,
148148
}
149149
with fm.PLData_Writer(recording.rec_dir, "gaze") as writer:
150-
for ((x, y), ts) in pi_gaze_items(root_dir=recording.rec_dir):
150+
for ((x, y), ts, conf) in pi_gaze_items(root_dir=recording.rec_dir):
151151
template_datum["timestamp"] = ts
152152
template_datum["norm_pos"] = m.normalize(
153153
(x, y), size=(width, height), flip_y=True
154154
)
155+
template_datum["confidence"] = conf
155156
writer.append(template_datum)
156157
logger.info(f"Converted {len(writer.ts_queue)} gaze positions.")
157158

pupil_src/shared_modules/video_capture/utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,16 @@ def find_raw_path(timestamps_path):
468468
assert raw_path.exists(), f"The file does not exist at path: {raw_path}"
469469
return raw_path
470470

471+
def find_worn_path(timestamps_path):
472+
worn_name = timestamps_path.name
473+
worn_name = worn_name.replace("gaze", "worn")
474+
worn_name = worn_name.replace("_timestamps", "")
475+
worn_path = timestamps_path.with_name(worn_name).with_suffix(".raw")
476+
if worn_path.exists():
477+
return worn_path
478+
else:
479+
return None
480+
471481
def load_timestamps_data(path):
472482
timestamps = np.load(str(path))
473483
return timestamps
@@ -478,6 +488,13 @@ def load_raw_data(path):
478488
raw_data.shape = (-1, 2)
479489
return np.asarray(raw_data, dtype=raw_data_dtype)
480490

491+
def load_worn_data(path):
492+
if not (path and path.exists()):
493+
return None
494+
495+
confidences = np.fromfile(str(path), "<u1") / 255.0
496+
return np.clip(confidences, 0.0, 1.0)
497+
481498
# This pattern will match any filename that:
482499
# - starts with "gaze ps"
483500
# - is followed by one or more digits
@@ -496,4 +513,16 @@ def load_raw_data(path):
496513
size = min(len(raw_data), len(timestamps))
497514
raw_data = raw_data[:size]
498515
timestamps = timestamps[:size]
499-
yield from zip(raw_data, timestamps)
516+
517+
conf_data = load_worn_data(find_worn_path(timestamps_path))
518+
if conf_data is not None and len(conf_data) != len(timestamps):
519+
logger.warning(
520+
f"There is a mismatch between the number of confidence data ({len(conf_data)}) "
521+
f"and the number of timestamps ({len(timestamps)})! Not using confidence data."
522+
)
523+
conf_data = None
524+
525+
if conf_data is None:
526+
conf_data = (1.0 for _ in range(len(timestamps)))
527+
528+
yield from zip(raw_data, timestamps, conf_data)

0 commit comments

Comments
 (0)