Skip to content

Commit 23b0c31

Browse files
authored
Fix dense assignment without aggregation (#243)
1 parent 589683e commit 23b0c31

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

hloc/match_dense.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@
8787

8888
def to_cpts(kpts, ps):
8989
if ps > 0.0:
90-
cpts = np.round(np.round((kpts + 0.5) / ps) * ps - 0.5, 2)
91-
return [tuple(cpt) for cpt in cpts]
90+
kpts = np.round(np.round((kpts + 0.5) / ps) * ps - 0.5, 2)
91+
return [tuple(cpt) for cpt in kpts]
9292

9393

9494
def assign_keypoints(kpts: np.ndarray,
@@ -106,7 +106,7 @@ def assign_keypoints(kpts: np.ndarray,
106106
return kpt_ids
107107
else:
108108
ps = cell_size if cell_size is not None else max_error
109-
ps = max(cell_size, max_error)
109+
ps = max(ps, max_error)
110110
# With update we quantize and bin (optionally)
111111
assert isinstance(other_cpts, list)
112112
kpt_ids = []
@@ -270,13 +270,15 @@ def match_dense(conf: Dict,
270270
image0, image1 = image0.to(device), image1.to(device)
271271

272272
# match semi-dense
273-
if name1 in existing_refs:
274-
# flip to enable refinement in query image
273+
# for consistency with pairs_from_*: refine kpts of image0
274+
if name0 in existing_refs:
275+
# special case: flip to enable refinement in query image
275276
pred = model({'image0': image1, 'image1': image0})
276277
pred = {**pred,
277278
'keypoints0': pred['keypoints1'],
278279
'keypoints1': pred['keypoints0']}
279280
else:
281+
# usual case
280282
pred = model({'image0': image0, 'image1': image1})
281283

282284
# Rescale keypoints and move to cpu
@@ -376,16 +378,18 @@ def aggregate_matches(
376378
update1 = name1 in required_queries
377379

378380
# in localization we do not want to bin the query kp
379-
if update1 and not update0 and max_kps is None:
380-
max_error1 = 0.0
381+
# assumes that the query is name0!
382+
if update0 and not update1 and max_kps is None:
383+
max_error0 = cell_size0 = 0.0
381384
else:
382-
max_error1 = conf['max_error']
385+
max_error0 = conf['max_error']
386+
cell_size0 = conf['cell_size']
383387

384388
# Get match ids and extend query keypoints (cpdict)
385-
mkp_ids0 = assign_keypoints(kpts0, cpdict[name0], conf['max_error'],
389+
mkp_ids0 = assign_keypoints(kpts0, cpdict[name0], max_error0,
386390
update0, bindict[name0], scores,
387-
conf['cell_size'])
388-
mkp_ids1 = assign_keypoints(kpts1, cpdict[name1], max_error1,
391+
cell_size0)
392+
mkp_ids1 = assign_keypoints(kpts1, cpdict[name1], conf['max_error'],
389393
update1, bindict[name1], scores,
390394
conf['cell_size'])
391395

hloc/matchers/loftr.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,19 @@ def _init(self, conf):
2323
self.net = LoFTR_(pretrained=conf['weights'], config=cfg)
2424

2525
def _forward(self, data):
26+
# For consistency with hloc pairs, we refine kpts in image0!
27+
rename = {
28+
'keypoints0': 'keypoints1',
29+
'keypoints1': 'keypoints0',
30+
'image0': 'image1',
31+
'image1': 'image0',
32+
'mask0': 'mask1',
33+
'mask1': 'mask0',
34+
}
35+
data_ = {rename[k]: v for k, v in data.items()}
2636
with warnings.catch_warnings():
2737
warnings.simplefilter("ignore")
28-
pred = self.net(data)
38+
pred = self.net(data_)
2939

3040
scores = pred['confidence']
3141

@@ -36,6 +46,8 @@ def _forward(self, data):
3646
pred['keypoints0'][keep], pred['keypoints1'][keep]
3747
scores = scores[keep]
3848

49+
# Switch back indices
50+
pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
3951
pred['scores'] = scores
4052
del pred['confidence']
4153
return pred

0 commit comments

Comments
 (0)