55import torch
66from sklearn .metrics import accuracy_score
77from sklearn .svm import SVC
8+ from sklearn .linear_model import LogisticRegression
89
910from gnn_aid .attacks .attack_base import Attacker
1011from gnn_aid .aux .utils import move_to_same_device
1112from gnn_aid .data_structures .mi_results import MIResultsStore
1213from gnn_aid .datasets .gen_dataset import GeneralDataset
14+ from gnn_aid .models_builder import FrameworkGNNConstructor
1315from gnn_aid .models_builder .models_zoo import model_configs_zoo
1416
1517
@@ -204,7 +206,7 @@ def _train_shadow_model(
204206 optimizer .step ()
205207
206208 if epoch % 20 == 0 :
207- print (f" Shadow epoch { epoch } /{ self .shadow_epochs } , loss: { loss .item ():.4f} " )
209+ print (f"Shadow epoch { epoch } /{ self .shadow_epochs } , loss: { loss .item ():.4f} " )
208210
209211 return shadow_model
210212
@@ -296,4 +298,207 @@ def attack(
296298 inferred_membership_full = torch .tensor (all_predictions , dtype = torch .bool )
297299
298300 self .results .add (mask_tensor , inferred_membership_full )
301+ return self .results
302+
303+
304+ class ShadowModelMILinkAttacker (MIAttacker ):
305+ """
306+ Shadow model-based membership inference attack for Link Prediction.
307+ """
308+ name = "ShadowModelMILinkAttacker"
309+
310+ def __init__ (
311+ self ,
312+ shadow_edge_ratio : float = 0.2 ,
313+ shadow_train_ratio : float = 0.75 ,
314+ shadow_epochs : int = 10 ,
315+ classifier_type : str = 'linreg' ,
316+ use_embedding_features : bool = False ,
317+ ** kwargs
318+ ):
319+ super ().__init__ (** kwargs )
320+ self .shadow_edge_ratio = shadow_edge_ratio
321+ self .shadow_train_ratio = shadow_train_ratio
322+ self .shadow_epochs = shadow_epochs
323+ self .classifier_type = classifier_type
324+ self .use_embedding_features = use_embedding_features
325+ self .classifier = None
326+ self .model_name = None
327+
328+ def _prepare_shadow_edge_masks (
329+ self ,
330+ num_edges : int
331+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
332+ """
333+ Create shadow train/test masks over a random subset of edges
334+ """
335+ all_indices = torch .arange (num_edges )
336+ shadow_size = int (num_edges * self .shadow_edge_ratio )
337+ shadow_indices = all_indices [torch .randperm (num_edges )[:shadow_size ]]
338+
339+ n_train = int (shadow_size * self .shadow_train_ratio )
340+ shadow_train_indices = shadow_indices [:n_train ]
341+ shadow_test_indices = shadow_indices [n_train :]
342+
343+ shadow_train_mask = torch .zeros (num_edges , dtype = torch .bool )
344+ shadow_test_mask = torch .zeros (num_edges , dtype = torch .bool )
345+ shadow_train_mask [shadow_train_indices ] = True
346+ shadow_test_mask [shadow_test_indices ] = True
347+
348+ return shadow_train_mask , shadow_test_mask
349+
350+ def _train_shadow_model (
351+ self ,
352+ shadow_model : torch .nn .Module ,
353+ shadow_dataset : GeneralDataset ,
354+ shadow_train_mask : torch .Tensor ,
355+ device : torch .device
356+ ) -> torch .nn .Module :
357+ """
358+ Train shadow model on shadow dataset
359+ """
360+ shadow_model = shadow_model .to (device )
361+ optimizer = torch .optim .Adam (shadow_model .parameters (), lr = 0.01 )
362+ criterion = torch .nn .BCEWithLogitsLoss ()
363+
364+ for epoch in range (self .shadow_epochs ):
365+ shadow_model .train ()
366+ optimizer .zero_grad ()
367+
368+ node_emb = shadow_model (
369+ shadow_dataset .data .x .to (device ),
370+ shadow_dataset .data .edge_index .to (device )
371+ )
372+
373+ train_edge_index = shadow_dataset .edge_label_index [:, shadow_train_mask ].to (device )
374+ train_edge_labels = shadow_dataset .edge_labels [shadow_train_mask ].float ().to (device )
375+
376+ edge_logits = shadow_model .decode (node_emb [train_edge_index [0 ]], node_emb [train_edge_index [1 ]]).squeeze ()
377+
378+ loss = criterion (edge_logits , train_edge_labels )
379+ loss .backward ()
380+ optimizer .step ()
381+
382+ if epoch % 10 == 0 :
383+ print (f"Shadow epoch { epoch } /{ self .shadow_epochs } , loss: { loss .item ():.4f} " )
384+
385+ return shadow_model
386+
387+ def _extract_edge_features (
388+ self ,
389+ model : torch .nn .Module ,
390+ dataset : GeneralDataset ,
391+ edge_mask : torch .Tensor ,
392+ device : torch .device
393+ ) -> np .ndarray :
394+ model .eval ()
395+ model = model .to (device )
396+ with torch .no_grad ():
397+ node_emb = model (
398+ dataset .data .x .to (device ),
399+ dataset .data .edge_index .to (device )
400+ )
401+
402+ edge_index = dataset .edge_label_index [:, edge_mask ].to (device )
403+
404+ edge_logits = model .decode (node_emb [edge_index [0 ]], node_emb [edge_index [1 ]]).squeeze ()
405+ edge_probs = torch .sigmoid (edge_logits )
406+
407+ if self .use_embedding_features :
408+ features = torch .cat ([
409+ edge_probs .unsqueeze (1 ),
410+ node_emb [edge_index [0 ]],
411+ node_emb [edge_index [1 ]]
412+ ], dim = 1 )
413+ else :
414+ features = edge_probs .unsqueeze (1 )
415+
416+ return features .cpu ().numpy ()
417+
418+ def _train_attack_classifier (
419+ self ,
420+ shadow_model : torch .nn .Module ,
421+ shadow_dataset : GeneralDataset ,
422+ shadow_train_mask : torch .Tensor ,
423+ shadow_test_mask : torch .Tensor ,
424+ device : torch .device
425+ ):
426+ """
427+ Train attack classifier using shadow model outputs
428+ """
429+ X_train = self ._extract_edge_features (shadow_model , shadow_dataset , shadow_train_mask , device )
430+ y_train = np .ones (X_train .shape [0 ])
431+
432+ X_test = self ._extract_edge_features (shadow_model , shadow_dataset , shadow_test_mask , device )
433+ y_test = np .zeros (X_test .shape [0 ])
434+
435+ X = np .vstack ([X_train , X_test ])
436+ y = np .concatenate ([y_train , y_test ])
437+
438+ import matplotlib .pyplot as plt
439+ plt .hist (X_train [:, 0 ], bins = 50 , alpha = 0.5 , label = 'Train edges' )
440+ plt .hist (X_test [:, 0 ], bins = 50 , alpha = 0.5 , label = 'Test edges' )
441+ plt .legend ()
442+ plt .title ('Probability distributions: Train vs Test edges' )
443+ plt .savefig ('edge_prob_distributions.png' )
444+
445+ if self .classifier_type == 'svc' :
446+ self .classifier = SVC (kernel = 'rbf' , probability = True )
447+ elif self .classifier_type == 'linreg' :
448+ self .classifier = LogisticRegression (max_iter = 1000 )
449+ else :
450+ raise ValueError (f"Unsupported classifier: { self .classifier_type } " )
451+
452+ self .classifier .fit (X , y )
453+
454+ def attack (
455+ self ,
456+ model : torch .nn .Module ,
457+ gen_dataset : GeneralDataset ,
458+ mask_tensor : Union [torch .Tensor , list ],
459+ ** kwargs
460+ ):
461+ """
462+ Perform membership inference attack on target model
463+ """
464+ if isinstance (mask_tensor , str ):
465+ if mask_tensor == 'train' :
466+ mask_tensor = gen_dataset .train_mask
467+ elif mask_tensor == 'val' :
468+ mask_tensor = gen_dataset .val_mask
469+ elif mask_tensor == 'test' :
470+ mask_tensor = gen_dataset .test_mask
471+ elif mask_tensor == 'all' :
472+ mask_tensor = torch .ones (
473+ gen_dataset .edge_label_index .size (1 ),
474+ dtype = torch .bool ,
475+ device = gen_dataset .train_mask .device
476+ )
477+ else :
478+ raise ValueError (f"Unknown mask string: { mask_tensor } " )
479+
480+ num_edges = gen_dataset .edge_label_index .size (1 )
481+ device = next (model .parameters ()).device
482+
483+ shadow_train_mask , shadow_test_mask = self ._prepare_shadow_edge_masks (num_edges )
484+
485+ shadow_dataset = copy .deepcopy (gen_dataset )
486+ shadow_dataset .train_mask = shadow_train_mask
487+ shadow_dataset .test_mask = shadow_test_mask
488+
489+ shadow_model = model_configs_zoo (dataset = shadow_dataset , model_name = 'gcn_link_pred' )
490+ shadow_model = self ._train_shadow_model (shadow_model , shadow_dataset , shadow_train_mask , device )
491+
492+ self ._train_attack_classifier (
493+ shadow_model , shadow_dataset , shadow_train_mask , shadow_test_mask , device
494+ )
495+
496+ target_features = self ._extract_edge_features (model , gen_dataset , torch .ones (num_edges , dtype = torch .bool ),
497+ device )
498+
499+ all_predictions = self .classifier .predict (target_features )
500+ inferred_membership_full = torch .tensor (all_predictions , dtype = torch .bool )
501+
502+ self .results .add (mask_tensor , inferred_membership_full )
503+ members_count = inferred_membership_full .sum ().item ()
299504 return self .results
0 commit comments