|
1 | 1 | import torch |
2 | 2 | import numpy as np |
| 3 | +from typing import Optional |
3 | 4 |
|
4 | 5 | from gnn_aid.datasets.gen_dataset import GeneralDataset |
5 | 6 | from gnn_aid.data_structures.graph_modification_artifacts import GraphModificationArtifact |
6 | 7 | from gnn_aid.defenses.poison_defense import PoisonDefender |
| 8 | +from gnn_aid.data_structures.configs import Task |
7 | 9 |
|
8 | 10 |
|
9 | | -class JaccardDefender( |
10 | | - PoisonDefender |
11 | | -): |
| 11 | +def _is_binary_tensor(X: torch.Tensor) -> bool: |
| 12 | + return torch.all((X == 0) | (X == 1)).item() |
| 13 | + |
| 14 | + |
| 15 | +class JaccardDefender(PoisonDefender): |
12 | 16 | """ |
13 | | - Poison defense based on removing edges between dissimilar nodes |
| 17 | + Poison defense based on removing edges between dissimilar nodes. |
14 | 18 | """ |
15 | 19 | name = 'JaccardDefender' |
16 | 20 |
|
17 | | - def __init__(self, threshold): |
| 21 | + def __init__(self, threshold: float, binarize_threshold: Optional[float] = None): |
| 22 | + """ |
| 23 | + :param threshold: Jaccard similarity threshold (edges with similarity <= threshold are removed) |
| 24 | + :param binarize_threshold: Optional threshold to binarize non-binary features |
| 25 | + """ |
18 | 26 | super().__init__() |
19 | | - self.thrsh = threshold |
20 | | - self.remove_edge_index = None |
| 27 | + self.threshold = threshold |
| 28 | + self.binarize_threshold = binarize_threshold |
| 29 | + self.removed_edges_train = None |
| 30 | + self.original_num_edges = None |
21 | 31 |
|
22 | 32 | def defense( |
23 | 33 | self, |
24 | 34 | gen_dataset: GeneralDataset, |
25 | 35 | **kwargs |
26 | 36 | ) -> GeneralDataset: |
27 | | - """ |
28 | | - Modify input graph by removing edges between dissimilar nodes |
29 | | - :param gen_dataset: input graph dataset |
30 | | - :return: modified graph (only adjacency matrix modified) |
31 | | - """ |
| 37 | + task = gen_dataset.dataset_var_config.task |
| 38 | + |
| 39 | + if task in [Task.EDGE_PREDICTION, Task.EDGE_REGRESSION]: |
| 40 | + if not hasattr(gen_dataset, 'train_mask') or gen_dataset.train_mask is None: |
| 41 | + raise RuntimeError("JaccardDefender for link tasks requires train_test_split() to be called first") |
| 42 | + |
| 43 | + self.original_num_edges = gen_dataset.data.edge_index.size(1) |
| 44 | + |
| 45 | + x = self._prepare_features(gen_dataset.data.x) |
| 46 | + |
| 47 | + if task in [Task.EDGE_PREDICTION, Task.EDGE_REGRESSION]: |
| 48 | + gen_dataset = self._defense_link_task(gen_dataset, x) |
| 49 | + else: |
| 50 | + gen_dataset = self._defense_standard_task(gen_dataset, x) |
32 | 51 |
|
33 | | - def is_binary_tensor(X: torch.Tensor) -> bool: |
34 | | - return torch.all((X == 0) | (X == 1)).item() |
35 | | - |
36 | | - assert is_binary_tensor(gen_dataset.data.x), "The features should be presented in binary form" |
37 | | - |
38 | | - # TODO need to check whether features binary or not. Consistency required - Cora has 'unknown' features e.g. |
39 | | - # self.drop_edges(batch) |
40 | | - edge_index = gen_dataset.data.edge_index.tolist() |
41 | | - #new_edge_mask = torch.zeros_like(gen_dataset.data.edge_index).bool() |
42 | | - new_edge_index = [[],[]] |
43 | | - self.remove_edge_index = [[], []] |
44 | | - for i in range(len(edge_index[0])): |
45 | | - if self.jaccard_index(gen_dataset.data.x, edge_index[0][i], edge_index[1][i]) > self.thrsh: |
46 | | - # new_edge_mask[0,i] = True |
47 | | - # new_edge_mask[1,i] = True |
48 | | - new_edge_index[0].append(edge_index[0][i]) |
49 | | - new_edge_index[1].append(edge_index[1][i]) |
50 | | - else: |
51 | | - self.remove_edge_index[0].append(edge_index[0][i]) |
52 | | - self.remove_edge_index[1].append(edge_index[1][i]) |
53 | | - # gen_dataset.data.edge_index *= new_edge_mask.float() |
54 | | - gen_dataset.data.edge_index = torch.tensor(new_edge_index).long() |
55 | 52 | return gen_dataset |
56 | 53 |
|
57 | | - def jaccard_index( |
| 54 | + def _prepare_features(self, x: torch.Tensor) -> torch.Tensor: |
| 55 | + if self.binarize_threshold is not None: |
| 56 | + x = (x > self.binarize_threshold).float() |
| 57 | + elif not _is_binary_tensor(x): |
| 58 | + raise ValueError( |
| 59 | + "JaccardDefender requires binary features" |
| 60 | + ) |
| 61 | + return x |
58 | 62 |
|
| 63 | + def _defense_link_task( |
59 | 64 | self, |
60 | | - x, |
61 | | - u, |
62 | | - v |
63 | | - ) -> float: |
64 | | - """ |
65 | | - Computes jaccard index of 'u' and 'v' objects based on their features |
66 | | - :param x: feature matrix |
67 | | - :param u: index of object from dataset |
68 | | - :param v: index of object from dataset |
69 | | - :return: |
70 | | - """ |
71 | | - im1 = x[u,:].detach().cpu().numpy().astype(bool) |
72 | | - im2 = x[v,:].detach().cpu().numpy().astype(bool) |
73 | | - intersection = np.logical_and(im1, im2) |
74 | | - union = np.logical_or(im1, im2) |
75 | | - return intersection.sum() / float(union.sum()) |
76 | | - |
77 | | - def dataset_diff( |
78 | | - self |
79 | | - ) -> GraphModificationArtifact: |
80 | | - diff = GraphModificationArtifact() |
| 65 | + gen_dataset: GeneralDataset, |
| 66 | + x: torch.Tensor |
| 67 | + ) -> GeneralDataset: |
| 68 | + train_edge_label_index = gen_dataset.edge_label_index[:, gen_dataset.train_mask] |
81 | 69 |
|
82 | | - try: |
83 | | - src_nodes = self.remove_edge_index[0] |
84 | | - dst_nodes = self.remove_edge_index[1] |
| 70 | + filtered_train_edges, removed_edges = self._filter_edges_jaccard(train_edge_label_index, x) |
| 71 | + self.removed_edges_train = removed_edges |
85 | 72 |
|
86 | | - assert len(src_nodes) == len(dst_nodes), ( |
87 | | - "Mismatch in source and target edge lengths: " |
88 | | - f"{len(src_nodes)} vs {len(dst_nodes)}" |
89 | | - ) |
| 73 | + gen_dataset.data.edge_index = filtered_train_edges |
90 | 74 |
|
91 | | - edges_to_remove = [ |
92 | | - [src, dst] for src, dst in zip(src_nodes, dst_nodes) |
93 | | - ] |
| 75 | + num_removed = removed_edges.size(1) if removed_edges is not None else 0 |
| 76 | + print(f"JaccardDefender: Removed {num_removed}/{train_edge_label_index.size(1)} " |
| 77 | + f"training edges (threshold={self.threshold})") |
94 | 78 |
|
| 79 | + return gen_dataset |
| 80 | + |
| 81 | + def _defense_standard_task( |
| 82 | + self, |
| 83 | + gen_dataset: GeneralDataset, |
| 84 | + x: torch.Tensor |
| 85 | + ) -> GeneralDataset: |
| 86 | + filtered_edges, removed_edges = self._filter_edges_jaccard( |
| 87 | + gen_dataset.data.edge_index, x |
| 88 | + ) |
| 89 | + self.removed_edges_train = removed_edges # Reusing field for simplicity |
| 90 | + |
| 91 | + gen_dataset.data.edge_index = filtered_edges |
| 92 | + |
| 93 | + num_removed = removed_edges.size(1) if removed_edges is not None else 0 |
| 94 | + print(f"JaccardDefender: Removed {num_removed}/{self.original_num_edges} edges " |
| 95 | + f"(threshold={self.threshold})") |
| 96 | + |
| 97 | + return gen_dataset |
| 98 | + |
| 99 | + def _filter_edges_jaccard( |
| 100 | + self, |
| 101 | + edge_index: torch.Tensor, |
| 102 | + x: torch.Tensor |
| 103 | + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| 104 | + if edge_index.size(1) == 0: |
| 105 | + return edge_index, None |
| 106 | + |
| 107 | + src_feats = x[edge_index[0]] |
| 108 | + dst_feats = x[edge_index[1]] |
| 109 | + |
| 110 | + intersection = (src_feats * dst_feats).sum(dim=1) # AND |
| 111 | + union = ((src_feats + dst_feats) > 0).sum(dim=1).float() # OR |
| 112 | + |
| 113 | + union = torch.where(union == 0, torch.ones_like(union), union) |
| 114 | + |
| 115 | + jaccard_scores = intersection / union |
| 116 | + |
| 117 | + keep_mask = jaccard_scores > self.threshold |
| 118 | + filtered_edges = edge_index[:, keep_mask] |
| 119 | + removed_edges = edge_index[:, ~keep_mask] if (~keep_mask).any() else None |
| 120 | + |
| 121 | + return filtered_edges, removed_edges |
| 122 | + |
| 123 | + def dataset_diff(self) -> GraphModificationArtifact: |
| 124 | + diff = GraphModificationArtifact() |
| 125 | + |
| 126 | + if self.removed_edges_train is not None and self.removed_edges_train.size(1) > 0: |
| 127 | + edges_to_remove = self.removed_edges_train.t().tolist() |
95 | 128 | diff.remove_edges(edges_to_remove) |
96 | 129 | self.defense_diff = diff |
97 | | - |
98 | | - except Exception as e: |
99 | | - raise RuntimeError( |
100 | | - f"Failed to build dataset diff from remove_edge_index: {e}" |
101 | | - ) from e |
| 130 | + else: |
| 131 | + # No edges removed |
| 132 | + self.defense_diff = diff |
102 | 133 |
|
103 | 134 | return self.defense_diff |
0 commit comments