Skip to content

Commit b8c9963

Browse files
committed
add saving of features to file
1 parent 87b112a commit b8c9963

File tree

3 files changed

+123
-9
lines changed

3 files changed

+123
-9
lines changed

features.ipynb

Lines changed: 109 additions & 0 deletions
Large diffs are not rendered by default.

requirements.txt

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,27 @@ json_tricks
66
yacs
77
scikit-learn
88
pandas
9-
timm==0.4.12
10-
numpy==1.23.5
9+
timm
10+
numpy
1111
einops
1212
fvcore
13-
transformers==4.19.2
13+
transformers
1414
sentencepiece
1515
ftfy
1616
regex
1717
nltk
18-
vision-datasets==0.2.2
19-
pycocotools==2.0.4
18+
vision-datasets
19+
pycocotools
2020
diffdist
2121
pyarrow
2222
cityscapesscripts
2323
shapely
2424
scikit-image
2525
mup
26-
gradio==3.35.2
26+
gradio
2727
scann
28-
kornia==0.6.4
29-
torchmetrics==0.6.0
28+
kornia
29+
torchmetrics
3030
mpi4py
3131
progressbar
32-
pillow==9.4.0
32+
pillow

semantic_sam/architectures/interactive_mask_dino.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,13 @@ def evaluate_demo(self, batched_inputs,all_whole=None,all_part=None,mask_feature
304304
if mask_features is None or multi_scale_features is None:
305305

306306
features = self.backbone(images.tensor)
307+
# save to file
308+
torch.save(features, 'backbone_features.pt')
307309
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(
308310
features, None)
311+
torch.save(mask_features, 'mask_features.pt')
312+
torch.save(multi_scale_features, 'multi_scale_features.pt')
313+
torch.save(transformer_encoder_features, 'transformer_encoder_features.pt')
309314
outputs, mask_dict = self.sem_seg_head.predictor(multi_scale_features, mask_features, None, targets=targets,
310315
target_queries=None, target_vlp=None, task='demo', extra=prediction_switch)
311316
pred_ious=None

0 commit comments

Comments
 (0)