Skip to content

Commit de86fc5

Browse files
Merge pull request #101 from ispras/sazonov_final
Sazonov final
2 parents 534e04e + c4eb98b commit de86fc5

File tree

9 files changed

+1244
-150
lines changed

9 files changed

+1244
-150
lines changed

gnn_aid/attacks/mi_attacks.py

Lines changed: 306 additions & 72 deletions
Large diffs are not rendered by default.
Lines changed: 102 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,134 @@
11
import torch
22
import numpy as np
3+
from typing import Optional
34

45
from gnn_aid.datasets.gen_dataset import GeneralDataset
56
from gnn_aid.data_structures.graph_modification_artifacts import GraphModificationArtifact
67
from gnn_aid.defenses.poison_defense import PoisonDefender
8+
from gnn_aid.data_structures.configs import Task
79

810

9-
class JaccardDefender(
10-
PoisonDefender
11-
):
11+
def _is_binary_tensor(X: torch.Tensor) -> bool:
12+
return torch.all((X == 0) | (X == 1)).item()
13+
14+
15+
class JaccardDefender(PoisonDefender):
1216
"""
13-
Poison defense based on removing edges between dissimilar nodes
17+
Poison defense based on removing edges between dissimilar nodes.
1418
"""
1519
name = 'JaccardDefender'
1620

17-
def __init__(self, threshold):
21+
def __init__(self, threshold: float, binarize_threshold: Optional[float] = None):
22+
"""
23+
:param threshold: Jaccard similarity threshold (edges with similarity <= threshold are removed)
24+
:param binarize_threshold: Optional threshold to binarize non-binary features
25+
"""
1826
super().__init__()
19-
self.thrsh = threshold
20-
self.remove_edge_index = None
27+
self.threshold = threshold
28+
self.binarize_threshold = binarize_threshold
29+
self.removed_edges_train = None
30+
self.original_num_edges = None
2131

2232
def defense(
2333
self,
2434
gen_dataset: GeneralDataset,
2535
**kwargs
2636
) -> GeneralDataset:
27-
"""
28-
Modify input graph by removing edges between dissimilar nodes
29-
:param gen_dataset: input graph dataset
30-
:return: modified graph (only adjacency matrix modified)
31-
"""
37+
task = gen_dataset.dataset_var_config.task
38+
39+
if task in [Task.EDGE_PREDICTION, Task.EDGE_REGRESSION]:
40+
if not hasattr(gen_dataset, 'train_mask') or gen_dataset.train_mask is None:
41+
raise RuntimeError("JaccardDefender for link tasks requires train_test_split() to be called first")
42+
43+
self.original_num_edges = gen_dataset.data.edge_index.size(1)
44+
45+
x = self._prepare_features(gen_dataset.data.x)
46+
47+
if task in [Task.EDGE_PREDICTION, Task.EDGE_REGRESSION]:
48+
gen_dataset = self._defense_link_task(gen_dataset, x)
49+
else:
50+
gen_dataset = self._defense_standard_task(gen_dataset, x)
3251

33-
def is_binary_tensor(X: torch.Tensor) -> bool:
34-
return torch.all((X == 0) | (X == 1)).item()
35-
36-
assert is_binary_tensor(gen_dataset.data.x), "The features should be presented in binary form"
37-
38-
# TODO need to check whether features binary or not. Consistency required - Cora has 'unknown' features e.g.
39-
# self.drop_edges(batch)
40-
edge_index = gen_dataset.data.edge_index.tolist()
41-
#new_edge_mask = torch.zeros_like(gen_dataset.data.edge_index).bool()
42-
new_edge_index = [[],[]]
43-
self.remove_edge_index = [[], []]
44-
for i in range(len(edge_index[0])):
45-
if self.jaccard_index(gen_dataset.data.x, edge_index[0][i], edge_index[1][i]) > self.thrsh:
46-
# new_edge_mask[0,i] = True
47-
# new_edge_mask[1,i] = True
48-
new_edge_index[0].append(edge_index[0][i])
49-
new_edge_index[1].append(edge_index[1][i])
50-
else:
51-
self.remove_edge_index[0].append(edge_index[0][i])
52-
self.remove_edge_index[1].append(edge_index[1][i])
53-
# gen_dataset.data.edge_index *= new_edge_mask.float()
54-
gen_dataset.data.edge_index = torch.tensor(new_edge_index).long()
5552
return gen_dataset
5653

57-
def jaccard_index(
54+
def _prepare_features(self, x: torch.Tensor) -> torch.Tensor:
55+
if self.binarize_threshold is not None:
56+
x = (x > self.binarize_threshold).float()
57+
elif not _is_binary_tensor(x):
58+
raise ValueError(
59+
"JaccardDefender requires binary features"
60+
)
61+
return x
5862

63+
def _defense_link_task(
5964
self,
60-
x,
61-
u,
62-
v
63-
) -> float:
64-
"""
65-
Computes jaccard index of 'u' and 'v' objects based on their features
66-
:param x: feature matrix
67-
:param u: index of object from dataset
68-
:param v: index of object from dataset
69-
:return:
70-
"""
71-
im1 = x[u,:].detach().cpu().numpy().astype(bool)
72-
im2 = x[v,:].detach().cpu().numpy().astype(bool)
73-
intersection = np.logical_and(im1, im2)
74-
union = np.logical_or(im1, im2)
75-
return intersection.sum() / float(union.sum())
76-
77-
def dataset_diff(
78-
self
79-
) -> GraphModificationArtifact:
80-
diff = GraphModificationArtifact()
65+
gen_dataset: GeneralDataset,
66+
x: torch.Tensor
67+
) -> GeneralDataset:
68+
train_edge_label_index = gen_dataset.edge_label_index[:, gen_dataset.train_mask]
8169

82-
try:
83-
src_nodes = self.remove_edge_index[0]
84-
dst_nodes = self.remove_edge_index[1]
70+
filtered_train_edges, removed_edges = self._filter_edges_jaccard(train_edge_label_index, x)
71+
self.removed_edges_train = removed_edges
8572

86-
assert len(src_nodes) == len(dst_nodes), (
87-
"Mismatch in source and target edge lengths: "
88-
f"{len(src_nodes)} vs {len(dst_nodes)}"
89-
)
73+
gen_dataset.data.edge_index = filtered_train_edges
9074

91-
edges_to_remove = [
92-
[src, dst] for src, dst in zip(src_nodes, dst_nodes)
93-
]
75+
num_removed = removed_edges.size(1) if removed_edges is not None else 0
76+
print(f"JaccardDefender: Removed {num_removed}/{train_edge_label_index.size(1)} "
77+
f"training edges (threshold={self.threshold})")
9478

79+
return gen_dataset
80+
81+
def _defense_standard_task(
82+
self,
83+
gen_dataset: GeneralDataset,
84+
x: torch.Tensor
85+
) -> GeneralDataset:
86+
filtered_edges, removed_edges = self._filter_edges_jaccard(
87+
gen_dataset.data.edge_index, x
88+
)
89+
self.removed_edges_train = removed_edges # Reusing field for simplicity
90+
91+
gen_dataset.data.edge_index = filtered_edges
92+
93+
num_removed = removed_edges.size(1) if removed_edges is not None else 0
94+
print(f"JaccardDefender: Removed {num_removed}/{self.original_num_edges} edges "
95+
f"(threshold={self.threshold})")
96+
97+
return gen_dataset
98+
99+
def _filter_edges_jaccard(
100+
self,
101+
edge_index: torch.Tensor,
102+
x: torch.Tensor
103+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
104+
if edge_index.size(1) == 0:
105+
return edge_index, None
106+
107+
src_feats = x[edge_index[0]]
108+
dst_feats = x[edge_index[1]]
109+
110+
intersection = (src_feats * dst_feats).sum(dim=1) # AND
111+
union = ((src_feats + dst_feats) > 0).sum(dim=1).float() # OR
112+
113+
union = torch.where(union == 0, torch.ones_like(union), union)
114+
115+
jaccard_scores = intersection / union
116+
117+
keep_mask = jaccard_scores > self.threshold
118+
filtered_edges = edge_index[:, keep_mask]
119+
removed_edges = edge_index[:, ~keep_mask] if (~keep_mask).any() else None
120+
121+
return filtered_edges, removed_edges
122+
123+
def dataset_diff(self) -> GraphModificationArtifact:
124+
diff = GraphModificationArtifact()
125+
126+
if self.removed_edges_train is not None and self.removed_edges_train.size(1) > 0:
127+
edges_to_remove = self.removed_edges_train.t().tolist()
95128
diff.remove_edges(edges_to_remove)
96129
self.defense_diff = diff
97-
98-
except Exception as e:
99-
raise RuntimeError(
100-
f"Failed to build dataset diff from remove_edge_index: {e}"
101-
) from e
130+
else:
131+
# No edges removed
132+
self.defense_diff = diff
102133

103134
return self.defense_diff

gnn_aid/defenses/mi_defense.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,87 @@ def post_batch(
147147
"outputs": modified_logits,
148148
"loss": modified_loss
149149
}
150+
151+
152+
class NoiseMILinkDefender(MIDefender):
153+
"""
154+
MI defense for Link Prediction tasks via edge logit perturbation
155+
"""
156+
name = "NoiseMILinkDefender"
157+
158+
def __init__(
159+
self,
160+
noise_type: Literal["reverse_sigmoid", "random", "none"] = "reverse_sigmoid",
161+
beta: float = 0.3,
162+
gamma: float = 0.8,
163+
noise_scale: float = 0.2,
164+
temperature: float = 1.0,
165+
**kwargs
166+
):
167+
super().__init__(**kwargs)
168+
self.noise_type = noise_type
169+
self.beta = beta
170+
self.gamma = gamma
171+
self.noise_scale = noise_scale
172+
self.temperature = temperature
173+
174+
if noise_type not in ["reverse_sigmoid", "random", "none"]:
175+
raise ValueError(f"Invalid noise_type: {noise_type}")
176+
177+
def _apply_reverse_sigmoid_binary(
178+
self,
179+
edge_logits: torch.Tensor
180+
) -> torch.Tensor:
181+
"""
182+
Reverse sigmoid perturbation for binary classification (link prediction)
183+
"""
184+
probs = torch.sigmoid(edge_logits / self.temperature)
185+
186+
perturbed_temp_logits = self.gamma * edge_logits
187+
perturbed_temp_probs = torch.sigmoid(perturbed_temp_logits)
188+
r = self.beta * (perturbed_temp_probs - 0.5)
189+
190+
perturbed_probs = probs - r
191+
perturbed_probs = torch.clamp(perturbed_probs, min=1e-7, max=1.0 - 1e-7)
192+
193+
perturbed_logits = torch.logit(perturbed_probs, eps=1e-7) * self.temperature
194+
return perturbed_logits
195+
196+
def _apply_random_noise(
197+
self,
198+
edge_logits: torch.Tensor
199+
) -> torch.Tensor:
200+
"""Add Gaussian noise to edge logits"""
201+
noise = torch.randn_like(edge_logits) * self.noise_scale
202+
return edge_logits + noise
203+
204+
def post_batch(
205+
self,
206+
model_manager: Any,
207+
batch: Any,
208+
**kwargs
209+
) -> dict:
210+
node_emb = model_manager.gnn(batch.x, batch.edge_index)
211+
src_emb = node_emb[batch.edge_label_index[0]]
212+
dst_emb = node_emb[batch.edge_label_index[1]]
213+
214+
if hasattr(model_manager.gnn, 'decode'):
215+
edge_logits = model_manager.gnn.decode(src_emb, dst_emb).squeeze(-1)
216+
else:
217+
edge_logits = (src_emb * dst_emb).sum(dim=-1)
218+
219+
if self.noise_type == "reverse_sigmoid":
220+
modified_logits = self._apply_reverse_sigmoid_binary(edge_logits)
221+
elif self.noise_type == "random":
222+
modified_logits = self._apply_random_noise(edge_logits)
223+
else:
224+
modified_logits = edge_logits
225+
226+
edge_labels = batch.edge_label.float()
227+
modified_loss = model_manager.loss_function(modified_logits, edge_labels)
228+
229+
return {
230+
"outputs": modified_logits,
231+
"loss": modified_loss,
232+
"original_logits": edge_logits.detach()
233+
}

0 commit comments

Comments
 (0)