-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathdefense_against_poisoning.py
More file actions
97 lines (76 loc) · 3.21 KB
/
defense_against_poisoning.py
File metadata and controls
97 lines (76 loc) · 3.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import warnings
import torch
from torch import device
from data_structures.configs import ModelModificationConfig, ConfigPattern
from datasets.datasets_manager import DatasetManager
from models_builder.gnn_models import FrameworkGNNModelManager, Metric
from models_builder.models_zoo import model_configs_zoo
from aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH
my_device = device('cuda' if torch.cuda.is_available() else 'cpu')
# Here we load Cora dataset and GIN_2l model the same way as in poisoning_attack.py
full_name = ("Homogeneous", "Planetoid", 'Cora')
torch.manual_seed(1234)
gen_dataset, data, results_dataset_path = DatasetManager.get_by_full_name(
full_name=full_name,
dataset_ver_ind=0
)
gnn = model_configs_zoo(dataset=gen_dataset, model_name='gin_gin')
manager_config = ConfigPattern(
_config_class="ModelManagerConfig",
_config_kwargs={
"mask_features": [],
"optimizer": {
"_class_name": "Adam",
"_config_kwargs": {},
}
}
)
steps_epochs = 200
gnn_model_manager = FrameworkGNNModelManager(
gnn=gnn,
dataset_path=results_dataset_path,
manager_config=manager_config,
modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs)
)
save_model_flag = False
# data.x = data.x.float()
gnn_model_manager.gnn.to(my_device)
data = data.to(my_device)
gen_dataset.data.to(my_device)
poison_attack_config = ConfigPattern(
_class_name="CLGAAttack",
_import_path=POISON_ATTACK_PARAMETERS_PATH,
_config_class="PoisonAttackConfig",
_config_kwargs={
"num_epochs": 300
}
)
# Here we set poison defense config
# You can see the available poison defense types and their default parameters in
# ./metainfo/poison_defense_parameters.json
jaccard_poison_defense_config = ConfigPattern(
_class_name="JaccardDefender",
_import_path=POISON_DEFENSE_PARAMETERS_PATH,
_config_class="PoisonDefenseConfig",
_config_kwargs={
"threshold": 0.2,
}
)
gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config)
gnn_model_manager.set_poison_defender(poison_defense_config=jaccard_poison_defense_config)
warnings.warn("Start training")
gen_dataset.train_test_split()
gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
train_test_split_path = gnn_model_manager.train_model(gen_dataset=gen_dataset, steps=steps_epochs,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)])
if train_test_split_path is not None:
gen_dataset.save_train_test_mask(train_test_split_path)
train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[:]
gen_dataset.train_mask, gen_dataset.val_mask, gen_dataset.test_mask = train_mask, val_mask, test_mask
data.percent_train_class, data.percent_test_class = train_test_sizes
warnings.warn("Training was successful")
metric_loc = gnn_model_manager.evaluate_model(
gen_dataset=gen_dataset, metrics=[Metric("F1", mask='test', average='macro'),
Metric("Accuracy", mask='test')])
print(metric_loc)