Skip to content

Commit 90733da

Browse files
authored
Switch to kornia DISK (#291)
* Switch to kornia DISK * Bump kornia minimum version
1 parent a828176 commit 90733da

File tree

4 files changed

+16
-69
lines changed

4 files changed

+16
-69
lines changed

.gitmodules

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,3 @@
1010
[submodule "third_party/r2d2"]
1111
path = third_party/r2d2
1212
url = https://github.com/naver/r2d2.git
13-
[submodule "third_party/disk"]
14-
path = third_party/disk
15-
url = https://github.com/cvlab-epfl/disk.git

hloc/extractors/disk.py

Lines changed: 15 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,32 @@
1-
import sys
2-
from pathlib import Path
3-
from functools import partial
4-
import torch
5-
import torch.nn.functional as F
1+
import kornia
62

73
from ..utils.base_model import BaseModel
84

9-
disk_path = Path(__file__).parent / "../../third_party/disk"
10-
sys.path.append(str(disk_path))
11-
from disk import DISK as _DISK # noqa E402
12-
135

146
class DISK(BaseModel):
157
default_conf = {
16-
'model_name': 'depth-save.pth',
8+
'weights': 'depth',
179
'max_keypoints': None,
18-
'desc_dim': 128,
19-
'mode': 'nms',
2010
'nms_window_size': 5,
11+
'detection_threshold': 0.0,
12+
'pad_if_not_divisible': True,
2113
}
2214
required_inputs = ['image']
2315

2416
def _init(self, conf):
25-
self.model = _DISK(window=8, desc_dim=conf['desc_dim'])
26-
27-
state_dict = torch.load(
28-
disk_path / conf['model_name'], map_location='cpu')
29-
if 'extractor' in state_dict:
30-
weights = state_dict['extractor']
31-
elif 'disk' in state_dict:
32-
weights = state_dict['disk']
33-
else:
34-
raise KeyError('Incompatible weight file!')
35-
self.model.load_state_dict(weights)
36-
37-
if conf['mode'] == 'nms':
38-
self.extract = partial(
39-
self.model.features,
40-
kind='nms',
41-
window_size=conf['nms_window_size'],
42-
cutoff=0.,
43-
n=conf['max_keypoints']
44-
)
45-
elif conf['mode'] == 'rng':
46-
self.extract = partial(self.model.features, kind='rng')
47-
else:
48-
raise KeyError(
49-
f'mode must be `nms` or `rng`, got `{conf["mode"]}`')
17+
self.model = kornia.feature.DISK.from_pretrained(conf['weights'])
5018

5119
def _forward(self, data):
5220
image = data['image']
53-
# make sure that the dimensions of the image are multiple of 16
54-
orig_h, orig_w = image.shape[-2:]
55-
new_h = round(orig_h / 16) * 16
56-
new_w = round(orig_w / 16) * 16
57-
image = F.pad(image, (0, new_w - orig_w, 0, new_h - orig_h))
58-
59-
batched_features = self.extract(image)
60-
61-
assert(len(batched_features) == 1)
62-
features = batched_features[0]
63-
64-
# filter points detected in the padded areas
65-
kpts = features.kp
66-
valid = torch.all(kpts <= kpts.new_tensor([orig_w, orig_h]) - 1, 1)
67-
kpts = kpts[valid]
68-
descriptors = features.desc[valid]
69-
scores = features.kp_logp[valid]
70-
71-
# order the keypoints
72-
indices = torch.argsort(scores, descending=True)
73-
kpts = kpts[indices]
74-
descriptors = descriptors[indices]
75-
scores = scores[indices]
76-
21+
features = self.model(
22+
image,
23+
n=self.conf['max_keypoints'],
24+
window_size=self.conf['nms_window_size'],
25+
score_threshold=self.conf['detection_threshold'],
26+
pad_if_not_divisible=self.conf['pad_if_not_divisible'],
27+
)
7728
return {
78-
'keypoints': kpts[None],
79-
'descriptors': descriptors.t()[None],
80-
'scores': scores[None],
29+
'keypoints': [f.keypoints for f in features],
30+
'keypoint_scores': [f.detection_scores for f in features],
31+
'descriptors': [f.descriptors.t() for f in features],
8132
}

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ plotly
88
scipy
99
h5py
1010
pycolmap>=0.3.0
11-
kornia>=0.6.7
11+
kornia>=0.6.11
1212
gdown
1313
lightglue @ git+https://github.com/cvg/LightGlue

third_party/disk

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)