Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions data/example/example/raw/edge_attributes/weight
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"11,13": 0.2,
"11,15": 0.3,
"12,13": 0.4,
"12,14": 0.3,
"12,17": 0.5,
"15,14": 0.6,
"15,16": 0.7
"14,15": 0.6,
"15,16": 0.7,
"16,17": 0.3
}

2 changes: 2 additions & 0 deletions data/example/example/raw/edges.ij
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
11 13
11 15
12 13
12 14
12 17
15 14
15 16
16 17
8 changes: 5 additions & 3 deletions data/example/example/raw/labels/edge-regression/regression
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"11,13": 0.13,
"11,15": 0.1,
"12,13": 0.3,
"12,14": 0.39,
"12,17": 0.22,
"15,14": 0.35,
"15,16": 0.4
}
"14,15": 0.35,
"15,16": 0.4,
"16,17": 0.3
}
3 changes: 3 additions & 0 deletions docs/source/api/aux.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ aux
.. automodule:: aux.declaration
:members:

.. automodule:: aux.prefix_storage
:members:

5 changes: 0 additions & 5 deletions docs/source/api/data_structures.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ data_structures
:toctree: generated


Prefix storage
==============
.. automodule:: data_structures.prefix_storage
:members:

Configs
=======
.. automodule:: data_structures.configs
Expand Down
4 changes: 0 additions & 4 deletions docs/source/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ Short overview
datasets.ptg_datasets.PTGDataset
datasets.known_format_datasets.KnownFormatDataset
datasets.datasets_manager.DatasetManager
datasets.visible_part.DatasetData
datasets.visible_part.DatasetVarData


Base classes
Expand Down Expand Up @@ -47,5 +45,3 @@ Additional dataset related modules
:members:
.. automodule:: datasets.dataset_stats
:members:
.. automodule:: datasets.visible_part
:members:
2 changes: 0 additions & 2 deletions docs/source/api/models_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ Base classes
:members:
.. automodule:: models_builder.gnn_constructor
:members:
.. automodule:: models_builder.gnn_models
:members:
.. automodule:: models_builder.models_utils
:members:
.. automodule:: models_builder.models_zoo
Expand Down
39 changes: 22 additions & 17 deletions experiments/various_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,20 @@ def link_prediction():
manager_config = ConfigPattern(
_config_class="ModelManagerConfig",
_config_kwargs={
"batch": 64,
"batch": 128,
"mask_features": [],
"optimizer": {
"_class_name": "Adam",
"_config_kwargs": {},
},
"loss_function": {
"_config_class": "Config",
"_class_name": "CrossEntropyLoss",
"_class_name": "BCEWithLogitsLoss",
"_import_path": FUNCTIONS_PARAMETERS_PATH,
"_class_import_info": ["torch.nn"],
"_config_kwargs": {},
},
"neg_samples_ratio": 2
}
)

Expand All @@ -208,10 +209,14 @@ def link_prediction():
)
print("Training was successful")

# res = gnn_model_manager.run_model(
# gen_dataset=gen_dataset,
# mask='all'
# )
res = gnn_model_manager.run_model(
gen_dataset=gen_dataset,
mask='all',
out='predictions'
)
print(json.dumps(res.tolist(), indent=2))
return

from gnn_aid.aux.utils import EVASION_ATTACK_PARAMETERS_PATH
evasion_attack_config = ConfigPattern(
_class_name="FGSM",
Expand All @@ -227,17 +232,17 @@ def link_prediction():
# атака
# gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)
#
# res = gnn_model_manager.evaluate_model(
# gen_dataset=gen_dataset,
# metrics=[
# Metric("Accuracy", mask='all'),
# Metric("AUC", mask='test'),
# # Metric("Precision@k", mask='test', k=50),
# # Metric("Precision@k", mask='test', k=500000),
# # Metric("Recall@k", mask='test', k=500000),
# ]
# )
# print(json.dumps(res, indent=2))
res = gnn_model_manager.evaluate_model(
gen_dataset=gen_dataset,
metrics=[
Metric("Accuracy", mask='all'),
Metric("AUC", mask='test'),
# Metric("Precision@k", mask='test', k=50),
# Metric("Precision@k", mask='test', k=500000),
# Metric("Recall@k", mask='test', k=500000),
]
)
print(json.dumps(res, indent=2))

# explainer
from gnn_aid.aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH
Expand Down
1 change: 0 additions & 1 deletion gnn_aid/attacks/evasion_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,4 +884,3 @@ def dataset_diff(
self
) -> GraphModificationArtifact:
return self.attack_diff

23 changes: 16 additions & 7 deletions gnn_aid/attacks/mi_attacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -33,7 +33,7 @@ def compute_single_attack_accuracy(
inferred_labels: torch.Tensor,
mask_true: torch.Tensor,
train_class_label: bool = True
) -> float:
) -> Union[float, Dict]:
"""
Computes accuracy for a single attack result (mask + inferred labels pair).

Expand All @@ -43,7 +43,7 @@ def compute_single_attack_accuracy(
mask_true: Tensor of true labels for all nodes in the graph

Returns:
float: Accuracy (0.0 to 1.0) of correct predictions among attacked samples
float: Dict with metrics for predictions among attacked samples.
Returns 0.0 if no samples were attacked
"""
metrics = {
Expand Down Expand Up @@ -86,6 +86,7 @@ def compute_single_attack_accuracy(

return metrics


class EmptyMIAttacker(
MIAttacker
):
Expand Down Expand Up @@ -119,6 +120,8 @@ def attack(
mask_tensor: Union[str, List[bool], torch.Tensor],
):
assert not isinstance(mask_tensor, str), "Input of original mask seems senseless"
if isinstance(mask_tensor, list):
mask_tensor = torch.tensor(mask_tensor)

model.eval()

Expand All @@ -131,9 +134,13 @@ def attack(

return self.results

class ShadowModelMIAttacker(MIAttacker):

class ShadowModelMIAttacker(
MIAttacker
):
"""
Shadow model-based membership inference attack for Node/Graph Classification.
The surrogate model is trained on a part of the initial dataset.
The classifier learns from its responses to determine whether the input is from train or test
"""
name = "ShadowModelMIAttacker"

Expand Down Expand Up @@ -255,12 +262,14 @@ def attack(
self,
model: torch.nn.Module,
gen_dataset: GeneralDataset,
mask_tensor: Union[torch.Tensor, list],
mask_tensor: Union[List[bool], torch.Tensor],
**kwargs
):
"""
Perform membership inference attack using shadow model technique.
"""
if isinstance(mask_tensor, list):
mask_tensor = torch.tensor(mask_tensor)
task_type = gen_dataset.is_multi()
self.model_name = 'gcn_gcn_lin_no_softmax' if task_type else 'gcn_gcn_no_softmax'

Expand Down Expand Up @@ -494,4 +503,4 @@ def attack(

self.results.add(mask_tensor, inferred_membership_full)
members_count = inferred_membership_full.sum().item()
return self.results
return self.results
2 changes: 2 additions & 0 deletions gnn_aid/aux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .declaration import Declare
from .prefix_storage import TuplePrefixStorage, FixedKeysPrefixStorage
from .data_info import DataInfo, UserCodeInfo
from .utils import ProgressBar

__all__ = [
'timing_decorator',
Expand All @@ -12,4 +13,5 @@
'FixedKeysPrefixStorage',
'DataInfo',
'UserCodeInfo',
'ProgressBar',
]
8 changes: 3 additions & 5 deletions gnn_aid/aux/data_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,9 @@ def values_list_and_technical_files_by_path_and_prefix(
for file_info_dict in val["files_info"]:
if file_info_dict["file_name"] == "origin":
file_name = path[parts_parse].strip()
else:
file_name = file_info_dict["file_name"]
file_name += file_info_dict["format"]
description_info.update(
{key: {parts_val[-1]: os.path.join(*path[:parts_parse], file_name)}})
file_name += file_info_dict["format"]
description_info.update(
{key: {parts_val[-1]: os.path.join(*path[:parts_parse], file_name)}})
parts_parse += 1
return parts_val, description_info

Expand Down
56 changes: 28 additions & 28 deletions gnn_aid/aux/prefix_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,34 +542,34 @@ def remap(
Change keys order and combination.
"""
raise NotImplementedError
# Check consistency
ms = set()
for m in mapping:
if isinstance(m, (list, tuple)):
ms.update(m)
else:
ms.add(m)
assert ms == set(range(len(self.keys))),\
f"mapping should contain key indices from 0 to {len(self.keys)-1}"

keys = [",".join(self.keys[i] for i in m) if isinstance(m, (list, tuple)) else self.keys[m]
for m in mapping]
ps = FixedKeysPrefixStorage(keys)

for item in self:
values = []
for m in mapping:
if isinstance(m, (list, tuple)):
if only_values:
v = ",".join(str(item[i]) for i in m)
else:
v = ",".join(f"{self.keys[i]}={item[i]}" for i in m)
else:
v = item[m]
values.append(v)
ps.add(values)

return ps
# # Check consistency
# ms = set()
# for m in mapping:
# if isinstance(m, (list, tuple)):
# ms.update(m)
# else:
# ms.add(m)
# assert ms == set(range(len(self.keys))),\
# f"mapping should contain key indices from 0 to {len(self.keys)-1}"
#
# keys = [",".join(self.keys[i] for i in m) if isinstance(m, (list, tuple)) else self.keys[m]
# for m in mapping]
# ps = FixedKeysPrefixStorage(keys)
#
# for item in self:
# values = []
# for m in mapping:
# if isinstance(m, (list, tuple)):
# if only_values:
# v = ",".join(str(item[i]) for i in m)
# else:
# v = ",".join(f"{self.keys[i]}={item[i]}" for i in m)
# else:
# v = item[m]
# values.append(v)
# ps.add(values)
#
# return ps

def __str__(
self
Expand Down
Loading