Skip to content

Commit 878efdf

Browse files
committed
Add LP-task for JaccardDefence
1 parent 534e04e commit 878efdf

File tree

3 files changed

+237
-74
lines changed

3 files changed

+237
-74
lines changed
Lines changed: 102 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,134 @@
11
import torch
22
import numpy as np
3+
from typing import Optional
34

45
from gnn_aid.datasets.gen_dataset import GeneralDataset
56
from gnn_aid.data_structures.graph_modification_artifacts import GraphModificationArtifact
67
from gnn_aid.defenses.poison_defense import PoisonDefender
8+
from gnn_aid.data_structures.configs import Task
79

810

9-
class JaccardDefender(
10-
PoisonDefender
11-
):
11+
def _is_binary_tensor(X: torch.Tensor) -> bool:
12+
return torch.all((X == 0) | (X == 1)).item()
13+
14+
15+
class JaccardDefender(PoisonDefender):
1216
"""
13-
Poison defense based on removing edges between dissimilar nodes
17+
Poison defense based on removing edges between dissimilar nodes.
1418
"""
1519
name = 'JaccardDefender'
1620

17-
def __init__(self, threshold):
21+
def __init__(self, threshold: float, binarize_threshold: Optional[float] = None):
22+
"""
23+
:param threshold: Jaccard similarity threshold (edges with similarity <= threshold are removed)
24+
:param binarize_threshold: Optional threshold to binarize non-binary features
25+
"""
1826
super().__init__()
19-
self.thrsh = threshold
20-
self.remove_edge_index = None
27+
self.threshold = threshold
28+
self.binarize_threshold = binarize_threshold
29+
self.removed_edges_train = None
30+
self.original_num_edges = None
2131

2232
def defense(
2333
self,
2434
gen_dataset: GeneralDataset,
2535
**kwargs
2636
) -> GeneralDataset:
27-
"""
28-
Modify input graph by removing edges between dissimilar nodes
29-
:param gen_dataset: input graph dataset
30-
:return: modified graph (only adjacency matrix modified)
31-
"""
37+
task = gen_dataset.dataset_var_config.task
38+
39+
if task in [Task.EDGE_PREDICTION, Task.EDGE_REGRESSION]:
40+
if not hasattr(gen_dataset, 'train_mask') or gen_dataset.train_mask is None:
41+
raise RuntimeError("JaccardDefender for link tasks requires train_test_split() to be called first")
42+
43+
self.original_num_edges = gen_dataset.data.edge_index.size(1)
44+
45+
x = self._prepare_features(gen_dataset.data.x)
46+
47+
if task in [Task.EDGE_PREDICTION, Task.EDGE_REGRESSION]:
48+
gen_dataset = self._defense_link_task(gen_dataset, x)
49+
else:
50+
gen_dataset = self._defense_standard_task(gen_dataset, x)
3251

33-
def is_binary_tensor(X: torch.Tensor) -> bool:
34-
return torch.all((X == 0) | (X == 1)).item()
35-
36-
assert is_binary_tensor(gen_dataset.data.x), "The features should be presented in binary form"
37-
38-
# TODO need to check whether features binary or not. Consistency required - Cora has 'unknown' features e.g.
39-
# self.drop_edges(batch)
40-
edge_index = gen_dataset.data.edge_index.tolist()
41-
#new_edge_mask = torch.zeros_like(gen_dataset.data.edge_index).bool()
42-
new_edge_index = [[],[]]
43-
self.remove_edge_index = [[], []]
44-
for i in range(len(edge_index[0])):
45-
if self.jaccard_index(gen_dataset.data.x, edge_index[0][i], edge_index[1][i]) > self.thrsh:
46-
# new_edge_mask[0,i] = True
47-
# new_edge_mask[1,i] = True
48-
new_edge_index[0].append(edge_index[0][i])
49-
new_edge_index[1].append(edge_index[1][i])
50-
else:
51-
self.remove_edge_index[0].append(edge_index[0][i])
52-
self.remove_edge_index[1].append(edge_index[1][i])
53-
# gen_dataset.data.edge_index *= new_edge_mask.float()
54-
gen_dataset.data.edge_index = torch.tensor(new_edge_index).long()
5552
return gen_dataset
5653

57-
def jaccard_index(
54+
def _prepare_features(self, x: torch.Tensor) -> torch.Tensor:
55+
if self.binarize_threshold is not None:
56+
x = (x > self.binarize_threshold).float()
57+
elif not _is_binary_tensor(x):
58+
raise ValueError(
59+
"JaccardDefender requires binary features"
60+
)
61+
return x
5862

63+
def _defense_link_task(
5964
self,
60-
x,
61-
u,
62-
v
63-
) -> float:
64-
"""
65-
Computes jaccard index of 'u' and 'v' objects based on their features
66-
:param x: feature matrix
67-
:param u: index of object from dataset
68-
:param v: index of object from dataset
69-
:return:
70-
"""
71-
im1 = x[u,:].detach().cpu().numpy().astype(bool)
72-
im2 = x[v,:].detach().cpu().numpy().astype(bool)
73-
intersection = np.logical_and(im1, im2)
74-
union = np.logical_or(im1, im2)
75-
return intersection.sum() / float(union.sum())
76-
77-
def dataset_diff(
78-
self
79-
) -> GraphModificationArtifact:
80-
diff = GraphModificationArtifact()
65+
gen_dataset: GeneralDataset,
66+
x: torch.Tensor
67+
) -> GeneralDataset:
68+
train_edge_label_index = gen_dataset.edge_label_index[:, gen_dataset.train_mask]
8169

82-
try:
83-
src_nodes = self.remove_edge_index[0]
84-
dst_nodes = self.remove_edge_index[1]
70+
filtered_train_edges, removed_edges = self._filter_edges_jaccard(train_edge_label_index, x)
71+
self.removed_edges_train = removed_edges
8572

86-
assert len(src_nodes) == len(dst_nodes), (
87-
"Mismatch in source and target edge lengths: "
88-
f"{len(src_nodes)} vs {len(dst_nodes)}"
89-
)
73+
gen_dataset.data.edge_index = filtered_train_edges
9074

91-
edges_to_remove = [
92-
[src, dst] for src, dst in zip(src_nodes, dst_nodes)
93-
]
75+
num_removed = removed_edges.size(1) if removed_edges is not None else 0
76+
print(f"JaccardDefender: Removed {num_removed}/{train_edge_label_index.size(1)} "
77+
f"training edges (threshold={self.threshold})")
9478

79+
return gen_dataset
80+
81+
def _defense_standard_task(
82+
self,
83+
gen_dataset: GeneralDataset,
84+
x: torch.Tensor
85+
) -> GeneralDataset:
86+
filtered_edges, removed_edges = self._filter_edges_jaccard(
87+
gen_dataset.data.edge_index, x
88+
)
89+
self.removed_edges_train = removed_edges # Reusing field for simplicity
90+
91+
gen_dataset.data.edge_index = filtered_edges
92+
93+
num_removed = removed_edges.size(1) if removed_edges is not None else 0
94+
print(f"JaccardDefender: Removed {num_removed}/{self.original_num_edges} edges "
95+
f"(threshold={self.threshold})")
96+
97+
return gen_dataset
98+
99+
def _filter_edges_jaccard(
100+
self,
101+
edge_index: torch.Tensor,
102+
x: torch.Tensor
103+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
104+
if edge_index.size(1) == 0:
105+
return edge_index, None
106+
107+
src_feats = x[edge_index[0]]
108+
dst_feats = x[edge_index[1]]
109+
110+
intersection = (src_feats * dst_feats).sum(dim=1) # AND
111+
union = ((src_feats + dst_feats) > 0).sum(dim=1).float() # OR
112+
113+
union = torch.where(union == 0, torch.ones_like(union), union)
114+
115+
jaccard_scores = intersection / union
116+
117+
keep_mask = jaccard_scores > self.threshold
118+
filtered_edges = edge_index[:, keep_mask]
119+
removed_edges = edge_index[:, ~keep_mask] if (~keep_mask).any() else None
120+
121+
return filtered_edges, removed_edges
122+
123+
def dataset_diff(self) -> GraphModificationArtifact:
124+
diff = GraphModificationArtifact()
125+
126+
if self.removed_edges_train is not None and self.removed_edges_train.size(1) > 0:
127+
edges_to_remove = self.removed_edges_train.t().tolist()
95128
diff.remove_edges(edges_to_remove)
96129
self.defense_diff = diff
97-
98-
except Exception as e:
99-
raise RuntimeError(
100-
f"Failed to build dataset diff from remove_edge_index: {e}"
101-
) from e
130+
else:
131+
# No edges removed
132+
self.defense_diff = diff
102133

103134
return self.defense_diff

metainfo/functions_parameters.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,8 @@
2525
},
2626
"NLLLoss": {
2727
"reduction": ["Reduction", "string", "mean", ["none","mean","sum"], "Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken, 'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction."]
28-
}
28+
},
29+
"BCEWithLogitsLoss": {
30+
"reduction": ["Reduction", "string", "mean", ["none","mean","sum"], "Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken, 'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction."]
31+
}
2932
}

tests/defense_test.py

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
from gnn_aid.attacks.mi_attacks import MIAttacker
66
from gnn_aid.aux.utils import POISON_DEFENSE_PARAMETERS_PATH, \
7-
OPTIMIZERS_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH, MI_DEFENSE_PARAMETERS_PATH
7+
OPTIMIZERS_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH, MI_DEFENSE_PARAMETERS_PATH, FUNCTIONS_PARAMETERS_PATH
88
from gnn_aid.data_structures.configs import ModelModificationConfig, DatasetConfig, DatasetVarConfig, \
9-
ConfigPattern, FeatureConfig, Task
9+
ConfigPattern, FeatureConfig, Task, ModelConfig, ModelStructureConfig
1010
from gnn_aid.datasets.datasets_manager import DatasetManager
1111
from gnn_aid.datasets.ptg_datasets import LibPTGDataset
12+
from gnn_aid.models_builder import FrameworkGNNConstructor
1213
from gnn_aid.models_builder.models_utils import Metric
1314
from gnn_aid.models_builder.model_managers import FrameworkGNNModelManager
1415
from gnn_aid.models_builder.models_zoo import model_configs_zoo
@@ -36,6 +37,13 @@ def setUp(self):
3637

3738
self.gen_dataset_sg_cora.train_test_split(percent_train_class=0.6, percent_test_class=0.4)
3839

40+
# Single-graph - Cora (Link Prediction)
41+
self.gen_dataset_lp_cora = DatasetManager.get_by_config(
42+
DatasetConfig((LibPTGDataset.data_folder, "Homogeneous", "Planetoid", "Cora")),
43+
LibPTGDataset.default_dataset_var_config.clone_with({"task": Task.EDGE_PREDICTION})
44+
)
45+
self.gen_dataset_lp_cora.train_test_split(percent_train_class=0.85, percent_test_class=0.15)
46+
3947
self.default_config = ModelModificationConfig(
4048
model_ver_ind=0,
4149
)
@@ -53,6 +61,27 @@ def setUp(self):
5361
}
5462
}
5563
)
64+
65+
self.manager_config_lp = ConfigPattern(
66+
_config_class="ModelManagerConfig",
67+
_config_kwargs={
68+
"mask_features": [],
69+
"optimizer": {
70+
"_config_class": "Config",
71+
"_class_name": "Adam",
72+
"_import_path": OPTIMIZERS_PARAMETERS_PATH,
73+
"_class_import_info": ["torch.optim"],
74+
"_config_kwargs": {"weight_decay": 5e-4},
75+
},
76+
"loss_function": {
77+
"_config_class": "Config",
78+
"_class_name": "BCEWithLogitsLoss",
79+
"_import_path": FUNCTIONS_PARAMETERS_PATH,
80+
"_class_import_info": ["torch.nn"],
81+
"_config_kwargs": {},
82+
},
83+
}
84+
)
5685
monkey_patch_dirs()
5786

5887
def tearDown(self):
@@ -136,6 +165,106 @@ def test_noise_mi_defender_cora(self):
136165
print(f"MI Attack accuracy:"
137166
f" {MIAttacker.compute_single_attack_accuracy(mask, res, self.gen_dataset_sg_cora.train_mask)}")
138167

168+
def test_jaccard_defender_link_prediction(self):
169+
"""
170+
Test JaccardDefender on Link Prediction task (Cora dataset).
171+
"""
172+
poison_defense_config = ConfigPattern(
173+
_class_name="JaccardDefender",
174+
_import_path=POISON_DEFENSE_PARAMETERS_PATH,
175+
_config_class="PoisonDefenseConfig",
176+
_config_kwargs={
177+
"threshold": 0.03,
178+
}
179+
)
180+
181+
gnn = FrameworkGNNConstructor(
182+
model_config=ModelConfig(
183+
structure=ModelStructureConfig([
184+
# Encoder: 2-layer GCN
185+
{
186+
'label': 'n',
187+
'layer': {
188+
'layer_name': 'GCNConv',
189+
'layer_kwargs': {
190+
'in_channels': self.gen_dataset_lp_cora.num_node_features,
191+
'out_channels': 32,
192+
},
193+
},
194+
'activation': {
195+
'activation_name': 'ReLU',
196+
'activation_kwargs': None,
197+
},
198+
},
199+
{
200+
'label': 'n',
201+
'layer': {
202+
'layer_name': 'GCNConv',
203+
'layer_kwargs': {
204+
'in_channels': 32,
205+
'out_channels': 16,
206+
},
207+
},
208+
},
209+
{
210+
'label': 'd',
211+
'function': {
212+
'function_name': 'CosineSimilarity',
213+
'function_kwargs': None
214+
}
215+
}
216+
])
217+
)
218+
)
219+
220+
gnn_model_manager = FrameworkGNNModelManager(
221+
gnn=gnn,
222+
dataset_path=self.gen_dataset_lp_cora.prepared_dir,
223+
modification=self.default_config,
224+
manager_config=self.manager_config_lp,
225+
)
226+
227+
gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config)
228+
229+
original_train_edges = self.gen_dataset_lp_cora.edge_label_index[:,
230+
self.gen_dataset_lp_cora.train_mask].size(1)
231+
232+
gnn_model_manager.train_model(
233+
gen_dataset=self.gen_dataset_lp_cora,
234+
steps=30,
235+
save_model_flag=False,
236+
metrics=[Metric("AUC", mask='train')]
237+
)
238+
239+
defense = gnn_model_manager.poison_defender
240+
241+
removed_edges = defense.defense_diff.edges["remove"]
242+
num_removed = len(removed_edges)
243+
244+
if num_removed > 0:
245+
print(f"JaccardDefender removed {num_removed} training edges "
246+
f"({num_removed / original_train_edges * 100:.1f}%)")
247+
# Sanity checks
248+
self.assertGreater(num_removed, 0, "No edges were removed - threshold may be too low")
249+
self.assertLess(num_removed, original_train_edges * 0.5,
250+
"Too many edges removed (>50%) - threshold may be too high")
251+
else:
252+
print("WARNING: No edges removed (threshold may be too low for this graph)")
253+
254+
test_metrics = gnn_model_manager.evaluate_model(
255+
gen_dataset=self.gen_dataset_lp_cora,
256+
metrics=[
257+
Metric("AUC", mask='test'),
258+
Metric("Recall@k", mask='test', k=50),
259+
Metric("Recall@k", mask='test', k=100),
260+
]
261+
)
262+
print("Link Prediction test metrics:", test_metrics)
263+
264+
self.assertGreater(test_metrics['test']['AUC'], 0.5, "AUC should be >0.5 (random baseline) after training")
265+
266+
self.assertGreater(test_metrics['test']['Recall@k{k=50}'], 0.0, "Recall@50 should be >0 after training")
267+
139268

140269
if __name__ == '__main__':
141270
unittest.main()

0 commit comments

Comments
 (0)