Skip to content

Commit d398e98

Browse files
committed
+
1 parent 75a0525 commit d398e98

File tree

6 files changed

+250
-190
lines changed

6 files changed

+250
-190
lines changed

experiments/attack_defense_test.py

Lines changed: 36 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
1010
EVASION_DEFENSE_PARAMETERS_PATH
1111
from models_builder.gnn_models import FrameworkGNNModelManager, Metric
12-
from data_structures.configs import ModelModificationConfig, ConfigPattern
12+
from data_structures.configs import ModelModificationConfig, ConfigPattern, DatasetConfig
1313
from datasets.datasets_manager import DatasetManager
1414
from models_builder.models_zoo import model_configs_zoo
1515
from attacks.qattack import qattack
@@ -1054,14 +1054,16 @@ def test_adv_training():
10541054

10551055
def test_pgd():
10561056
# ______________________ Attack on node ______________________
1057-
my_device = device('cuda')
1057+
my_device = device('cpu')
10581058

10591059
# Load dataset
1060-
full_name = (LibPTGDataset.data_folder, "single-graph", "Planetoid", "Cora")
1061-
dataset, data, results_dataset_path = DatasetManager.get_by_full_name(
1062-
full_name=full_name,
1063-
dataset_ver_ind=0
1060+
full_name = (LibPTGDataset.data_folder, "Homogeneous", "Planetoid", "Cora")
1061+
dataset = DatasetManager.get_by_config(
1062+
DatasetConfig(full_name),
1063+
# LibPTGDataset.default_dataset_var_config.clone_with({"task": Task.NODE_CLASSIFICATION})
10641064
)
1065+
dataset.train_test_split(percent_train_class=0.6, percent_test_class=0.4)
1066+
data = dataset.data
10651067
data.to(my_device)
10661068

10671069
gcn_gcn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn')
@@ -1079,14 +1081,14 @@ def test_pgd():
10791081

10801082
gnn_model_manager = FrameworkGNNModelManager(
10811083
gnn=gcn_gcn,
1082-
dataset_path=results_dataset_path,
1084+
dataset_path=dataset.prepared_dir,
10831085
manager_config=manager_config,
10841086
modification=ModelModificationConfig(model_ver_ind=0, epochs=0)
10851087
)
10861088

10871089
gnn_model_manager.gnn.to(my_device)
10881090

1089-
num_steps = 200
1091+
num_steps = 120
10901092
gnn_model_manager.train_model(gen_dataset=dataset,
10911093
steps=num_steps,
10921094
save_model_flag=False)
@@ -1099,18 +1101,18 @@ def test_pgd():
10991101
node_idx = 650
11001102

11011103
# Model prediction on a node before PGD attack on it
1102-
# gnn_model_manager.gnn.eval()
1103-
# with torch.no_grad():
1104-
# probabilities = torch.exp(gnn_model_manager.gnn(dataset.data.x, dataset.data.edge_index))
1105-
#
1106-
# predicted_class = probabilities[node_idx].argmax().item()
1107-
# predicted_probability = probabilities[node_idx][predicted_class].item()
1108-
# real_class = dataset.data.y[node_idx].item()
1109-
#
1110-
# info_before_pgd_attack_on_node = {"node_idx": node_idx,
1111-
# "predicted_class": predicted_class,
1112-
# "predicted_probability": predicted_probability,
1113-
# "real_class": real_class}
1104+
gnn_model_manager.gnn.eval()
1105+
with torch.no_grad():
1106+
probabilities = torch.exp(gnn_model_manager.gnn(dataset.data.x, dataset.data.edge_index))
1107+
1108+
predicted_class = probabilities[node_idx].argmax().item()
1109+
predicted_probability = probabilities[node_idx][predicted_class].item()
1110+
real_class = dataset.data.y[node_idx].item()
1111+
1112+
info_before_pgd_attack_on_node = {"node_idx": node_idx,
1113+
"predicted_class": predicted_class,
1114+
"predicted_probability": predicted_probability,
1115+
"real_class": real_class}
11141116

11151117
# Attack config
11161118
evasion_attack_config = ConfigPattern(
@@ -1134,119 +1136,20 @@ def test_pgd():
11341136
metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
11351137

11361138
# Model prediction on a node after PGD attack on it
1137-
# with torch.no_grad():
1138-
# probabilities = torch.exp(gnn_model_manager.gnn(gnn_model_manager.evasion_attacker.attack_diff.data.x,
1139-
# gnn_model_manager.evasion_attacker.attack_diff.data.edge_index))
1140-
#
1141-
# predicted_class = probabilities[node_idx].argmax().item()
1142-
# predicted_probability = probabilities[node_idx][predicted_class].item()
1143-
# real_class = dataset.data.y[node_idx].item()
1144-
#
1145-
# info_after_pgd_attack_on_node = {"node_idx": node_idx,
1146-
# "predicted_class": predicted_class,
1147-
# "predicted_probability": predicted_probability,
1148-
# "real_class": real_class}
1149-
# ____________________________________________________________
1150-
1151-
# ______________________ Attack on graph _____________________
1152-
# Load dataset
1153-
full_name = (LibPTGDataset.data_folder, "Homogeneous", "TUDataset", "MUTAG")
1154-
dataset, data, results_dataset_path = DatasetManager.get_by_full_name(
1155-
full_name=full_name,
1156-
dataset_ver_ind=0
1157-
)
1158-
data.to(my_device)
1159-
1160-
model = model_configs_zoo(dataset=dataset, model_name='gin_gin_gin_lin_lin_con')
1161-
1162-
manager_config = ConfigPattern(
1163-
_config_class="ModelManagerConfig",
1164-
_config_kwargs={
1165-
"mask_features": [],
1166-
"optimizer": {
1167-
"_class_name": "Adam",
1168-
"_config_kwargs": {},
1169-
}
1170-
}
1171-
)
1172-
1173-
gnn_model_manager = FrameworkGNNModelManager(
1174-
gnn=model,
1175-
dataset_path=results_dataset_path,
1176-
manager_config=manager_config,
1177-
modification=ModelModificationConfig(model_ver_ind=0, epochs=0)
1178-
)
1179-
1180-
gnn_model_manager.gnn.to(my_device)
1181-
1182-
num_steps = 200
1183-
gnn_model_manager.train_model(gen_dataset=dataset,
1184-
steps=num_steps,
1185-
save_model_flag=False)
1186-
1187-
acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset,
1188-
metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
1189-
print(f"Accuracy on test: {acc_test}")
1190-
1191-
# Graph for attack
1192-
graph_idx = 0
1193-
1194-
# Model prediction on a graph before PGD attack on it
1195-
# gnn_model_manager.gnn.eval()
1196-
# with torch.no_grad():
1197-
# probabilities = torch.exp(gnn_model_manager.gnn(dataset.dataset[graph_idx].x,
1198-
# dataset.dataset[graph_idx].edge_index))
1199-
#
1200-
# predicted_class = probabilities.argmax().item()
1201-
# predicted_probability = probabilities[0][predicted_class].item()
1202-
# real_class = dataset.dataset[graph_idx].y.item()
1203-
#
1204-
# info_before_pgd_attack_on_graph = {"graph_idx": graph_idx,
1205-
# "predicted_class": predicted_class,
1206-
# "predicted_probability": predicted_probability,
1207-
# "real_class": real_class}
1208-
1209-
# Attack config
1210-
evasion_attack_config = ConfigPattern(
1211-
_class_name="PGD",
1212-
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
1213-
_config_class="EvasionAttackConfig",
1214-
_config_kwargs={
1215-
"is_feature_attack": True,
1216-
"element_idx": graph_idx,
1217-
"epsilon": 0.1,
1218-
"learning_rate": 0.001,
1219-
"num_iterations": 500,
1220-
"random_sampling_num_trials": 100
1221-
}
1222-
)
1223-
1224-
gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)
1139+
with torch.no_grad():
1140+
probabilities = torch.exp(gnn_model_manager.gnn(gnn_model_manager.evasion_attacker.attack_res.data.x,
1141+
gnn_model_manager.evasion_attacker.attack_res.data.edge_index))
12251142

1226-
# Attack
1227-
_ = gnn_model_manager.evaluate_model(gen_dataset=dataset,
1228-
metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
1143+
predicted_class = probabilities[node_idx].argmax().item()
1144+
predicted_probability = probabilities[node_idx][predicted_class].item()
1145+
real_class = dataset.data.y[node_idx].item()
12291146

1230-
# Model prediction on a graph after PGD attack on it
1231-
# with torch.no_grad():
1232-
# probabilities = torch.exp(
1233-
# gnn_model_manager.gnn(gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].x,
1234-
# gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].edge_index))
1235-
#
1236-
# predicted_class = probabilities.argmax().item()
1237-
# predicted_probability = probabilities[0][predicted_class].item()
1238-
# real_class = dataset.dataset[graph_idx].y.item()
1239-
#
1240-
# info_after_pgd_attack_on_graph = {"graph_idx": graph_idx,
1241-
# "predicted_class": predicted_class,
1242-
# "predicted_probability": predicted_probability,
1243-
# "real_class": real_class}
1244-
#
1245-
# # ____________________________________________________________
1246-
# print(f"Before PGD attack on node (Cora dataset): {info_before_pgd_attack_on_node}")
1247-
# print(f"After PGD attack on node (Cora dataset): {info_after_pgd_attack_on_node}")
1248-
# print(f"Before PGD attack on graph (MUTAG dataset): {info_before_pgd_attack_on_graph}")
1249-
# print(f"After PGD attack on graph (MUTAG dataset): {info_after_pgd_attack_on_graph}")
1147+
info_after_pgd_attack_on_node = {"node_idx": node_idx,
1148+
"predicted_class": predicted_class,
1149+
"predicted_probability": predicted_probability,
1150+
"real_class": real_class}
1151+
print(f"Before PGD attack on node (Cora dataset): {info_before_pgd_attack_on_node}")
1152+
print(f"After PGD attack on node (Cora dataset): {info_after_pgd_attack_on_node}")
12501153

12511154

12521155
def test_pgd_structure():
@@ -1827,12 +1730,12 @@ def test_rewatt():
18271730

18281731
random.seed(10)
18291732
# test_attack_defense_small()
1830-
test_attack_defense()
1733+
# test_attack_defense()
18311734
# test_nettack_evasion()
18321735
# torch.manual_seed(5000)
18331736
# test_gnnguard()
18341737
# test_jaccard()
1835-
# test_pgd()
1738+
test_pgd()
18361739
# test_fgsm()
18371740
# test_pgd_structure()
18381741
# test_rewatt()

0 commit comments

Comments
 (0)