Skip to content

Commit d278e25

Browse files
authored
Made examples in readme more interesting
Made examples in readme more interesting
1 parent db2b965 commit d278e25

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ from oml import datasets as d
546546
from oml.inference import inference
547547
from oml.losses import TripletLossWithMiner
548548
from oml.metrics import calc_retrieval_metrics_rr
549-
from oml.miners import AllTripletsMiner
549+
from oml.miners import HardTripletsMiner
550550
from oml.models import ViTExtractor
551551
from oml.registry import get_transforms_for_pretrained
552552
from oml.retrieval import RetrievalResults, AdaptiveThresholding
@@ -561,7 +561,7 @@ train = d.ImageLabeledDataset(df_train, transform=transform)
561561
val = d.ImageQueryGalleryLabeledDataset(df_val, transform=transform)
562562

563563
optimizer = Adam(model.parameters(), lr=1e-4)
564-
criterion = TripletLossWithMiner(0.1, AllTripletsMiner(), need_logs=True)
564+
criterion = TripletLossWithMiner(0.1, HardTripletsMiner(), need_logs=True)
565565
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)
566566

567567

@@ -601,7 +601,7 @@ from oml import datasets as d
601601
from oml.inference import inference
602602
from oml.losses import TripletLossWithMiner
603603
from oml.metrics import calc_retrieval_metrics_rr
604-
from oml.miners import AllTripletsMiner
604+
from oml.miners import NHardTripletsMiner
605605
from oml.models import HFWrapper
606606
from oml.retrieval import RetrievalResults, AdaptiveThresholding
607607
from oml.samplers import BalanceSampler
@@ -615,7 +615,9 @@ train = d.TextLabeledDataset(df_train, tokenizer=tokenizer)
615615
val = d.TextQueryGalleryLabeledDataset(df_val, tokenizer=tokenizer)
616616

617617
optimizer = Adam(model.parameters(), lr=1e-4)
618-
criterion = TripletLossWithMiner(0.1, AllTripletsMiner(), need_logs=True)
618+
criterion = TripletLossWithMiner(
619+
0.1, NHardTripletsMiner(n_positive=2, n_negative=2), need_logs=True
620+
)
619621
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)
620622

621623

@@ -651,9 +653,8 @@ from torch.utils.data import DataLoader
651653

652654
from oml import datasets as d
653655
from oml.inference import inference
654-
from oml.losses import TripletLossWithMiner
656+
from oml.losses import ArcFaceLoss
655657
from oml.metrics import calc_retrieval_metrics_rr
656-
from oml.miners import AllTripletsMiner
657658
from oml.models import ECAPATDNNExtractor
658659
from oml.retrieval import AdaptiveThresholding, RetrievalResults
659660
from oml.samplers import BalanceSampler
@@ -666,7 +667,7 @@ train = d.AudioLabeledDataset(df_train)
666667
val = d.AudioQueryGalleryLabeledDataset(df_val)
667668

668669
optimizer = Adam(model.parameters(), lr=1e-4)
669-
criterion = TripletLossWithMiner(0.1, AllTripletsMiner(), need_logs=True)
670+
criterion = ArcFaceLoss(m=0.2, s=30, in_features=192, num_classes=4) # similar to paper
670671
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)
671672

672673

docs/readme/examples_source/extractor/train_val_all_modalities.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ from oml import datasets as d
3838
from oml.inference import inference
3939
from oml.losses import TripletLossWithMiner
4040
from oml.metrics import calc_retrieval_metrics_rr
41-
from oml.miners import AllTripletsMiner
41+
from oml.miners import HardTripletsMiner
4242
from oml.models import ViTExtractor
4343
from oml.registry import get_transforms_for_pretrained
4444
from oml.retrieval import RetrievalResults, AdaptiveThresholding
@@ -53,7 +53,7 @@ train = d.ImageLabeledDataset(df_train, transform=transform)
5353
val = d.ImageQueryGalleryLabeledDataset(df_val, transform=transform)
5454

5555
optimizer = Adam(model.parameters(), lr=1e-4)
56-
criterion = TripletLossWithMiner(0.1, AllTripletsMiner(), need_logs=True)
56+
criterion = TripletLossWithMiner(0.1, HardTripletsMiner(), need_logs=True)
5757
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)
5858

5959

@@ -93,7 +93,7 @@ from oml import datasets as d
9393
from oml.inference import inference
9494
from oml.losses import TripletLossWithMiner
9595
from oml.metrics import calc_retrieval_metrics_rr
96-
from oml.miners import AllTripletsMiner
96+
from oml.miners import NHardTripletsMiner
9797
from oml.models import HFWrapper
9898
from oml.retrieval import RetrievalResults, AdaptiveThresholding
9999
from oml.samplers import BalanceSampler
@@ -107,7 +107,9 @@ train = d.TextLabeledDataset(df_train, tokenizer=tokenizer)
107107
val = d.TextQueryGalleryLabeledDataset(df_val, tokenizer=tokenizer)
108108

109109
optimizer = Adam(model.parameters(), lr=1e-4)
110-
criterion = TripletLossWithMiner(0.1, AllTripletsMiner(), need_logs=True)
110+
criterion = TripletLossWithMiner(
111+
0.1, NHardTripletsMiner(n_positive=2, n_negative=2), need_logs=True
112+
)
111113
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)
112114

113115

@@ -143,9 +145,8 @@ from torch.utils.data import DataLoader
143145

144146
from oml import datasets as d
145147
from oml.inference import inference
146-
from oml.losses import TripletLossWithMiner
148+
from oml.losses import ArcFaceLoss
147149
from oml.metrics import calc_retrieval_metrics_rr
148-
from oml.miners import AllTripletsMiner
149150
from oml.models import ECAPATDNNExtractor
150151
from oml.retrieval import AdaptiveThresholding, RetrievalResults
151152
from oml.samplers import BalanceSampler
@@ -158,7 +159,7 @@ train = d.AudioLabeledDataset(df_train)
158159
val = d.AudioQueryGalleryLabeledDataset(df_val)
159160

160161
optimizer = Adam(model.parameters(), lr=1e-4)
161-
criterion = TripletLossWithMiner(0.1, AllTripletsMiner(), need_logs=True)
162+
criterion = ArcFaceLoss(m=0.2, s=30, in_features=192, num_classes=4) # similar to paper
162163
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)
163164

164165

0 commit comments

Comments
 (0)