66from gnn_aid .aux .utils import move_to_same_device
77from gnn_aid .datasets .gen_dataset import GeneralDataset
88from gnn_aid .models_builder .model_managers import GNNModelManager
9+ from gnn_aid .data_structures .configs import Task
910
1011# Nettack imports
1112from .evasion_attacks_collection .nettack .utils import NettackSurrogate , NettackAttack
2526# ReWatt imports
2627from .evasion_attacks_collection .rewatt .utils import GraphEnvironment , ReWattPolicyNet , \
2728 GraphState , ReWattAgent
28- from ..data_structures import Task
2929
3030
3131class EvasionAttacker (
@@ -65,6 +65,9 @@ def check_availability(
6565 model_manager : GNNModelManager
6666 ):
6767 """ Availability check for the given dataset and model manager. """
68+ rules = [
69+ gen_dataset .dataset_var_config .task
70+ ]
6871 return True
6972
7073 def __init__ (
@@ -86,7 +89,7 @@ def attack(
8689 model_manager : Type ,
8790 gen_dataset : GeneralDataset ,
8891 mask_tensor : torch .Tensor ,
89- task_type : str = None ,
92+ task_type : str = None , # FIXME remove
9093 ):
9194 task = gen_dataset .dataset_var_config .task
9295 device = gen_dataset .data .x .device
@@ -110,9 +113,9 @@ def attack(
110113 edge_out = model .decode (src , dst ).unsqueeze (dim = 0 ).to (device )
111114
112115 # TODO use model_manager.loss_function when BCE support
113- # loss = model_manager.loss_function(edge_out, edge_label)
114- criterion = torch .nn .BCEWithLogitsLoss ()
115- loss = criterion (edge_out , edge_label )
116+ loss = model_manager .loss_function (edge_out , edge_label )
117+ # criterion = torch.nn.BCEWithLogitsLoss()
118+ # loss = criterion(edge_out, edge_label)
116119 model .zero_grad ()
117120 loss .backward ()
118121 sign_data_grad = x .grad .sign ()
@@ -425,7 +428,7 @@ def attack(
425428 model_manager : Type ,
426429 gen_dataset : GeneralDataset ,
427430 mask_tensor : torch .Tensor ,
428- task_type : str = None ,
431+ task_type : str = None , # FIXME remove
429432 ) -> None :
430433 if task_type is None :
431434 task_type = gen_dataset .is_multi ()
@@ -662,7 +665,7 @@ def check_availability(
662665
663666 def __init__ (
664667 self ,
665- node_idx : int = 0 ,
668+ element_idx : int = 0 ,
666669 budget : Union [int , None ] = None ,
667670 perturb_features : bool = True ,
668671 perturb_structure : bool = True ,
@@ -674,7 +677,7 @@ def __init__(
674677 ):
675678 super ().__init__ ()
676679 self .attack_diff = None
677- self .node_idx = node_idx
680+ self .element_idx = element_idx
678681 self .budget = budget
679682 self .perturb_features = perturb_features
680683 self .perturb_structure = perturb_structure
@@ -700,13 +703,13 @@ def attack(
700703 # surrogate.evaluate(x, edge_index, y)
701704
702705 attacker = NettackAttack (
703- real_class = data .y [self .node_idx ].item (),
706+ real_class = data .y [self .element_idx ].item (),
704707 gnn_model = model_manager .gnn ,
705708 model = surrogate ,
706709 x = x ,
707710 edge_index = edge_index ,
708711 num_classes = num_classes ,
709- target_node = self .node_idx ,
712+ target_node = self .element_idx ,
710713 attack_diff = self .attack_diff ,
711714 direct = self .direct ,
712715 depth = self .depth ,
@@ -719,14 +722,14 @@ def attack(
719722 elif self .perturb_structure and not self .perturb_features :
720723 mode = "structure"
721724
722- # logits_before = surrogate.forward(edge_index, x)[self.node_idx ]
725+ # logits_before = surrogate.forward(edge_index, x)[self.element_idx ]
723726 # pred_before = logits_before.argmax().item()
724727 # prob_before = torch.softmax(logits_before, dim=0)[pred_before].item()
725728 # print(f"Surrogate prediction before attack: {pred_before} (confidence: {prob_before:.4f})")
726729
727730 attacker .attack (budget = self .budget , mode = mode )
728731
729- # logits_after = surrogate.forward(attacker.edge_index, attacker.x)[self.node_idx ]
732+ # logits_after = surrogate.forward(attacker.edge_index, attacker.x)[self.element_idx ]
730733 # pred_after = logits_after.argmax().item()
731734 # prob_after = torch.softmax(logits_after, dim=0)[pred_after].item()
732735 # print(f"Surrogate prediction after attack: {pred_after} (confidence: {prob_after:.4f})")
0 commit comments