Skip to content

Commit 6b0d987

Browse files
Merge pull request #98 from ispras/extend_models_stage3
Extend models stage3
2 parents 257450d + 59af08e commit 6b0d987

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2660
-2135
lines changed

experiments/attack_defense_exps.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import torch
44
from torch import device
55

6-
from data_structures.configs import ModelModificationConfig, ConfigPattern, DatasetConfig, Task
7-
from datasets.datasets_manager import DatasetManager
8-
from datasets.ptg_datasets import LibPTGDataset
9-
from models_builder.gnn_models import FrameworkGNNModelManager, Metric
10-
from models_builder.models_zoo import model_configs_zoo
11-
from aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, \
6+
from gnn_aid.data_structures.configs import ModelModificationConfig, ConfigPattern, DatasetConfig, Task
7+
from gnn_aid.datasets.datasets_manager import DatasetManager
8+
from gnn_aid.datasets.ptg_datasets import LibPTGDataset
9+
from gnn_aid.models_builder import Metric
10+
from gnn_aid.models_builder.model_managers import FrameworkGNNModelManager
11+
from gnn_aid.models_builder.models_zoo import model_configs_zoo
12+
from gnn_aid.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, \
1213
EVASION_ATTACK_PARAMETERS_PATH, \
1314
EVASION_DEFENSE_PARAMETERS_PATH
1415

@@ -758,8 +759,8 @@ def test_adv_training():
758759
# "num_nodes": dataset.dataset.x.shape[0]
759760
}
760761
)
761-
from defenses.evasion_defense import EvasionDefender
762-
from aux.utils import all_subclasses
762+
from gnn_aid.defenses.evasion_defense import EvasionDefender
763+
from gnn_aid.aux.utils import all_subclasses
763764
print([e.name for e in all_subclasses(EvasionDefender)])
764765
gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config)
765766

@@ -1375,6 +1376,7 @@ def test_rewatt():
13751376
DatasetConfig(full_name),
13761377
LibPTGDataset.default_dataset_var_config.clone_with({"task": Task.NODE_CLASSIFICATION})
13771378
)
1379+
data = dataset.data
13781380
data.to(my_device)
13791381

13801382
gcn_gcn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn')

experiments/attack_defense_metric_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import torch
55
from torch import device
66

7-
from datasets.datasets_manager import DatasetManager
8-
from datasets.ptg_datasets import LibPTGDataset
9-
from models_builder.attack_defense_manager import FrameworkAttackDefenseManager
10-
from models_builder.attack_defense_metric import AttackMetric, DefenseMetric
11-
from aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
7+
from gnn_aid.datasets.datasets_manager import DatasetManager
8+
from gnn_aid.datasets.ptg_datasets import LibPTGDataset
9+
from gnn_aid.models_builder.attack_defense_manager import FrameworkAttackDefenseManager
10+
from gnn_aid.models_builder.attack_defense_metric import AttackMetric, DefenseMetric
11+
from gnn_aid.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
1212
EVASION_DEFENSE_PARAMETERS_PATH
13-
from models_builder.gnn_models import FrameworkGNNModelManager
14-
from data_structures.configs import ModelModificationConfig, ConfigPattern, DatasetConfig, Task
15-
from models_builder.models_zoo import model_configs_zoo
13+
from gnn_aid.models_builder.model_managers import FrameworkGNNModelManager
14+
from gnn_aid.data_structures.configs import ModelModificationConfig, ConfigPattern, DatasetConfig, Task
15+
from gnn_aid.models_builder.models_zoo import model_configs_zoo
1616

1717

1818
def attack_defense_metrics():

experiments/attack_defense_test.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,21 @@
44

55
from torch import device
66

7-
from datasets.ptg_datasets import LibPTGDataset
8-
from models_builder.models_utils import apply_decorator_to_graph_layers
9-
from aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
7+
from gnn_aid.datasets.ptg_datasets import LibPTGDataset
8+
from gnn_aid.models_builder.models_utils import apply_decorator_to_graph_layers
9+
from gnn_aid.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
1010
EVASION_DEFENSE_PARAMETERS_PATH
11-
from models_builder.gnn_models import FrameworkGNNModelManager, Metric
12-
from data_structures.configs import ModelModificationConfig, ConfigPattern, DatasetConfig, Task
13-
from datasets.datasets_manager import DatasetManager
14-
from models_builder.models_zoo import model_configs_zoo
15-
from attacks.qattack import qattack
11+
from gnn_aid.models_builder.model_managers import FrameworkGNNModelManager
12+
from gnn_aid.models_builder.models_utils import Metric
13+
from gnn_aid.data_structures.configs import ModelModificationConfig, ConfigPattern, DatasetConfig, Task
14+
from gnn_aid.datasets.datasets_manager import DatasetManager
15+
from gnn_aid.models_builder.models_zoo import model_configs_zoo
16+
from gnn_aid.attacks.qattack import qattack
1617
# from attacks.RL_S2V.rl_s2v import RLS2VAttacker
17-
from defenses.jaccard_defense import jaccard_def
18-
from attacks.metattack import meta_gradient_attack
19-
from defenses.gnn_guard import gnnguard
20-
from defenses.pro_gnn.prognn import ProGNNDefender
18+
from gnn_aid.defenses.jaccard_defense import jaccard_def
19+
from gnn_aid.attacks.metattack import meta_gradient_attack
20+
from gnn_aid.defenses.gnn_guard import gnnguard
21+
from gnn_aid.defenses.pro_gnn.prognn import ProGNNDefender
2122

2223

2324
def test_attack_defense_small():
@@ -206,7 +207,7 @@ def test_attack_defense_small():
206207

207208

208209
def test_attack_defense():
209-
from attacks.clga import CLGA
210+
from gnn_aid.attacks.clga import CLGA
210211

211212
my_device = device('cuda' if torch.cuda.is_available() else 'cpu')
212213

@@ -980,8 +981,8 @@ def test_adv_training():
980981
# "num_nodes": dataset.dataset.x.shape[0]
981982
}
982983
)
983-
from defenses.evasion_defense import EvasionDefender
984-
from aux.utils import all_subclasses
984+
from gnn_aid.defenses.evasion_defense import EvasionDefender
985+
from gnn_aid.aux.utils import all_subclasses
985986
print([e.name for e in all_subclasses(EvasionDefender)])
986987
gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config)
987988

experiments/backend_demo_ase.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
from torch import device
77

8-
from aux.utils import EVASION_ATTACK_PARAMETERS_PATH, EVASION_DEFENSE_PARAMETERS_PATH
9-
from datasets.ptg_datasets import LibPTGDataset
10-
from models_builder.gnn_models import FrameworkGNNModelManager, Metric
11-
from data_structures.configs import ModelModificationConfig, ConfigPattern, DatasetConfig, Task
12-
from datasets.datasets_manager import DatasetManager
13-
from models_builder.models_zoo import model_configs_zoo
8+
from gnn_aid.aux.utils import EVASION_ATTACK_PARAMETERS_PATH, EVASION_DEFENSE_PARAMETERS_PATH
9+
from gnn_aid.datasets.ptg_datasets import LibPTGDataset
10+
from gnn_aid.models_builder.model_managers import FrameworkGNNModelManager
11+
from gnn_aid.models_builder.models_utils import Metric
12+
from gnn_aid.data_structures.configs import ModelModificationConfig, ConfigPattern, DatasetConfig, Task
13+
from gnn_aid.datasets.datasets_manager import DatasetManager
14+
from gnn_aid.models_builder.models_zoo import model_configs_zoo
1415

1516

1617
def test_attack_defense_small():

experiments/user_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import tensor
33
from torch_geometric.data import InMemoryDataset, Data, Dataset
44

5-
from datasets.datasets_manager import DatasetManager
5+
from gnn_aid.datasets.datasets_manager import DatasetManager
66

77

88
# Example of local user PTG dataset

experiments/various_tasks.py

Lines changed: 152 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
import json
2+
13
import torch
24
from torch import device
35

46
from gnn_aid.attacks import Attacker
57
from gnn_aid.aux.utils import FUNCTIONS_PARAMETERS_PATH, all_subclasses
8+
from gnn_aid.data_structures import ModelStructureConfig, ModelConfig
69
from gnn_aid.data_structures.configs import DatasetConfig, DatasetVarConfig, FeatureConfig, Task, \
710
ConfigPattern, ModelModificationConfig
811
from gnn_aid.datasets.datasets_manager import DatasetManager
912
from gnn_aid.datasets.ptg_datasets import LibPTGDataset
10-
from gnn_aid.models_builder.gnn_models import FrameworkGNNModelManager, Metric
13+
from gnn_aid.models_builder import FrameworkGNNConstructor
14+
from gnn_aid.models_builder.models_utils import Metric
15+
from gnn_aid.models_builder.model_managers import FrameworkGNNModelManager
1116
from gnn_aid.models_builder.models_zoo import model_configs_zoo
1217

1318

@@ -103,9 +108,32 @@ def link_prediction():
103108

104109
gen_dataset = DatasetManager.get_by_config(dc, dvc)
105110
print(gen_dataset.data)
106-
gen_dataset.train_test_split()
111+
gen_dataset.train_test_split(percent_train_class=0.85, percent_test_class=0.1)
112+
113+
gnn = FrameworkGNNConstructor(
114+
model_config=ModelConfig(
115+
structure=ModelStructureConfig(
116+
[
117+
{
118+
'label': 'n',
119+
'layer': {
120+
'layer_name': 'SAGEConv',
121+
'layer_kwargs': {
122+
'in_channels': gen_dataset.num_node_features,
123+
'out_channels': 16,
124+
},
125+
},
126+
},
127+
{
128+
'label': 'd',
129+
'function': {
130+
'function_name': 'CosineSimilarity',
131+
'function_kwargs': None
132+
}
133+
}
134+
]
135+
)))
107136

108-
gnn = model_configs_zoo(dataset=gen_dataset, model_name='gcn_gcn')
109137
manager_config = ConfigPattern(
110138
_config_class="ModelManagerConfig",
111139
_config_kwargs={
@@ -125,7 +153,7 @@ def link_prediction():
125153
}
126154
)
127155

128-
steps_epochs = 30
156+
steps_epochs = 3
129157
my_device = device('cuda' if torch.cuda.is_available() else 'cpu')
130158
gnn_model_manager = FrameworkGNNModelManager(
131159
gnn=gnn,
@@ -145,13 +173,130 @@ def link_prediction():
145173
)
146174
print("Training was successful")
147175

176+
res = gnn_model_manager.run_model(
177+
gen_dataset=gen_dataset,
178+
mask='all'
179+
)
180+
res = gnn_model_manager.evaluate_model(
181+
gen_dataset=gen_dataset,
182+
metrics=[Metric("Accuracy", mask="test")]
183+
)
184+
print(json.dumps(res, indent=2))
185+
186+
# gnn.eval()
187+
# z = gnn.encode(data.x, data.edge_index)
188+
# out = gnn.decode(z, data.edge_label_index).view(-1).sigmoid()
189+
# return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
190+
191+
192+
def ptg_example():
193+
import os.path as osp
194+
195+
import torch
196+
from sklearn.metrics import roc_auc_score
197+
198+
import torch_geometric.transforms as T
199+
from torch_geometric.datasets import Planetoid
200+
from torch_geometric.nn import GCNConv
201+
from torch_geometric.utils import negative_sampling
202+
203+
if torch.cuda.is_available():
204+
device = torch.device('cuda')
205+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
206+
device = torch.device('mps')
207+
else:
208+
device = torch.device('cpu')
209+
210+
transform = T.Compose([
211+
T.NormalizeFeatures(),
212+
T.ToDevice(device),
213+
T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
214+
add_negative_train_samples=False),
215+
])
216+
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
217+
dataset = Planetoid(path, name='Cora', transform=transform)
218+
# After applying the `RandomLinkSplit` transform, the data is transformed from
219+
# a data object to a list of tuples (train_data, val_data, test_data), with
220+
# each element representing the corresponding split.
221+
train_data, val_data, test_data = dataset[0]
222+
223+
class Net(torch.nn.Module):
224+
def __init__(self, in_channels, hidden_channels, out_channels):
225+
super().__init__()
226+
self.conv1 = GCNConv(in_channels, hidden_channels)
227+
self.conv2 = GCNConv(hidden_channels, out_channels)
228+
229+
def encode(self, x, edge_index):
230+
x = self.conv1(x, edge_index).relu()
231+
return self.conv2(x, edge_index)
232+
233+
def decode(self, z, edge_label_index):
234+
return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
235+
236+
def decode_all(self, z):
237+
prob_adj = z @ z.t()
238+
return (prob_adj > 0).nonzero(as_tuple=False).t()
239+
240+
model = Net(dataset.num_features, 128, 64).to(device)
241+
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
242+
criterion = torch.nn.BCEWithLogitsLoss()
243+
244+
def train():
245+
model.train()
246+
optimizer.zero_grad()
247+
z = model.encode(train_data.x, train_data.edge_index)
248+
249+
# We perform a new round of negative sampling for every training epoch:
250+
neg_edge_index = negative_sampling(
251+
edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
252+
num_neg_samples=train_data.edge_label_index.size(1), method='sparse')
253+
254+
edge_label_index = torch.cat(
255+
[train_data.edge_label_index, neg_edge_index],
256+
dim=-1,
257+
)
258+
edge_label = torch.cat([
259+
train_data.edge_label,
260+
train_data.edge_label.new_zeros(neg_edge_index.size(1))
261+
], dim=0)
262+
263+
out = model.decode(z, edge_label_index).view(-1)
264+
loss = criterion(out, edge_label)
265+
loss.backward()
266+
optimizer.step()
267+
return loss
268+
269+
@torch.no_grad()
270+
def test(data):
271+
model.eval()
272+
z = model.encode(data.x, data.edge_index)
273+
out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
274+
return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
275+
276+
best_val_auc = final_test_auc = 0
277+
for epoch in range(1, 101):
278+
loss = train()
279+
val_auc = test(val_data)
280+
test_auc = test(test_data)
281+
if val_auc > best_val_auc:
282+
best_val_auc = val_auc
283+
final_test_auc = test_auc
284+
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
285+
f'Test: {test_auc:.4f}')
286+
287+
print(f'Final Test: {final_test_auc:.4f}')
288+
289+
z = model.encode(test_data.x, test_data.edge_index)
290+
final_edge_index = model.decode_all(z)
291+
148292

149293
if __name__ == '__main__':
150294
# node_regression()
151295
# graph_regression()
152296

153297
# edge_regression()
154-
# link_prediction()
298+
link_prediction()
155299

156-
for c in all_subclasses(Attacker):
157-
print(c)
300+
# ptg_example()
301+
# for c in all_subclasses(Attacker):
302+
# print(c)

gnn_aid/attacks/attack_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from gnn_aid.datasets.gen_dataset import GeneralDataset
22
from gnn_aid.data_structures.graph_modification_artifacts import GraphModificationArtifact
3-
from gnn_aid.models_builder.gnn_models import GNNModelManager
3+
from gnn_aid.models_builder.model_managers import GNNModelManager
44

55

66
class Attacker:

gnn_aid/attacks/evasion_attacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from gnn_aid.attacks.attack_base import Attacker
66
from gnn_aid.aux.utils import move_to_same_device
77
from gnn_aid.datasets.gen_dataset import GeneralDataset
8-
from gnn_aid.models_builder.gnn_models import GNNModelManager
8+
from gnn_aid.models_builder.model_managers import GNNModelManager
99

1010
# Nettack imports
1111
from .evasion_attacks_collection.nettack.utils import NettackSurrogate, NettackAttack

gnn_aid/attacks/metattack/meta_gradient_attack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from gnn_aid.aux.utils import OPTIMIZERS_PARAMETERS_PATH, move_to_same_device
1414
from gnn_aid.data_structures.configs import ModelModificationConfig, ConfigPattern
1515
from gnn_aid.datasets.gen_dataset import GeneralDataset
16-
from gnn_aid.models_builder.gnn_models import FrameworkGNNModelManager, GNNModelManager
16+
from gnn_aid.models_builder.model_managers import GNNModelManager, FrameworkGNNModelManager
1717
from gnn_aid.models_builder.models_zoo import model_configs_zoo
1818
from . import utils
1919

gnn_aid/data_structures/configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ def __init__(
616616
@property
617617
def task(
618618
self
619-
) -> Union[str, dict]:
619+
) -> Task:
620620
return self["task"]
621621

622622
@property

0 commit comments

Comments
 (0)