@@ -546,7 +546,7 @@ from oml import datasets as d
546546from oml.inference import inference
547547from oml.losses import TripletLossWithMiner
548548from oml.metrics import calc_retrieval_metrics_rr
549- from oml.miners import AllTripletsMiner
549+ from oml.miners import HardTripletsMiner
550550from oml.models import ViTExtractor
551551from oml.registry import get_transforms_for_pretrained
552552from oml.retrieval import RetrievalResults, AdaptiveThresholding
@@ -561,7 +561,7 @@ train = d.ImageLabeledDataset(df_train, transform=transform)
561561val = d.ImageQueryGalleryLabeledDataset(df_val, transform = transform)
562562
563563optimizer = Adam(model.parameters(), lr = 1e-4 )
564- criterion = TripletLossWithMiner(0.1 , AllTripletsMiner (), need_logs = True )
564+ criterion = TripletLossWithMiner(0.1 , HardTripletsMiner (), need_logs = True )
565565sampler = BalanceSampler(train.get_labels(), n_labels = 2 , n_instances = 2 )
566566
567567
@@ -601,7 +601,7 @@ from oml import datasets as d
601601from oml.inference import inference
602602from oml.losses import TripletLossWithMiner
603603from oml.metrics import calc_retrieval_metrics_rr
604- from oml.miners import AllTripletsMiner
604+ from oml.miners import NHardTripletsMiner
605605from oml.models import HFWrapper
606606from oml.retrieval import RetrievalResults, AdaptiveThresholding
607607from oml.samplers import BalanceSampler
@@ -615,7 +615,9 @@ train = d.TextLabeledDataset(df_train, tokenizer=tokenizer)
615615val = d.TextQueryGalleryLabeledDataset(df_val, tokenizer = tokenizer)
616616
617617optimizer = Adam(model.parameters(), lr = 1e-4 )
618- criterion = TripletLossWithMiner(0.1 , AllTripletsMiner(), need_logs = True )
618+ criterion = TripletLossWithMiner(
619+ 0.1 , NHardTripletsMiner(n_positive = 2 , n_negative = 2 ), need_logs = True
620+ )
619621sampler = BalanceSampler(train.get_labels(), n_labels = 2 , n_instances = 2 )
620622
621623
@@ -651,9 +653,8 @@ from torch.utils.data import DataLoader
651653
652654from oml import datasets as d
653655from oml.inference import inference
654- from oml.losses import TripletLossWithMiner
656+ from oml.losses import ArcFaceLoss
655657from oml.metrics import calc_retrieval_metrics_rr
656- from oml.miners import AllTripletsMiner
657658from oml.models import ECAPATDNNExtractor
658659from oml.retrieval import AdaptiveThresholding, RetrievalResults
659660from oml.samplers import BalanceSampler
@@ -666,7 +667,7 @@ train = d.AudioLabeledDataset(df_train)
666667val = d.AudioQueryGalleryLabeledDataset(df_val)
667668
668669optimizer = Adam(model.parameters(), lr = 1e-4 )
669- criterion = TripletLossWithMiner( 0.1 , AllTripletsMiner(), need_logs = True )
670+ criterion = ArcFaceLoss( m = 0.2 , s = 30 , in_features = 192 , num_classes = 4 ) # similar to paper
670671sampler = BalanceSampler(train.get_labels(), n_labels = 2 , n_instances = 2 )
671672
672673
0 commit comments