|
1 |
| -import sys |
2 |
| -from pathlib import Path |
3 |
| -from functools import partial |
4 |
| -import torch |
5 |
| -import torch.nn.functional as F |
| 1 | +import kornia |
6 | 2 |
|
7 | 3 | from ..utils.base_model import BaseModel
|
8 | 4 |
|
9 |
| -disk_path = Path(__file__).parent / "../../third_party/disk" |
10 |
| -sys.path.append(str(disk_path)) |
11 |
| -from disk import DISK as _DISK # noqa E402 |
12 |
| - |
13 | 5 |
|
14 | 6 | class DISK(BaseModel):
|
15 | 7 | default_conf = {
|
16 |
| - 'model_name': 'depth-save.pth', |
| 8 | + 'weights': 'depth', |
17 | 9 | 'max_keypoints': None,
|
18 |
| - 'desc_dim': 128, |
19 |
| - 'mode': 'nms', |
20 | 10 | 'nms_window_size': 5,
|
| 11 | + 'detection_threshold': 0.0, |
| 12 | + 'pad_if_not_divisible': True, |
21 | 13 | }
|
22 | 14 | required_inputs = ['image']
|
23 | 15 |
|
24 | 16 | def _init(self, conf):
|
25 |
| - self.model = _DISK(window=8, desc_dim=conf['desc_dim']) |
26 |
| - |
27 |
| - state_dict = torch.load( |
28 |
| - disk_path / conf['model_name'], map_location='cpu') |
29 |
| - if 'extractor' in state_dict: |
30 |
| - weights = state_dict['extractor'] |
31 |
| - elif 'disk' in state_dict: |
32 |
| - weights = state_dict['disk'] |
33 |
| - else: |
34 |
| - raise KeyError('Incompatible weight file!') |
35 |
| - self.model.load_state_dict(weights) |
36 |
| - |
37 |
| - if conf['mode'] == 'nms': |
38 |
| - self.extract = partial( |
39 |
| - self.model.features, |
40 |
| - kind='nms', |
41 |
| - window_size=conf['nms_window_size'], |
42 |
| - cutoff=0., |
43 |
| - n=conf['max_keypoints'] |
44 |
| - ) |
45 |
| - elif conf['mode'] == 'rng': |
46 |
| - self.extract = partial(self.model.features, kind='rng') |
47 |
| - else: |
48 |
| - raise KeyError( |
49 |
| - f'mode must be `nms` or `rng`, got `{conf["mode"]}`') |
| 17 | + self.model = kornia.feature.DISK.from_pretrained(conf['weights']) |
50 | 18 |
|
51 | 19 | def _forward(self, data):
|
52 | 20 | image = data['image']
|
53 |
| - # make sure that the dimensions of the image are multiple of 16 |
54 |
| - orig_h, orig_w = image.shape[-2:] |
55 |
| - new_h = round(orig_h / 16) * 16 |
56 |
| - new_w = round(orig_w / 16) * 16 |
57 |
| - image = F.pad(image, (0, new_w - orig_w, 0, new_h - orig_h)) |
58 |
| - |
59 |
| - batched_features = self.extract(image) |
60 |
| - |
61 |
| - assert(len(batched_features) == 1) |
62 |
| - features = batched_features[0] |
63 |
| - |
64 |
| - # filter points detected in the padded areas |
65 |
| - kpts = features.kp |
66 |
| - valid = torch.all(kpts <= kpts.new_tensor([orig_w, orig_h]) - 1, 1) |
67 |
| - kpts = kpts[valid] |
68 |
| - descriptors = features.desc[valid] |
69 |
| - scores = features.kp_logp[valid] |
70 |
| - |
71 |
| - # order the keypoints |
72 |
| - indices = torch.argsort(scores, descending=True) |
73 |
| - kpts = kpts[indices] |
74 |
| - descriptors = descriptors[indices] |
75 |
| - scores = scores[indices] |
76 |
| - |
| 21 | + features = self.model( |
| 22 | + image, |
| 23 | + n=self.conf['max_keypoints'], |
| 24 | + window_size=self.conf['nms_window_size'], |
| 25 | + score_threshold=self.conf['detection_threshold'], |
| 26 | + pad_if_not_divisible=self.conf['pad_if_not_divisible'], |
| 27 | + ) |
77 | 28 | return {
|
78 |
| - 'keypoints': kpts[None], |
79 |
| - 'descriptors': descriptors.t()[None], |
80 |
| - 'scores': scores[None], |
| 29 | + 'keypoints': [f.keypoints for f in features], |
| 30 | + 'keypoint_scores': [f.detection_scores for f in features], |
| 31 | + 'descriptors': [f.descriptors.t() for f in features], |
81 | 32 | }
|
0 commit comments