From 91a04b0074d0aaa55e6e6847c55b163b89336d5a Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 17 Jul 2023 13:52:44 +0200 Subject: [PATCH 1/2] Switch to kornia DISK --- .gitmodules | 3 -- hloc/extractors/disk.py | 79 ++++++++--------------------------------- third_party/disk | 1 - 3 files changed, 15 insertions(+), 68 deletions(-) delete mode 160000 third_party/disk diff --git a/.gitmodules b/.gitmodules index 982d0a30..f869e77a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,6 +10,3 @@ [submodule "third_party/r2d2"] path = third_party/r2d2 url = https://github.com/naver/r2d2.git -[submodule "third_party/disk"] - path = third_party/disk - url = https://github.com/cvlab-epfl/disk.git diff --git a/hloc/extractors/disk.py b/hloc/extractors/disk.py index d668d30c..dc04280c 100644 --- a/hloc/extractors/disk.py +++ b/hloc/extractors/disk.py @@ -1,81 +1,32 @@ -import sys -from pathlib import Path -from functools import partial -import torch -import torch.nn.functional as F +import kornia from ..utils.base_model import BaseModel -disk_path = Path(__file__).parent / "../../third_party/disk" -sys.path.append(str(disk_path)) -from disk import DISK as _DISK # noqa E402 - class DISK(BaseModel): default_conf = { - 'model_name': 'depth-save.pth', + 'weights': 'depth', 'max_keypoints': None, - 'desc_dim': 128, - 'mode': 'nms', 'nms_window_size': 5, + 'detection_threshold': 0.0, + 'pad_if_not_divisible': True, } required_inputs = ['image'] def _init(self, conf): - self.model = _DISK(window=8, desc_dim=conf['desc_dim']) - - state_dict = torch.load( - disk_path / conf['model_name'], map_location='cpu') - if 'extractor' in state_dict: - weights = state_dict['extractor'] - elif 'disk' in state_dict: - weights = state_dict['disk'] - else: - raise KeyError('Incompatible weight file!') - self.model.load_state_dict(weights) - - if conf['mode'] == 'nms': - self.extract = partial( - self.model.features, - kind='nms', - window_size=conf['nms_window_size'], - cutoff=0., - n=conf['max_keypoints'] - ) - elif conf['mode'] == 'rng': - self.extract = partial(self.model.features, kind='rng') - else: - raise KeyError( - f'mode must be `nms` or `rng`, got `{conf["mode"]}`') + self.model = kornia.feature.DISK.from_pretrained(conf['weights']) def _forward(self, data): image = data['image'] - # make sure that the dimensions of the image are multiple of 16 - orig_h, orig_w = image.shape[-2:] - new_h = round(orig_h / 16) * 16 - new_w = round(orig_w / 16) * 16 - image = F.pad(image, (0, new_w - orig_w, 0, new_h - orig_h)) - - batched_features = self.extract(image) - - assert(len(batched_features) == 1) - features = batched_features[0] - - # filter points detected in the padded areas - kpts = features.kp - valid = torch.all(kpts <= kpts.new_tensor([orig_w, orig_h]) - 1, 1) - kpts = kpts[valid] - descriptors = features.desc[valid] - scores = features.kp_logp[valid] - - # order the keypoints - indices = torch.argsort(scores, descending=True) - kpts = kpts[indices] - descriptors = descriptors[indices] - scores = scores[indices] - + features = self.model( + image, + n=self.conf['max_keypoints'], + window_size=self.conf['nms_window_size'], + score_threshold=self.conf['detection_threshold'], + pad_if_not_divisible=self.conf['pad_if_not_divisible'], + ) return { - 'keypoints': kpts[None], - 'descriptors': descriptors.t()[None], - 'scores': scores[None], + 'keypoints': [f.keypoints for f in features], + 'keypoint_scores': [f.detection_scores for f in features], + 'descriptors': [f.descriptors.t() for f in features], } diff --git a/third_party/disk b/third_party/disk deleted file mode 160000 index eafa0ee8..00000000 --- a/third_party/disk +++ /dev/null @@ -1 +0,0 @@ -Subproject commit eafa0ee80d26f86d6c838b78cea8ae4c9abd4b51 From 733d575f71c56af441125696fb350683d5bd5680 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 17 Jul 2023 13:54:39 +0200 Subject: [PATCH 2/2] Bump kornia minimum version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6b6499d8..6b25438b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,6 @@ plotly scipy h5py pycolmap>=0.3.0 -kornia>=0.6.7 +kornia>=0.6.11 gdown lightglue @ git+https://github.com/cvg/LightGlue