Skip to content

Commit 1b05bb6

Browse files
committed
fix frontend choose edge
1 parent 5b01aa4 commit 1b05bb6

File tree

19 files changed

+309
-115
lines changed

19 files changed

+309
-115
lines changed

gnn_aid/attacks/evasion_attacks.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from gnn_aid.aux.utils import move_to_same_device
77
from gnn_aid.datasets.gen_dataset import GeneralDataset
88
from gnn_aid.models_builder.model_managers import GNNModelManager
9+
from gnn_aid.data_structures.configs import Task
910

1011
# Nettack imports
1112
from .evasion_attacks_collection.nettack.utils import NettackSurrogate, NettackAttack
@@ -25,7 +26,6 @@
2526
# ReWatt imports
2627
from .evasion_attacks_collection.rewatt.utils import GraphEnvironment, ReWattPolicyNet, \
2728
GraphState, ReWattAgent
28-
from ..data_structures import Task
2929

3030

3131
class EvasionAttacker(
@@ -65,6 +65,9 @@ def check_availability(
6565
model_manager: GNNModelManager
6666
):
6767
""" Availability check for the given dataset and model manager. """
68+
rules = [
69+
gen_dataset.dataset_var_config.task
70+
]
6871
return True
6972

7073
def __init__(
@@ -86,7 +89,7 @@ def attack(
8689
model_manager: Type,
8790
gen_dataset: GeneralDataset,
8891
mask_tensor: torch.Tensor,
89-
task_type: str = None,
92+
task_type: str = None, # FIXME remove
9093
):
9194
task = gen_dataset.dataset_var_config.task
9295
device = gen_dataset.data.x.device
@@ -110,9 +113,9 @@ def attack(
110113
edge_out = model.decode(src, dst).unsqueeze(dim=0).to(device)
111114

112115
# TODO use model_manager.loss_function when BCE support
113-
# loss = model_manager.loss_function(edge_out, edge_label)
114-
criterion = torch.nn.BCEWithLogitsLoss()
115-
loss = criterion(edge_out, edge_label)
116+
loss = model_manager.loss_function(edge_out, edge_label)
117+
# criterion = torch.nn.BCEWithLogitsLoss()
118+
# loss = criterion(edge_out, edge_label)
116119
model.zero_grad()
117120
loss.backward()
118121
sign_data_grad = x.grad.sign()
@@ -425,7 +428,7 @@ def attack(
425428
model_manager: Type,
426429
gen_dataset: GeneralDataset,
427430
mask_tensor: torch.Tensor,
428-
task_type: str = None,
431+
task_type: str = None, # FIXME remove
429432
) -> None:
430433
if task_type is None:
431434
task_type = gen_dataset.is_multi()
@@ -662,7 +665,7 @@ def check_availability(
662665

663666
def __init__(
664667
self,
665-
node_idx: int = 0,
668+
element_idx: int = 0,
666669
budget: Union[int, None] = None,
667670
perturb_features: bool = True,
668671
perturb_structure: bool = True,
@@ -674,7 +677,7 @@ def __init__(
674677
):
675678
super().__init__()
676679
self.attack_diff = None
677-
self.node_idx = node_idx
680+
self.element_idx = element_idx
678681
self.budget = budget
679682
self.perturb_features = perturb_features
680683
self.perturb_structure = perturb_structure
@@ -700,13 +703,13 @@ def attack(
700703
# surrogate.evaluate(x, edge_index, y)
701704

702705
attacker = NettackAttack(
703-
real_class=data.y[self.node_idx].item(),
706+
real_class=data.y[self.element_idx].item(),
704707
gnn_model=model_manager.gnn,
705708
model=surrogate,
706709
x=x,
707710
edge_index=edge_index,
708711
num_classes=num_classes,
709-
target_node=self.node_idx,
712+
target_node=self.element_idx,
710713
attack_diff=self.attack_diff,
711714
direct=self.direct,
712715
depth=self.depth,
@@ -719,14 +722,14 @@ def attack(
719722
elif self.perturb_structure and not self.perturb_features:
720723
mode = "structure"
721724

722-
# logits_before = surrogate.forward(edge_index, x)[self.node_idx]
725+
# logits_before = surrogate.forward(edge_index, x)[self.element_idx]
723726
# pred_before = logits_before.argmax().item()
724727
# prob_before = torch.softmax(logits_before, dim=0)[pred_before].item()
725728
# print(f"Surrogate prediction before attack: {pred_before} (confidence: {prob_before:.4f})")
726729

727730
attacker.attack(budget=self.budget, mode=mode)
728731

729-
# logits_after = surrogate.forward(attacker.edge_index, attacker.x)[self.node_idx]
732+
# logits_after = surrogate.forward(attacker.edge_index, attacker.x)[self.element_idx]
730733
# pred_after = logits_after.argmax().item()
731734
# prob_after = torch.softmax(logits_after, dim=0)[pred_after].item()
732735
# print(f"Surrogate prediction after attack: {pred_after} (confidence: {prob_after:.4f})")

gnn_aid/attacks/rl_s2v/rl_s2v.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
mlp_hidden: int = 64,
2323
max_lv: bool = True,
2424
gm: str = 'mean_field',
25-
node_idx: int = 0
25+
element_idx: int = 0
2626
):
2727
"""
2828
:param num_steps: rl training steps
@@ -35,7 +35,7 @@ def __init__(
3535
:param mlp_hidden: mlp hidden layer size
3636
:param max_lv: max rounds of message passing
3737
:param gm: mean_field/loopy_bp/gcn
38-
:param node_idx: index of node to be attacked
38+
:param element_idx: index of node to be attacked
3939
4040
"""
4141
super().__init__()
@@ -49,7 +49,7 @@ def __init__(
4949
self.mlp_hidden = mlp_hidden
5050
self.max_lv = max_lv
5151
self.gm = gm
52-
self.node_idx = node_idx
52+
self.element_idx = element_idx
5353

5454
self.env = None
5555
self.agent = None
@@ -105,7 +105,7 @@ def setup(
105105

106106
self.env = NodeAttackEnv(gen_dataset=gen_dataset, all_targets=total, list_action_space=dict_of_lists,
107107
classifier=gnn, num_mod=self.num_mod, reward_type=self.reward_type, gm=self.gm)
108-
self.agent = RLS2VAgent(self.env, gen_dataset, node_idx=self.node_idx, idx_test=attack_list, num_wrong=1,
108+
self.agent = RLS2VAgent(self.env, gen_dataset, node_idx=self.element_idx, idx_test=attack_list, num_wrong=1,
109109
list_action_space=dict_of_lists, num_mod=self.num_mod, reward_type=self.reward_type,
110110
batch_size=self.batch_size,
111111
bilin_q=self.bilin_q, embed_dim=self.embed_dim, mlp_hidden=self.mlp_hidden,

gnn_aid/aux/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,13 @@ def setting_class_default_parameters(
125125
f"in def setting_class_default_parameters")
126126
continue
127127
key_1 = class_kwargs_default[key][1]
128-
if key_1 == 'int_or_tuple':
128+
if key_1 == 'int_or_tuple' or key_1 == 'int_or_tuple_or_mask':
129129
try:
130130
class_kwargs[key] = int(val)
131-
except TypeError:
131+
except TypeError: # tuple
132132
class_kwargs[key] = tuple(val)
133+
except ValueError: # string
134+
class_kwargs[key] = val
133135
elif val is None or key_1 == 'string'\
134136
or (key_1 == 'dynamic' and isinstance(val, str))\
135137
or np.isinf(val):
@@ -140,11 +142,13 @@ def setting_class_default_parameters(
140142
if key != TECHNICAL_PARAMETER_KEY and key not in class_kwargs.keys():
141143
if val[2] is None or val[1] == 'string' or val[2] == np.inf:
142144
class_kwargs[key] = val[2]
143-
elif val[1] == 'int_or_tuple':
145+
elif val[1] == 'int_or_tuple' or val[1] == 'int_or_tuple_or_mask':
144146
try:
145147
class_kwargs[key] = int(val[2])
146-
except TypeError:
148+
except TypeError: # tuple
147149
class_kwargs[key] = tuple(val[2])
150+
except ValueError: # string
151+
class_kwargs[key] = val[2]
148152
else:
149153
class_kwargs[key] = locate(val[1])(val[2])
150154

gnn_aid/defenses/evasion_defense.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def pre_batch(
240240
self,
241241
model_manager: Type,
242242
batch,
243-
task_type: str = None,
243+
task_type: str = None, # FIXME remove
244244
**kwargs,
245245
):
246246
super().pre_batch(model_manager=model_manager, batch=batch)
@@ -261,7 +261,7 @@ def is_multi():
261261
model_manager=model_manager,
262262
gen_dataset=self.perturbed_gen_dataset,
263263
mask_tensor=torch.arange(0, batch.batch_size),
264-
task_type=task_type,
264+
# task_type=task_type,
265265
)
266266

267267
def post_batch(

gnn_aid/explainers/gnnexplainer/dig_our/out.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,14 @@ class GNNExplainer(Explainer, ExplainerBase):
4242
def check_availability(gen_dataset, model_manager):
4343
""" Availability check for the given dataset and model manager. """
4444
# Should have at least 1 MessagePassing module
45-
return ({'modules', 'flow', 'get_num_hops', 'parameters', 'forward'}.issubset(dir(model_manager.gnn)) and
46-
any(isinstance(m, MessagePassing) for m in model_manager.gnn.modules()))
45+
rules = [
46+
gen_dataset.dataset_var_config.task.is_node_level(), # FIXME check
47+
{'modules', 'flow', 'get_num_hops', 'parameters', 'forward'}.issubset(dir(model_manager.gnn)),
48+
49+
# Should have at least 1 MessagePassing module
50+
any(isinstance(m, MessagePassing) for m in model_manager.gnn.modules()),
51+
]
52+
return all(rules)
4753

4854
def __init__(self,
4955
gen_dataset,

gnn_aid/explainers/gnnexplainer/torch_geom_our/out.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@ class GNNExplainer(Explainer):
2121
name = 'GNNExplainer(torch-geom)'
2222
availability_profile = ({'single', 'multi'}, {'modules', 'get_num_hops', 'forward'})
2323

24-
@staticmethod
25-
def check_availability(gen_dataset, model_manager):
26-
""" Availability check for the given dataset and model manager. """
27-
# Should have at least 1 MessagePassing module
28-
return\
29-
{'modules', 'get_num_hops', 'forward'}.issubset(dir(model_manager.gnn)) and\
30-
any(isinstance(m, MessagePassing) for m in model_manager.gnn.modules())
31-
3224
coeffs = {
3325
'edge_size': 0.005,
3426
'edge_reduction': 'sum',
@@ -39,6 +31,17 @@ def check_availability(gen_dataset, model_manager):
3931
'EPS': 1e-15,
4032
}
4133

34+
@staticmethod
35+
def check_availability(gen_dataset, model_manager):
36+
""" Availability check for the given dataset and model manager. """
37+
rules = [
38+
{'modules', 'get_num_hops', 'forward'}.issubset(dir(model_manager.gnn)),
39+
40+
# Should have at least 1 MessagePassing module
41+
any(isinstance(m, MessagePassing) for m in model_manager.gnn.modules()),
42+
]
43+
return all(rules)
44+
4245
def __init__(self,
4346
gen_dataset,
4447
model,

gnn_aid/explainers/graphmask/out.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@ class GraphMaskExplainer(Explainer, ExplainerBase):
2525
@staticmethod
2626
def check_availability(gen_dataset, model_manager):
2727
""" Availability check for the given dataset and model manager. """
28-
# Should have at least 1 MessagePassing module
29-
return\
30-
not gen_dataset.is_multi() and\
31-
{'modules', 'flow', 'get_num_hops', 'parameters', 'forward'}.issubset(dir(model_manager.gnn)) and\
32-
any(isinstance(m, MessagePassing) for m in model_manager.gnn.modules())
28+
rules = [
29+
not gen_dataset.is_multi(),
30+
gen_dataset.dataset_var_config.task.is_node_level(),
31+
{'modules', 'flow', 'get_num_hops', 'parameters', 'forward'}.issubset(dir(model_manager.gnn)),
32+
33+
# Should have at least 1 MessagePassing module
34+
any(isinstance(m, MessagePassing) for m in model_manager.gnn.modules()),
35+
]
36+
return all(rules)
3337

3438
def __init__(self,
3539
gen_dataset,

gnn_aid/explainers/pgeexplainer/dig/out.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,14 @@ class PGExplainer(nn.Module, Explainer):
2929
@staticmethod
3030
def check_availability(gen_dataset, model_manager):
3131
""" Availability check for the given dataset and model manager. """
32-
return ({'modules', 'flow', 'get_num_hops', 'parameters', 'forward'}.issubset(dir(model_manager.gnn)) and
33-
any(isinstance(m, MessagePassing) for m in model_manager.gnn.modules()))
32+
rules = [
33+
gen_dataset.dataset_var_config.task.is_node_level(), # FIXME check
34+
{'modules', 'flow', 'get_num_hops', 'parameters', 'forward'}.issubset(dir(model_manager.gnn)),
35+
36+
# Should have at least 1 MessagePassing module
37+
any(isinstance(m, MessagePassing) for m in model_manager.gnn.modules()),
38+
]
39+
return all(rules)
3440

3541
def __init__(self,
3642
gen_dataset,

gnn_aid/explainers/pgmexplainer/out.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,14 @@ class PGMExplainer(Explainer):
2626
@staticmethod
2727
def check_availability(gen_dataset, model_manager):
2828
""" Availability check for the given dataset and model manager. """
29-
return ({'modules', 'flow', 'get_num_hops', 'parameters', 'forward'}.issubset(dir(model_manager.gnn)) and
30-
any(isinstance(m, MessagePassing) for m in model_manager.gnn.modules()))
29+
rules = [
30+
gen_dataset.dataset_var_config.task.is_node_level(), # FIXME check
31+
{'modules', 'flow', 'get_num_hops', 'parameters', 'forward'}.issubset(dir(model_manager.gnn)),
32+
33+
# Should have at least 1 MessagePassing module
34+
any(isinstance(m, MessagePassing) for m in model_manager.gnn.modules()),
35+
]
36+
return all(rules)
3137

3238
def __init__(self,
3339
gen_dataset,

gnn_aid/explainers/subgraphx/out.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,11 @@ class SubgraphXExplainer(Explainer):
101101
@staticmethod
102102
def check_availability(gen_dataset, model_manager):
103103
""" Availability check for the given dataset and model manager. """
104-
return\
105-
{'get_num_hops', 'get_predictions'}.issubset(dir(model_manager.gnn))
104+
rules = [
105+
gen_dataset.dataset_var_config.task.is_node_level(),
106+
{'get_num_hops', 'get_predictions'}.issubset(dir(model_manager.gnn)),
107+
]
108+
return all(rules)
106109

107110
def __init__(self, gen_dataset, model, device,
108111
verbose: bool = False,

0 commit comments

Comments
 (0)