Skip to content

Commit 47a3df3

Browse files
committed
Add CLGA LP test
1 parent 878efdf commit 47a3df3

File tree

1 file changed

+117
-2
lines changed

1 file changed

+117
-2
lines changed

tests/attacks_test.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1+
import copy
12
import unittest
23
import numpy as np
34
import torch
45

6+
from gnn_aid.attacks.clga.CLGA import CLGAAttack
57
from gnn_aid.attacks.mi_attacks import MIAttacker
68
from gnn_aid.datasets.datasets_manager import DatasetManager
79
from gnn_aid.datasets.ptg_datasets import LibPTGDataset
10+
from gnn_aid.models_builder import FrameworkGNNConstructor
811
from gnn_aid.models_builder.models_utils import Metric
912
from gnn_aid.models_builder.model_managers import FrameworkGNNModelManager
1013
from gnn_aid.data_structures.configs import ModelModificationConfig, DatasetConfig, DatasetVarConfig, \
11-
ConfigPattern, FeatureConfig, Task
14+
ConfigPattern, FeatureConfig, Task, ModelConfig, ModelStructureConfig
1215
from gnn_aid.models_builder.models_zoo import model_configs_zoo
1316
from gnn_aid.aux.utils import POISON_ATTACK_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
14-
OPTIMIZERS_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH
17+
OPTIMIZERS_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH, FUNCTIONS_PARAMETERS_PATH
1518
from .utils import monkey_patch_dirs, cleanup_patches
1619

1720

@@ -695,6 +698,118 @@ def test_mi_shadow_cora(self):
695698
print(f"MI Attack accuracy:"
696699
f" {MIAttacker.compute_single_attack_accuracy(mask, res, self.gen_dataset_sg_cora.train_mask)}")
697700

701+
def test_clga_link_prediction(self):
702+
gen_dataset = DatasetManager.get_by_config(
703+
DatasetConfig((LibPTGDataset.data_folder, "Homogeneous", "Planetoid", "Cora")),
704+
LibPTGDataset.default_dataset_var_config.clone_with({"task": Task.EDGE_PREDICTION})
705+
)
706+
707+
poison_attack_config = ConfigPattern(
708+
_class_name="CLGAAttack",
709+
_import_path=POISON_ATTACK_PARAMETERS_PATH,
710+
_config_class="PoisonAttackConfig",
711+
_config_kwargs={
712+
"learning_rate": 0.01,
713+
"num_epochs": 50,
714+
}
715+
)
716+
717+
gnn = FrameworkGNNConstructor(
718+
model_config=ModelConfig(
719+
structure=ModelStructureConfig([
720+
{
721+
'label': 'n',
722+
'layer': {
723+
'layer_name': 'GCNConv',
724+
'layer_kwargs': {'in_channels': gen_dataset.num_node_features, 'out_channels': 32}
725+
},
726+
'activation': {
727+
'activation_name': 'ReLU',
728+
'activation_kwargs': None
729+
}
730+
},
731+
{
732+
'label': 'n',
733+
'layer': {
734+
'layer_name': 'GCNConv',
735+
'layer_kwargs': {'in_channels': 32, 'out_channels': 16}
736+
}
737+
},
738+
{
739+
'label': 'd',
740+
'function': {
741+
'function_name': 'Concat',
742+
'function_kwargs': None
743+
}
744+
},
745+
{
746+
'label': 'd',
747+
'layer': {
748+
'layer_name': 'Linear',
749+
'layer_kwargs': {'in_features': 32, 'out_features': 16}
750+
},
751+
'activation': {
752+
'activation_name': 'ReLU',
753+
'activation_kwargs': None
754+
}
755+
},
756+
{
757+
'label': 'd',
758+
'layer': {
759+
'layer_name': 'Linear',
760+
'layer_kwargs': {'in_features': 16, 'out_features': 1}
761+
}
762+
}
763+
])
764+
)
765+
)
766+
767+
manager_config_lp = ConfigPattern(
768+
_config_class="ModelManagerConfig",
769+
_config_kwargs={
770+
"mask_features": [],
771+
"optimizer": {
772+
"_class_name": "Adam",
773+
"_config_kwargs": {
774+
"lr": 0.01,
775+
"weight_decay": 5e-4
776+
},
777+
},
778+
"loss_function": {
779+
"_class_name": "BCEWithLogitsLoss",
780+
"_import_path": FUNCTIONS_PARAMETERS_PATH,
781+
"_class_import_info": ["torch.nn"],
782+
"_config_kwargs": {}
783+
},
784+
"batch": 64
785+
}
786+
)
787+
788+
gnn_model_manager = FrameworkGNNModelManager(
789+
gnn=gnn,
790+
dataset_path=gen_dataset.prepared_dir,
791+
modification=ModelModificationConfig(model_ver_ind=0, epochs=30),
792+
manager_config=manager_config_lp,
793+
)
794+
795+
gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config)
796+
797+
gen_dataset.train_test_split(percent_train_class=0.85, percent_test_class=0.15)
798+
799+
gnn_model_manager.train_model(
800+
gen_dataset=gen_dataset,
801+
steps=30,
802+
metrics=[Metric("AUC", mask='train')]
803+
)
804+
805+
test_metrics = gnn_model_manager.evaluate_model(
806+
gen_dataset=gen_dataset,
807+
metrics=[Metric("AUC", mask='test'), Metric("Recall@k", mask='test', k=100)]
808+
)
809+
print("CLGA Link Prediction AUC:", test_metrics['test']['AUC'])
810+
811+
self.assertLess(test_metrics['test']['AUC'], 0.95)
812+
698813

699814
if __name__ == '__main__':
700815
unittest.main()

0 commit comments

Comments
 (0)