|
| 1 | +import copy |
1 | 2 | import unittest |
2 | 3 | import numpy as np |
3 | 4 | import torch |
4 | 5 |
|
| 6 | +from gnn_aid.attacks.clga.CLGA import CLGAAttack |
5 | 7 | from gnn_aid.attacks.mi_attacks import MIAttacker |
6 | 8 | from gnn_aid.datasets.datasets_manager import DatasetManager |
7 | 9 | from gnn_aid.datasets.ptg_datasets import LibPTGDataset |
| 10 | +from gnn_aid.models_builder import FrameworkGNNConstructor |
8 | 11 | from gnn_aid.models_builder.models_utils import Metric |
9 | 12 | from gnn_aid.models_builder.model_managers import FrameworkGNNModelManager |
10 | 13 | from gnn_aid.data_structures.configs import ModelModificationConfig, DatasetConfig, DatasetVarConfig, \ |
11 | | - ConfigPattern, FeatureConfig, Task |
| 14 | + ConfigPattern, FeatureConfig, Task, ModelConfig, ModelStructureConfig |
12 | 15 | from gnn_aid.models_builder.models_zoo import model_configs_zoo |
13 | 16 | from gnn_aid.aux.utils import POISON_ATTACK_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \ |
14 | | - OPTIMIZERS_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH |
| 17 | + OPTIMIZERS_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH, FUNCTIONS_PARAMETERS_PATH |
15 | 18 | from .utils import monkey_patch_dirs, cleanup_patches |
16 | 19 |
|
17 | 20 |
|
@@ -695,6 +698,118 @@ def test_mi_shadow_cora(self): |
695 | 698 | print(f"MI Attack accuracy:" |
696 | 699 | f" {MIAttacker.compute_single_attack_accuracy(mask, res, self.gen_dataset_sg_cora.train_mask)}") |
697 | 700 |
|
| 701 | + def test_clga_link_prediction(self): |
| 702 | + gen_dataset = DatasetManager.get_by_config( |
| 703 | + DatasetConfig((LibPTGDataset.data_folder, "Homogeneous", "Planetoid", "Cora")), |
| 704 | + LibPTGDataset.default_dataset_var_config.clone_with({"task": Task.EDGE_PREDICTION}) |
| 705 | + ) |
| 706 | + |
| 707 | + poison_attack_config = ConfigPattern( |
| 708 | + _class_name="CLGAAttack", |
| 709 | + _import_path=POISON_ATTACK_PARAMETERS_PATH, |
| 710 | + _config_class="PoisonAttackConfig", |
| 711 | + _config_kwargs={ |
| 712 | + "learning_rate": 0.01, |
| 713 | + "num_epochs": 50, |
| 714 | + } |
| 715 | + ) |
| 716 | + |
| 717 | + gnn = FrameworkGNNConstructor( |
| 718 | + model_config=ModelConfig( |
| 719 | + structure=ModelStructureConfig([ |
| 720 | + { |
| 721 | + 'label': 'n', |
| 722 | + 'layer': { |
| 723 | + 'layer_name': 'GCNConv', |
| 724 | + 'layer_kwargs': {'in_channels': gen_dataset.num_node_features, 'out_channels': 32} |
| 725 | + }, |
| 726 | + 'activation': { |
| 727 | + 'activation_name': 'ReLU', |
| 728 | + 'activation_kwargs': None |
| 729 | + } |
| 730 | + }, |
| 731 | + { |
| 732 | + 'label': 'n', |
| 733 | + 'layer': { |
| 734 | + 'layer_name': 'GCNConv', |
| 735 | + 'layer_kwargs': {'in_channels': 32, 'out_channels': 16} |
| 736 | + } |
| 737 | + }, |
| 738 | + { |
| 739 | + 'label': 'd', |
| 740 | + 'function': { |
| 741 | + 'function_name': 'Concat', |
| 742 | + 'function_kwargs': None |
| 743 | + } |
| 744 | + }, |
| 745 | + { |
| 746 | + 'label': 'd', |
| 747 | + 'layer': { |
| 748 | + 'layer_name': 'Linear', |
| 749 | + 'layer_kwargs': {'in_features': 32, 'out_features': 16} |
| 750 | + }, |
| 751 | + 'activation': { |
| 752 | + 'activation_name': 'ReLU', |
| 753 | + 'activation_kwargs': None |
| 754 | + } |
| 755 | + }, |
| 756 | + { |
| 757 | + 'label': 'd', |
| 758 | + 'layer': { |
| 759 | + 'layer_name': 'Linear', |
| 760 | + 'layer_kwargs': {'in_features': 16, 'out_features': 1} |
| 761 | + } |
| 762 | + } |
| 763 | + ]) |
| 764 | + ) |
| 765 | + ) |
| 766 | + |
| 767 | + manager_config_lp = ConfigPattern( |
| 768 | + _config_class="ModelManagerConfig", |
| 769 | + _config_kwargs={ |
| 770 | + "mask_features": [], |
| 771 | + "optimizer": { |
| 772 | + "_class_name": "Adam", |
| 773 | + "_config_kwargs": { |
| 774 | + "lr": 0.01, |
| 775 | + "weight_decay": 5e-4 |
| 776 | + }, |
| 777 | + }, |
| 778 | + "loss_function": { |
| 779 | + "_class_name": "BCEWithLogitsLoss", |
| 780 | + "_import_path": FUNCTIONS_PARAMETERS_PATH, |
| 781 | + "_class_import_info": ["torch.nn"], |
| 782 | + "_config_kwargs": {} |
| 783 | + }, |
| 784 | + "batch": 64 |
| 785 | + } |
| 786 | + ) |
| 787 | + |
| 788 | + gnn_model_manager = FrameworkGNNModelManager( |
| 789 | + gnn=gnn, |
| 790 | + dataset_path=gen_dataset.prepared_dir, |
| 791 | + modification=ModelModificationConfig(model_ver_ind=0, epochs=30), |
| 792 | + manager_config=manager_config_lp, |
| 793 | + ) |
| 794 | + |
| 795 | + gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config) |
| 796 | + |
| 797 | + gen_dataset.train_test_split(percent_train_class=0.85, percent_test_class=0.15) |
| 798 | + |
| 799 | + gnn_model_manager.train_model( |
| 800 | + gen_dataset=gen_dataset, |
| 801 | + steps=30, |
| 802 | + metrics=[Metric("AUC", mask='train')] |
| 803 | + ) |
| 804 | + |
| 805 | + test_metrics = gnn_model_manager.evaluate_model( |
| 806 | + gen_dataset=gen_dataset, |
| 807 | + metrics=[Metric("AUC", mask='test'), Metric("Recall@k", mask='test', k=100)] |
| 808 | + ) |
| 809 | + print("CLGA Link Prediction AUC:", test_metrics['test']['AUC']) |
| 810 | + |
| 811 | + self.assertLess(test_metrics['test']['AUC'], 0.95) |
| 812 | + |
698 | 813 |
|
699 | 814 | if __name__ == '__main__': |
700 | 815 | unittest.main() |
0 commit comments