Skip to content

Commit e7173a4

Browse files
committed
Fix ShadowModel MI Attack
1 parent 47a3df3 commit e7173a4

File tree

2 files changed

+206
-72
lines changed

2 files changed

+206
-72
lines changed

gnn_aid/attacks/mi_attacks.py

Lines changed: 116 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
2-
from typing import Union, List
2+
from typing import Union, List, Tuple
33

4+
import numpy as np
45
import torch
56
from sklearn.metrics import accuracy_score
67
from sklearn.svm import SVC
@@ -128,41 +129,62 @@ def attack(
128129

129130
return self.results
130131

131-
132-
class ShadowModelMIAttacker(
133-
MIAttacker
134-
):
132+
class ShadowModelMIAttacker(MIAttacker):
135133
"""
136-
The surrogate model is trained on a part of the initial dataset.
137-
The classifier learns from its responses to determine whether the input is from train or test
134+
Shadow model-based membership inference attack for Node/Graph Classification.
138135
"""
139136
name = "ShadowModelMIAttacker"
140137

141138
def __init__(
142139
self,
143-
shadow_data_ratio: float = 0.25, # Fraction of data to use for shadow training
144-
shadow_epochs: int = 100, # Number of epochs to train shadow model
145-
classifier_type: str = 'svc', # Type of classifier to use ('svc' or 'mlp')
140+
shadow_data_ratio: float = 0.25,
141+
shadow_train_ratio: float = 0.75,
142+
shadow_epochs: int = 100,
143+
classifier_type: str = 'svc', # 'svc' only for now
144+
use_logits: bool = True, # Use logits (recommended) or softmax probs
146145
**kwargs
147146
):
148147
super().__init__(**kwargs)
149148
self.shadow_data_ratio = shadow_data_ratio
149+
self.shadow_train_ratio = shadow_train_ratio
150150
self.shadow_epochs = shadow_epochs
151151
self.classifier_type = classifier_type
152+
self.use_logits = use_logits
152153
self.classifier = None
154+
self.model_name = None
155+
156+
def _prepare_shadow_masks(
157+
self,
158+
num_nodes: int
159+
) -> Tuple[torch.Tensor, torch.Tensor]:
160+
"""
161+
Create shadow train/test masks over a random subset of nodes.
162+
"""
163+
all_indices = torch.arange(num_nodes)
164+
shadow_size = int(num_nodes * self.shadow_data_ratio)
165+
shadow_indices = all_indices[torch.randperm(num_nodes)[:shadow_size]]
166+
167+
n_train = int(shadow_size * self.shadow_train_ratio)
168+
shadow_train_indices = shadow_indices[:n_train]
169+
shadow_test_indices = shadow_indices[n_train:]
170+
171+
shadow_train_mask = torch.zeros(num_nodes, dtype=torch.bool)
172+
shadow_test_mask = torch.zeros(num_nodes, dtype=torch.bool)
173+
shadow_train_mask[shadow_train_indices] = True
174+
shadow_test_mask[shadow_test_indices] = True
153175

154-
# TODO customizable surrogate model
155-
# TODO customizable classifier
176+
return shadow_train_mask, shadow_test_mask
156177

157178
def _train_shadow_model(
158179
self,
159-
gen_dataset: GeneralDataset,
180+
shadow_dataset: GeneralDataset,
160181
shadow_train_mask: torch.Tensor
161-
):
182+
) -> torch.nn.Module:
162183
"""
163-
Train the shadow model on the shadow dataset
184+
Train shadow model on shadow training nodes
164185
"""
165-
shadow_model = model_configs_zoo(dataset=gen_dataset, model_name=self.model_name)
186+
shadow_model = model_configs_zoo(dataset=shadow_dataset, model_name=self.model_name)
187+
shadow_model = shadow_model.to(shadow_dataset.data.x.device)
166188

167189
optimizer = torch.optim.Adam(shadow_model.parameters(), lr=0.01)
168190
criterion = torch.nn.CrossEntropyLoss()
@@ -171,93 +193,115 @@ def _train_shadow_model(
171193
shadow_model.train()
172194
optimizer.zero_grad()
173195

174-
# Forward pass
175-
outputs = shadow_model(gen_dataset.data.x, gen_dataset.data.edge_index)
176-
177-
# Compute loss only on shadow training nodes
178-
loss = criterion(*move_to_same_device(outputs[shadow_train_mask], gen_dataset.data.y[shadow_train_mask]))
179-
print(f"Shadow loss: {loss}")
180-
181-
# Backward pass
196+
outputs = shadow_model(shadow_dataset.data.x, shadow_dataset.data.edge_index)
197+
loss = criterion(
198+
*move_to_same_device(
199+
outputs[shadow_train_mask],
200+
shadow_dataset.data.y[shadow_train_mask]
201+
)
202+
)
182203
loss.backward()
183204
optimizer.step()
205+
206+
if epoch % 20 == 0:
207+
print(f" Shadow epoch {epoch}/{self.shadow_epochs}, loss: {loss.item():.4f}")
208+
184209
return shadow_model
185210

211+
def _extract_features(
212+
self,
213+
model: torch.nn.Module,
214+
dataset: GeneralDataset,
215+
node_mask: torch.Tensor
216+
) -> np.ndarray:
217+
"""Extract features (logits or probabilities) from model outputs."""
218+
model.eval()
219+
with torch.no_grad():
220+
outputs = model(dataset.data.x, dataset.data.edge_index)
221+
if self.use_logits:
222+
features = outputs[node_mask]
223+
else:
224+
features = torch.softmax(outputs[node_mask], dim=1)
225+
return features.cpu().numpy()
226+
186227
def _train_attack_classifier(
187228
self,
188229
shadow_model: torch.nn.Module,
189-
shadow_data: GeneralDataset,
190-
shadow_train_mask: torch.tensor,
191-
original_train_mask: torch.Tensor
230+
shadow_dataset: GeneralDataset,
231+
shadow_train_mask: torch.Tensor,
232+
shadow_test_mask: torch.Tensor
192233
):
193-
"""
194-
Train the attack classifier using shadow model outputs
195-
"""
196-
shadow_model.eval()
197-
with torch.no_grad():
198-
outputs = shadow_model(shadow_data.data.x, shadow_data.data.edge_index)
199-
probs = torch.softmax(outputs, dim=1)
200-
# max_probs = torch.max(probs, dim=1).values.cpu().numpy()
201-
# Prepare features and labels for attack classifier
202-
# X = max_probs.reshape(-1, 1) # Using prediction confidence as feature
203-
X = probs[shadow_train_mask].detach().cpu().numpy()
204-
y = original_train_mask[shadow_train_mask].detach().cpu().numpy().astype(int) # Membership labels
234+
"""Train attack classifier using shadow model outputs."""
235+
# Features for shadow train nodes → label 1 (member)
236+
X_train = self._extract_features(shadow_model, shadow_dataset, shadow_train_mask)
237+
y_train = np.ones(X_train.shape[0])
205238

239+
# Features for shadow test nodes → label 0 (non-member)
240+
X_test = self._extract_features(shadow_model, shadow_dataset, shadow_test_mask)
241+
y_test = np.zeros(X_test.shape[0])
242+
243+
# Combine
244+
X = np.vstack([X_train, X_test])
245+
y = np.concatenate([y_train, y_test])
246+
247+
# Train classifier
206248
if self.classifier_type == 'svc':
207-
self.classifier = SVC(kernel='rbf', probability=False)
249+
self.classifier = SVC(kernel='rbf', probability=True)
208250
else:
209-
raise ValueError(f"Unsupported classifier type: {self.classifier_type}")
251+
raise ValueError(f"Unsupported classifier: {self.classifier_type}")
210252

211253
self.classifier.fit(X, y)
212254

213-
# Evaluate on shadow data (for debugging)
255+
# Debug: evaluate on shadow data
214256
y_pred = self.classifier.predict(X)
215-
shadow_accuracy = accuracy_score(y, y_pred)
216-
print(f"Shadow model attack classifier accuracy: {shadow_accuracy:.4f}")
257+
acc = accuracy_score(y, y_pred)
258+
print(f" ✓ Shadow attack classifier accuracy: {acc:.4f}")
217259

218260
def attack(
219261
self,
220262
model: torch.nn.Module,
221263
gen_dataset: GeneralDataset,
222-
mask_tensor: Union[List[bool], torch.Tensor],
264+
mask_tensor: Union[torch.Tensor, list],
223265
**kwargs
224266
):
267+
"""
268+
Perform membership inference attack using shadow model technique.
269+
"""
225270
task_type = gen_dataset.is_multi()
226-
if task_type:
227-
self.model_name = 'gcn_gcn_linear'
228-
else:
229-
self.model_name = 'gcn_gcn'
230-
231-
shadow_dataset = copy.deepcopy(gen_dataset)
271+
self.model_name = 'gcn_gcn_lin_no_softmax' if task_type else 'gcn_gcn_no_softmax'
232272

233273
num_nodes = gen_dataset.data.x.shape[0]
234274
shadow_indices = torch.randperm(num_nodes)[:int(num_nodes * self.shadow_data_ratio)]
235-
shadow_train_mask = torch.zeros_like(gen_dataset.train_mask, dtype=torch.bool)
236-
shadow_test_mask = torch.zeros_like(gen_dataset.train_mask, dtype=torch.bool)
237-
shadow_train_mask[shadow_indices[:int(len(shadow_indices) * 0.75)]] = True # 75-25 split
238-
shadow_test_mask[shadow_indices[int(len(shadow_indices) * 0.75):]] = True
275+
n_shadow = len(shadow_indices)
276+
n_train = int(n_shadow * 0.75)
277+
278+
shadow_train_indices = shadow_indices[:n_train]
279+
shadow_test_indices = shadow_indices[n_train:]
280+
281+
shadow_train_mask = torch.zeros(num_nodes, dtype=torch.bool)
282+
shadow_test_mask = torch.zeros(num_nodes, dtype=torch.bool)
283+
shadow_train_mask[shadow_train_indices] = True
284+
shadow_test_mask[shadow_test_indices] = True
285+
286+
shadow_dataset = copy.deepcopy(gen_dataset)
239287
shadow_dataset.train_mask = shadow_train_mask
240288
shadow_dataset.test_mask = shadow_test_mask
241289

242-
print("Training shadow model...")
243290
shadow_model = self._train_shadow_model(shadow_dataset, shadow_train_mask)
244291

245-
print("Training attack classifier...")
246-
self._train_attack_classifier(shadow_model, shadow_dataset, shadow_train_mask, gen_dataset.train_mask)
292+
self._train_attack_classifier(shadow_model, shadow_dataset, shadow_train_mask, shadow_test_mask)
247293

248-
print("Performing attack on target model...")
249294
model.eval()
250295
with torch.no_grad():
251-
outputs = shadow_model(gen_dataset.data.x, gen_dataset.data.edge_index)
252-
probs = torch.softmax(outputs, dim=1)
253-
max_probs = torch.max(probs, dim=1).values.detach().cpu().numpy()
254-
# Predict membership using attack classifier
255-
# X_target = max_probs.reshape(-1, 1)
256-
X_target = probs
257-
inferred_train_mask = torch.tensor(self.classifier.predict(X_target),
258-
dtype=torch.bool, device=mask_tensor.device)
259-
260-
# Store results
261-
self.results.add(mask_tensor, inferred_train_mask)
262-
263-
return self.results
296+
outputs = model(gen_dataset.data.x, gen_dataset.data.edge_index)
297+
if self.use_logits:
298+
features = outputs
299+
else:
300+
features = torch.softmax(outputs, dim=1)
301+
features = features.cpu().numpy()
302+
303+
all_predictions = self.classifier.predict(features)
304+
inferred_membership_full = torch.tensor(all_predictions, dtype=torch.bool)
305+
306+
self.results.add(mask_tensor, inferred_membership_full)
307+
return self.results

gnn_aid/models_builder/models_zoo.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,40 @@ def model_configs_zoo(
620620
)
621621
)
622622

623+
gcn_gcn_no_softmax = FrameworkGNNConstructor(
624+
model_config=ModelConfig(
625+
structure=ModelStructureConfig(
626+
[
627+
{
628+
'label': 'n',
629+
'layer': {
630+
'layer_name': 'GCNConv',
631+
'layer_kwargs': {
632+
'in_channels': dataset.num_node_features,
633+
'out_channels': 16,
634+
},
635+
},
636+
'activation': {
637+
'activation_name': 'ReLU',
638+
'activation_kwargs': None,
639+
},
640+
},
641+
642+
{
643+
'label': 'n',
644+
'layer': {
645+
'layer_name': 'GCNConv',
646+
'layer_kwargs': {
647+
'in_channels': 16,
648+
'out_channels': dataset.num_classes,
649+
},
650+
},
651+
},
652+
]
653+
)
654+
)
655+
)
656+
623657
gcn_gcn_gcn = FrameworkGNNConstructor(
624658
model_config=ModelConfig(
625659
structure=ModelStructureConfig(
@@ -873,7 +907,63 @@ def model_configs_zoo(
873907
'activation_kwargs': None,
874908
},
875909
},
910+
]
911+
)
912+
)
913+
)
914+
915+
gcn_gcn_lin_no_softmax = FrameworkGNNConstructor(
916+
model_config=ModelConfig(
917+
structure=ModelStructureConfig(
918+
[
919+
{
920+
'label': 'n',
921+
'layer': {
922+
'layer_name': 'GCNConv',
923+
'layer_kwargs': {
924+
'in_channels': dataset.num_node_features,
925+
'out_channels': 16,
926+
},
927+
},
928+
'activation': {
929+
'activation_name': 'ReLU',
930+
'activation_kwargs': None,
931+
},
932+
'connections': [
933+
{
934+
'into_layer': 2,
935+
'connection_kwargs': {
936+
'aggregation_type': 'cat',
937+
},
938+
},
939+
],
940+
},
941+
942+
{
943+
'label': 'n',
944+
'layer': {
945+
'layer_name': 'GCNConv',
946+
'layer_kwargs': {
947+
'in_channels': 16,
948+
'out_channels': 16,
949+
},
950+
},
951+
'activation': {
952+
'activation_name': 'ReLU',
953+
'activation_kwargs': None,
954+
},
955+
},
876956

957+
{
958+
'label': 'n',
959+
'layer': {
960+
'layer_name': 'Linear',
961+
'layer_kwargs': {
962+
'in_features': 16 * 2,
963+
'out_features': dataset.num_classes,
964+
},
965+
},
966+
},
877967
]
878968
)
879969
)

0 commit comments

Comments
 (0)