Skip to content

Commit 9cd8f55

Browse files
committed
add fgsm for link pred + test
1 parent 0c8154e commit 9cd8f55

File tree

4 files changed

+88
-9
lines changed

4 files changed

+88
-9
lines changed

gnn_aid/attacks/evasion_attacks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878
self.epsilon = epsilon
7979
self.grad_aggr_type = 'mean'
8080
self.attack_diff = GraphModificationArtifact()
81-
self.attack_res_misha = None
81+
# self.attack_res_misha = None
8282

8383
def attack(
8484
self,
@@ -100,7 +100,7 @@ def attack(
100100
x.requires_grad = True
101101

102102
# TODO now support only one edge, mask_tensor?
103-
edge_label_index = torch.tensor([[5], [6]]).to(device)
103+
edge_label_index = torch.tensor(self.element_idx).unsqueeze(dim=1).to(device)
104104
edge_label = ((data.edge_index == edge_label_index).all(dim=0).any()).float().unsqueeze(dim=0).to(device)
105105

106106
node_out = model(data.x, data.edge_index)
@@ -170,7 +170,7 @@ def attack(
170170
y = gen_dataset.data.y
171171
x = gen_dataset.data.x
172172

173-
edge_label_index = torch.tensor([[7], [6]]).to(device)
173+
edge_label_index = torch.tensor(self.element_idx).unsqueeze(dim=1).to(device)
174174
edge_label = ((edge_index == edge_label_index).all(dim=0).any()).float().unsqueeze(dim=0).to(device)
175175
node_idx_1 = edge_label_index[0].item()
176176
node_idx_2 = edge_label_index[1].item()
@@ -275,8 +275,8 @@ def attack(
275275
perturbed_edges = torch.cat((perturbed_edges[:, :max_index], perturbed_edges[:, max_index + 1:]),
276276
dim=1)
277277

278-
from torch_geometric.data import Data
279-
self.attack_res_misha = Data(x=x, edge_index=perturbed_edges, y=y)
278+
# from torch_geometric.data import Data
279+
# self.attack_res_misha = Data(x=x, edge_index=perturbed_edges, y=y)
280280
set_a = set(map(tuple, edge_index.T.tolist()))
281281
set_b = set(map(tuple, perturbed_edges.T.tolist()))
282282

@@ -348,7 +348,7 @@ def attack(
348348
edges_to_keep = edge_index[:, ~edge_mask]
349349
updated_edge_index = torch.cat([edges_to_keep, perturbed_edges], dim=1)
350350
gen_dataset.data.edge_index = updated_edge_index
351-
self.attack_res_misha = gen_dataset
351+
# self.attack_res_misha = gen_dataset
352352

353353
set_a = set(map(tuple, edge_index.T.tolist()))
354354
set_b = set(map(tuple, updated_edge_index.T.tolist()))

metainfo/evasion_attack_parameters.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
},
44
"FGSM": {
55
"is_feature_attack": ["is feature attack", "bool", false, {}, "If true, applies perturbations to node features; otherwise, perturbs the graph structure (edges)"],
6-
"element_idx": ["node", "int", 0, {"min": 0, "step": 1}, "Index of the element to attack"],
6+
"element_idx": ["node", "int_or_tuple", 0, {"min": 0, "step": 1}, "Index of the element to attack"],
77
"epsilon": ["epsilon", "float", 0.1, {"min": 0.0001, "step": 0.01}, "Magnitude of the perturbation — higher values result in stronger attacks"]
88
},
99
"Nettack": {

tests/attacks_test.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ConfigPattern, FeatureConfig, Task
1212
from gnn_aid.models_builder.models_zoo import model_configs_zoo
1313
from gnn_aid.aux.utils import POISON_ATTACK_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
14-
OPTIMIZERS_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH
14+
OPTIMIZERS_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH, FUNCTIONS_PARAMETERS_PATH
1515
from .utils import monkey_patch_dirs, cleanup_patches
1616

1717

@@ -75,6 +75,15 @@ def setUp(self):
7575
}
7676
)
7777

78+
# Cora for link pred
79+
dc = DatasetConfig((LibPTGDataset.data_folder, 'Homogeneous', 'Planetoid', 'Cora'))
80+
dvc = LibPTGDataset.default_dataset_var_config.clone_with({"task": Task.EDGE_PREDICTION})
81+
82+
self.gen_dataset_sg_cora_link = DatasetManager.get_by_config(dc, dvc)
83+
self.gen_dataset_sg_cora_link.train_test_split(percent_train_class=0.85, percent_test_class=0.1)
84+
self.results_dataset_path_sg_cora_link = self.gen_dataset_sg_cora_link.prepared_dir
85+
self.gen_dataset_sg_cora_link.data.to(self.my_device)
86+
7887
monkey_patch_dirs()
7988

8089
def tearDown(self):
@@ -368,6 +377,76 @@ def test_fgsm_SG(self):
368377
metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
369378
# ---------- ----------------- ----------
370379

380+
def test_fgsm_LINK(self):
381+
sage_cossim = model_configs_zoo(dataset=self.gen_dataset_sg_cora_link, model_name="sage_cossim")
382+
383+
manager_config = ConfigPattern(
384+
_config_class="ModelManagerConfig",
385+
_config_kwargs={
386+
"batch": 64,
387+
"mask_features": [],
388+
"optimizer": {
389+
"_class_name": "Adam",
390+
"_config_kwargs": {},
391+
},
392+
"loss_function": {
393+
"_config_class": "Config",
394+
"_class_name": "CrossEntropyLoss",
395+
"_import_path": FUNCTIONS_PARAMETERS_PATH,
396+
"_class_import_info": ["torch.nn"],
397+
"_config_kwargs": {},
398+
},
399+
}
400+
)
401+
402+
gnn_model_manager = FrameworkGNNModelManager(
403+
gnn=sage_cossim,
404+
dataset_path=self.gen_dataset_sg_cora_link.prepared_dir,
405+
manager_config=manager_config,
406+
modification=ModelModificationConfig(model_ver_ind=0, epochs=0)
407+
)
408+
409+
gnn_model_manager.gnn.to(self.my_device)
410+
gnn_model_manager.train_model(gen_dataset=self.gen_dataset_sg_cora_link, steps=10, save_model_flag=False)
411+
412+
# ---------- Attack on structure ----------
413+
evasion_attack_config = ConfigPattern(
414+
_class_name="FGSM",
415+
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
416+
_config_class="EvasionAttackConfig",
417+
_config_kwargs={
418+
"is_feature_attack": False,
419+
"element_idx": (1, 2),
420+
"epsilon": 0.5,
421+
}
422+
)
423+
424+
gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)
425+
426+
# Attack
427+
_ = gnn_model_manager.evaluate_model(gen_dataset=self.gen_dataset_sg_cora_link,
428+
metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
429+
# ---------- ------------------- ----------
430+
431+
# ---------- Attack on feature ----------
432+
evasion_attack_config = ConfigPattern(
433+
_class_name="FGSM",
434+
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
435+
_config_class="EvasionAttackConfig",
436+
_config_kwargs={
437+
"is_feature_attack": True,
438+
"element_idx": (1, 2),
439+
"epsilon": 0.5,
440+
}
441+
)
442+
443+
gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)
444+
445+
# Attack
446+
_ = gnn_model_manager.evaluate_model(gen_dataset=self.gen_dataset_sg_cora_link,
447+
metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
448+
# ---------- ----------------- ----------
449+
371450
def test_fgsm_MG(self):
372451
gcn_gcn = model_configs_zoo(dataset=self.gen_dataset_mg_small, model_name='gin_gin_gin_lin_lin_con')
373452

tests/explainers_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def setUp(self) -> None:
226226
gnn=self.sage_cossim,
227227
dataset_path=self.gen_dataset_sg_cora_link.prepared_dir,
228228
manager_config=manager_config,
229-
modification=ModelModificationConfig(model_ver_ind=0, epochs=10)
229+
modification=ModelModificationConfig(model_ver_ind=0, epochs=0)
230230
)
231231
self.sage_cossim_mm.train_model(
232232
gen_dataset=self.gen_dataset_sg_cora_link, steps=10,

0 commit comments

Comments
 (0)