Skip to content

Commit 846706d

Browse files
committed
pgd on features for aaai
1 parent 504f0e4 commit 846706d

File tree

1 file changed

+93
-24
lines changed

1 file changed

+93
-24
lines changed

gnn_aid/attacks/evasion_attacks.py

Lines changed: 93 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .evasion_attacks_collection.nettack.utils import NettackSurrogate, NettackAttack
1212

1313
# PGD imports
14+
from torch_geometric.data import Batch
1415
from torch_geometric.data import Data
1516
from torch_geometric.utils import k_hop_subgraph
1617
from gnn_aid.models_builder.models_utils import EdgeMaskingWrapper
@@ -436,29 +437,7 @@ def attack(
436437
model.eval()
437438
attack_loss = self.get_attack_loss(model_manager)
438439

439-
if task_type:
440-
graph_idx = self.element_idx
441-
x = gen_dataset.dataset[graph_idx].x.clone()
442-
edge_index = gen_dataset.dataset[graph_idx].edge_index.clone()
443-
y = gen_dataset.dataset[graph_idx].y.clone()
444-
else:
445-
node_idx = self.element_idx
446-
x = gen_dataset.data.x.clone()
447-
edge_index = gen_dataset.data.edge_index.clone()
448-
y = gen_dataset.data.y.clone()
449-
450-
num_hops = model.n_layers
451-
subset, edge_index_subset, inv, edge_mask_k_hop = k_hop_subgraph(
452-
node_idx=node_idx,
453-
num_hops=num_hops,
454-
edge_index=edge_index,
455-
relabel_nodes=True,
456-
directed=False
457-
)
458-
node_idx_remap = torch.where(subset == node_idx)[0].item()
459-
x = x[subset]
460-
y = y[subset]
461-
440+
"""
462441
if self.is_feature_attack:
463442
orig_x = x.clone()
464443
x.requires_grad = True
@@ -496,8 +475,98 @@ def attack(
496475
for remap_idx, node_idx in enumerate(subset.detach().cpu()):
497476
for feature_idx in range(x.size(1)):
498477
self.attack_diff.change_node_feature(node_idx, feature_idx,
499-
x[remap_idx][feature_idx].detach().cpu().item())
478+
x[remap_idx][feature_idx].detach().cpu().item())"""
479+
480+
if self.is_feature_attack:
481+
device = gen_dataset.data.x.device
482+
483+
if task_type:
484+
graph_idxs = mask_tensor.nonzero(as_tuple=True)[0] # LongTensor индексов графов
485+
486+
selected_graphs = [gen_dataset.dataset[i].clone() for i in graph_idxs.tolist()]
487+
488+
batch = Batch.from_data_list(selected_graphs).to(device)
489+
490+
x = batch.x.clone()
491+
edge_index = batch.edge_index
492+
y = batch.y
493+
batch_vec = batch.batch
494+
495+
orig_x = x.clone()
496+
x.requires_grad_(True)
497+
498+
for _ in tqdm(range(self.num_iterations)):
499+
if x.grad is not None:
500+
x.grad.zero_()
501+
502+
out = model(x, edge_index, batch_vec)
503+
loss = attack_loss(out, y)
504+
loss.backward()
505+
506+
with torch.no_grad():
507+
x -= self.learning_rate * x.grad.sign()
508+
x.copy_(torch.max(torch.min(x, orig_x + self.epsilon), orig_x - self.epsilon))
509+
x.copy_(torch.clamp(x, -self.epsilon, self.epsilon))
510+
511+
ptr = batch.ptr
512+
attacked = []
513+
for j, gi in enumerate(graph_idxs.tolist()):
514+
start = int(ptr[j])
515+
end = int(ptr[j + 1])
516+
517+
attacked.append({
518+
"graph_idx": gi,
519+
"x": x[start:end].detach().clone(),
520+
"edge_index": selected_graphs[j].edge_index
521+
})
522+
self.attack_res = attacked
523+
524+
else:
525+
x = gen_dataset.data.x.clone()
526+
edge_index = gen_dataset.data.edge_index.clone()
527+
y = gen_dataset.data.y.clone()
528+
# node_idxs = mask_tensor.nonzero(as_tuple=True)[0].tolist()
529+
530+
orig_x = x.clone()
531+
x.requires_grad = True
532+
533+
for _ in tqdm(range(self.num_iterations)):
534+
if x.grad is not None:
535+
x.grad.zero_()
536+
out = model(x, edge_index)
537+
loss = attack_loss(out[mask_tensor], y[mask_tensor])
538+
loss.backward()
539+
with torch.no_grad():
540+
x -= self.learning_rate * x.grad.sign()
541+
x.copy_(torch.max(torch.min(x, orig_x + self.epsilon), orig_x - self.epsilon))
542+
x.copy_(torch.clamp(x, -self.epsilon, self.epsilon))
543+
544+
self.attack_res = Data(x=x, edge_index=edge_index, y=y)
545+
500546
else: # structure attack
547+
if task_type:
548+
graph_idx = self.element_idx
549+
x = gen_dataset.dataset[graph_idx].x.clone()
550+
edge_index = gen_dataset.dataset[graph_idx].edge_index.clone()
551+
y = gen_dataset.dataset[graph_idx].y.clone()
552+
else:
553+
node_idx = self.element_idx
554+
x = gen_dataset.data.x.clone()
555+
edge_index = gen_dataset.data.edge_index.clone()
556+
y = gen_dataset.data.y.clone()
557+
558+
num_hops = model.n_layers
559+
subset, edge_index_subset, inv, edge_mask_k_hop = k_hop_subgraph(
560+
node_idx=node_idx,
561+
num_hops=num_hops,
562+
edge_index=edge_index,
563+
relabel_nodes=True,
564+
directed=False
565+
)
566+
node_idx_remap = torch.where(subset == node_idx)[0].item()
567+
x = x[subset]
568+
y = y[subset]
569+
501570
if task_type:
502571
num_edges = edge_index.size(1)
503572
else:

0 commit comments

Comments
 (0)