|
11 | 11 | ConfigPattern, FeatureConfig, Task |
12 | 12 | from gnn_aid.models_builder.models_zoo import model_configs_zoo |
13 | 13 | from gnn_aid.aux.utils import POISON_ATTACK_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \ |
14 | | - OPTIMIZERS_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH |
| 14 | + OPTIMIZERS_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH, FUNCTIONS_PARAMETERS_PATH |
15 | 15 | from .utils import monkey_patch_dirs, cleanup_patches |
16 | 16 |
|
17 | 17 |
|
@@ -75,6 +75,15 @@ def setUp(self): |
75 | 75 | } |
76 | 76 | ) |
77 | 77 |
|
| 78 | + # Cora for link pred |
| 79 | + dc = DatasetConfig((LibPTGDataset.data_folder, 'Homogeneous', 'Planetoid', 'Cora')) |
| 80 | + dvc = LibPTGDataset.default_dataset_var_config.clone_with({"task": Task.EDGE_PREDICTION}) |
| 81 | + |
| 82 | + self.gen_dataset_sg_cora_link = DatasetManager.get_by_config(dc, dvc) |
| 83 | + self.gen_dataset_sg_cora_link.train_test_split(percent_train_class=0.85, percent_test_class=0.1) |
| 84 | + self.results_dataset_path_sg_cora_link = self.gen_dataset_sg_cora_link.prepared_dir |
| 85 | + self.gen_dataset_sg_cora_link.data.to(self.my_device) |
| 86 | + |
78 | 87 | monkey_patch_dirs() |
79 | 88 |
|
80 | 89 | def tearDown(self): |
@@ -368,6 +377,76 @@ def test_fgsm_SG(self): |
368 | 377 | metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy'] |
369 | 378 | # ---------- ----------------- ---------- |
370 | 379 |
|
| 380 | + def test_fgsm_LINK(self): |
| 381 | + sage_cossim = model_configs_zoo(dataset=self.gen_dataset_sg_cora_link, model_name="sage_cossim") |
| 382 | + |
| 383 | + manager_config = ConfigPattern( |
| 384 | + _config_class="ModelManagerConfig", |
| 385 | + _config_kwargs={ |
| 386 | + "batch": 64, |
| 387 | + "mask_features": [], |
| 388 | + "optimizer": { |
| 389 | + "_class_name": "Adam", |
| 390 | + "_config_kwargs": {}, |
| 391 | + }, |
| 392 | + "loss_function": { |
| 393 | + "_config_class": "Config", |
| 394 | + "_class_name": "CrossEntropyLoss", |
| 395 | + "_import_path": FUNCTIONS_PARAMETERS_PATH, |
| 396 | + "_class_import_info": ["torch.nn"], |
| 397 | + "_config_kwargs": {}, |
| 398 | + }, |
| 399 | + } |
| 400 | + ) |
| 401 | + |
| 402 | + gnn_model_manager = FrameworkGNNModelManager( |
| 403 | + gnn=sage_cossim, |
| 404 | + dataset_path=self.gen_dataset_sg_cora_link.prepared_dir, |
| 405 | + manager_config=manager_config, |
| 406 | + modification=ModelModificationConfig(model_ver_ind=0, epochs=0) |
| 407 | + ) |
| 408 | + |
| 409 | + gnn_model_manager.gnn.to(self.my_device) |
| 410 | + gnn_model_manager.train_model(gen_dataset=self.gen_dataset_sg_cora_link, steps=10, save_model_flag=False) |
| 411 | + |
| 412 | + # ---------- Attack on structure ---------- |
| 413 | + evasion_attack_config = ConfigPattern( |
| 414 | + _class_name="FGSM", |
| 415 | + _import_path=EVASION_ATTACK_PARAMETERS_PATH, |
| 416 | + _config_class="EvasionAttackConfig", |
| 417 | + _config_kwargs={ |
| 418 | + "is_feature_attack": False, |
| 419 | + "element_idx": (1, 2), |
| 420 | + "epsilon": 0.5, |
| 421 | + } |
| 422 | + ) |
| 423 | + |
| 424 | + gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config) |
| 425 | + |
| 426 | + # Attack |
| 427 | + _ = gnn_model_manager.evaluate_model(gen_dataset=self.gen_dataset_sg_cora_link, |
| 428 | + metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy'] |
| 429 | + # ---------- ------------------- ---------- |
| 430 | + |
| 431 | + # ---------- Attack on feature ---------- |
| 432 | + evasion_attack_config = ConfigPattern( |
| 433 | + _class_name="FGSM", |
| 434 | + _import_path=EVASION_ATTACK_PARAMETERS_PATH, |
| 435 | + _config_class="EvasionAttackConfig", |
| 436 | + _config_kwargs={ |
| 437 | + "is_feature_attack": True, |
| 438 | + "element_idx": (1, 2), |
| 439 | + "epsilon": 0.5, |
| 440 | + } |
| 441 | + ) |
| 442 | + |
| 443 | + gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config) |
| 444 | + |
| 445 | + # Attack |
| 446 | + _ = gnn_model_manager.evaluate_model(gen_dataset=self.gen_dataset_sg_cora_link, |
| 447 | + metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy'] |
| 448 | + # ---------- ----------------- ---------- |
| 449 | + |
371 | 450 | def test_fgsm_MG(self): |
372 | 451 | gcn_gcn = model_configs_zoo(dataset=self.gen_dataset_mg_small, model_name='gin_gin_gin_lin_lin_con') |
373 | 452 |
|
|
0 commit comments