|
11 | 11 | from .evasion_attacks_collection.nettack.utils import NettackSurrogate, NettackAttack |
12 | 12 |
|
13 | 13 | # PGD imports |
| 14 | +from torch_geometric.data import Batch |
14 | 15 | from torch_geometric.data import Data |
15 | 16 | from torch_geometric.utils import k_hop_subgraph |
16 | 17 | from gnn_aid.models_builder.models_utils import EdgeMaskingWrapper |
@@ -436,29 +437,7 @@ def attack( |
436 | 437 | model.eval() |
437 | 438 | attack_loss = self.get_attack_loss(model_manager) |
438 | 439 |
|
439 | | - if task_type: |
440 | | - graph_idx = self.element_idx |
441 | | - x = gen_dataset.dataset[graph_idx].x.clone() |
442 | | - edge_index = gen_dataset.dataset[graph_idx].edge_index.clone() |
443 | | - y = gen_dataset.dataset[graph_idx].y.clone() |
444 | | - else: |
445 | | - node_idx = self.element_idx |
446 | | - x = gen_dataset.data.x.clone() |
447 | | - edge_index = gen_dataset.data.edge_index.clone() |
448 | | - y = gen_dataset.data.y.clone() |
449 | | - |
450 | | - num_hops = model.n_layers |
451 | | - subset, edge_index_subset, inv, edge_mask_k_hop = k_hop_subgraph( |
452 | | - node_idx=node_idx, |
453 | | - num_hops=num_hops, |
454 | | - edge_index=edge_index, |
455 | | - relabel_nodes=True, |
456 | | - directed=False |
457 | | - ) |
458 | | - node_idx_remap = torch.where(subset == node_idx)[0].item() |
459 | | - x = x[subset] |
460 | | - y = y[subset] |
461 | | - |
| 440 | + """ |
462 | 441 | if self.is_feature_attack: |
463 | 442 | orig_x = x.clone() |
464 | 443 | x.requires_grad = True |
@@ -496,8 +475,98 @@ def attack( |
496 | 475 | for remap_idx, node_idx in enumerate(subset.detach().cpu()): |
497 | 476 | for feature_idx in range(x.size(1)): |
498 | 477 | self.attack_diff.change_node_feature(node_idx, feature_idx, |
499 | | - x[remap_idx][feature_idx].detach().cpu().item()) |
| 478 | + x[remap_idx][feature_idx].detach().cpu().item())""" |
| 479 | + |
| 480 | + if self.is_feature_attack: |
| 481 | + device = gen_dataset.data.x.device |
| 482 | + |
| 483 | + if task_type: |
| 484 | + graph_idxs = mask_tensor.nonzero(as_tuple=True)[0] # LongTensor индексов графов |
| 485 | + |
| 486 | + selected_graphs = [gen_dataset.dataset[i].clone() for i in graph_idxs.tolist()] |
| 487 | + |
| 488 | + batch = Batch.from_data_list(selected_graphs).to(device) |
| 489 | + |
| 490 | + x = batch.x.clone() |
| 491 | + edge_index = batch.edge_index |
| 492 | + y = batch.y |
| 493 | + batch_vec = batch.batch |
| 494 | + |
| 495 | + orig_x = x.clone() |
| 496 | + x.requires_grad_(True) |
| 497 | + |
| 498 | + for _ in tqdm(range(self.num_iterations)): |
| 499 | + if x.grad is not None: |
| 500 | + x.grad.zero_() |
| 501 | + |
| 502 | + out = model(x, edge_index, batch_vec) |
| 503 | + loss = attack_loss(out, y) |
| 504 | + loss.backward() |
| 505 | + |
| 506 | + with torch.no_grad(): |
| 507 | + x -= self.learning_rate * x.grad.sign() |
| 508 | + x.copy_(torch.max(torch.min(x, orig_x + self.epsilon), orig_x - self.epsilon)) |
| 509 | + x.copy_(torch.clamp(x, -self.epsilon, self.epsilon)) |
| 510 | + |
| 511 | + ptr = batch.ptr |
| 512 | + attacked = [] |
| 513 | + for j, gi in enumerate(graph_idxs.tolist()): |
| 514 | + start = int(ptr[j]) |
| 515 | + end = int(ptr[j + 1]) |
| 516 | + |
| 517 | + attacked.append({ |
| 518 | + "graph_idx": gi, |
| 519 | + "x": x[start:end].detach().clone(), |
| 520 | + "edge_index": selected_graphs[j].edge_index |
| 521 | + }) |
| 522 | + self.attack_res = attacked |
| 523 | + |
| 524 | + else: |
| 525 | + x = gen_dataset.data.x.clone() |
| 526 | + edge_index = gen_dataset.data.edge_index.clone() |
| 527 | + y = gen_dataset.data.y.clone() |
| 528 | + # node_idxs = mask_tensor.nonzero(as_tuple=True)[0].tolist() |
| 529 | + |
| 530 | + orig_x = x.clone() |
| 531 | + x.requires_grad = True |
| 532 | + |
| 533 | + for _ in tqdm(range(self.num_iterations)): |
| 534 | + if x.grad is not None: |
| 535 | + x.grad.zero_() |
| 536 | + out = model(x, edge_index) |
| 537 | + loss = attack_loss(out[mask_tensor], y[mask_tensor]) |
| 538 | + loss.backward() |
| 539 | + with torch.no_grad(): |
| 540 | + x -= self.learning_rate * x.grad.sign() |
| 541 | + x.copy_(torch.max(torch.min(x, orig_x + self.epsilon), orig_x - self.epsilon)) |
| 542 | + x.copy_(torch.clamp(x, -self.epsilon, self.epsilon)) |
| 543 | + |
| 544 | + self.attack_res = Data(x=x, edge_index=edge_index, y=y) |
| 545 | + |
500 | 546 | else: # structure attack |
| 547 | + if task_type: |
| 548 | + graph_idx = self.element_idx |
| 549 | + x = gen_dataset.dataset[graph_idx].x.clone() |
| 550 | + edge_index = gen_dataset.dataset[graph_idx].edge_index.clone() |
| 551 | + y = gen_dataset.dataset[graph_idx].y.clone() |
| 552 | + else: |
| 553 | + node_idx = self.element_idx |
| 554 | + x = gen_dataset.data.x.clone() |
| 555 | + edge_index = gen_dataset.data.edge_index.clone() |
| 556 | + y = gen_dataset.data.y.clone() |
| 557 | + |
| 558 | + num_hops = model.n_layers |
| 559 | + subset, edge_index_subset, inv, edge_mask_k_hop = k_hop_subgraph( |
| 560 | + node_idx=node_idx, |
| 561 | + num_hops=num_hops, |
| 562 | + edge_index=edge_index, |
| 563 | + relabel_nodes=True, |
| 564 | + directed=False |
| 565 | + ) |
| 566 | + node_idx_remap = torch.where(subset == node_idx)[0].item() |
| 567 | + x = x[subset] |
| 568 | + y = y[subset] |
| 569 | + |
501 | 570 | if task_type: |
502 | 571 | num_edges = edge_index.size(1) |
503 | 572 | else: |
|
0 commit comments