Skip to content

Commit 5b01aa4

Browse files
committed
training edge pred ok
1 parent 7a21730 commit 5b01aa4

File tree

8 files changed

+148
-53
lines changed

8 files changed

+148
-53
lines changed

gnn_aid/attacks/evasion_attacks.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# ReWatt imports
2626
from .evasion_attacks_collection.rewatt.utils import GraphEnvironment, ReWattPolicyNet, \
2727
GraphState, ReWattAgent
28+
from ..data_structures import Task
2829

2930

3031
class EvasionAttacker(
@@ -94,7 +95,7 @@ def attack(
9495
model = model_manager.gnn
9596
model.eval()
9697

97-
if task.is_edge_level():
98+
if task == Task.EDGE_PREDICTION:
9899
data = gen_dataset.data
99100
x = data.x
100101
x.requires_grad = True
@@ -119,7 +120,7 @@ def attack(
119120
perturbed_data_x = torch.clamp(perturbed_data_x, 0, 1)
120121
gen_dataset.data.x = perturbed_data_x.detach()
121122

122-
elif task.is_graph_level() and task.is_classification(): # graph_classification
123+
elif task == Task.GRAPH_CLASSIFICATION:
123124
graph_idx = self.element_idx
124125
x = gen_dataset.dataset[graph_idx].x
125126
x.requires_grad = True
@@ -135,7 +136,7 @@ def attack(
135136
perturbed_data_x = torch.clamp(perturbed_data_x, 0, 1)
136137
gen_dataset.dataset[graph_idx].x = perturbed_data_x.detach()
137138

138-
elif task.is_node_level() and task.is_classification(): # node_classification
139+
elif task == Task.NODE_CLASSIFICATION:
139140
node_idx = self.element_idx
140141
x = gen_dataset.data.x
141142
x.requires_grad = True
@@ -152,7 +153,7 @@ def attack(
152153
gen_dataset.data.x = perturbed_data_x.detach()
153154

154155
else:
155-
pass
156+
raise NotImplementedError
156157

157158
# if task_type:
158159
# gni = GlobalNodeIndexer(gen_dataset.dataset)
@@ -542,6 +543,7 @@ def attack(
542543
x.copy_(torch.clamp(x, -self.epsilon, self.epsilon))
543544

544545
self.attack_res = Data(x=x, edge_index=edge_index, y=y)
546+
gen_dataset.data.x = x # FIXME tmp
545547

546548
else: # structure attack
547549
if task_type:

gnn_aid/datasets/gen_dataset.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,11 @@ def node_features(
204204
if self.dataset is None:
205205
raise RuntimeError(f"Cannot get node features: dataset {self} is not built."
206206
f" Define {DatasetVarConfig.__name__} and call build() method")
207-
if self.is_multi():
208-
return [data.x for data in self.dataset]
209-
else:
210-
return self.dataset[0].x
207+
# if self.is_multi():
208+
# return [data.x for data in self.dataset]
209+
# else:
210+
# return self.dataset[0].x
211+
return self.data.x
211212

212213
@property
213214
def edge_features(
@@ -219,11 +220,12 @@ def edge_features(
219220
raise RuntimeError(f"Cannot get edge features: dataset {self} is not built."
220221
f" Define {DatasetVarConfig.__name__} and call build() method")
221222

222-
if self.is_multi():
223-
return [data.edge_attr for data in self.dataset]
224-
else:
225-
return self.dataset[0].edge_attr
223+
# if self.is_multi():
224+
# return [data.edge_attr for data in self.dataset]
225+
# else:
226+
# return self.dataset[0].edge_attr
226227
# TODO misha implement
228+
return self.data.edge_attr
227229

228230
def __len__(
229231
self

gnn_aid/models_builder/model_managers/framework_mm.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Union, List, cast, Iterable
55

66
import torch
7-
from torch import tensor
87
from torch.cuda import is_available
98
from torch.nn.utils import clip_grad_norm
109
from torch_geometric.loader import NeighborLoader, DataLoader, LinkNeighborLoader
@@ -16,7 +15,7 @@
1615
from gnn_aid.data_structures import Task, GraphModificationArtifact
1716
from gnn_aid.data_structures.configs import ConfigPattern, CONFIG_OBJ
1817
from gnn_aid.datasets import GeneralDataset
19-
from gnn_aid.models_builder.models_utils import Metric, predict_top_k_edges
18+
from gnn_aid.models_builder.models_utils import Metric, predict_top_k_edges, mask_to_tensor
2019
from . import GNNModelManager
2120

2221

@@ -751,28 +750,3 @@ def load_train_test_split(
751750
return gen_dataset
752751

753752

754-
def mask_to_tensor(
755-
gen_dataset: GeneralDataset,
756-
mask: Union[str, List[bool], torch.Tensor] = 'test'
757-
) -> torch.Tensor:
758-
"""
759-
Convert mask over nodes/edges/graphs to tensor.
760-
Mask can be 'train', 'val', 'test', 'all', or Tensor of specific nodes/edges/graphs.
761-
762-
:param gen_dataset: dataset
763-
:param mask: part of the dataset on which the output will be obtained.
764-
'train', 'val', 'test', 'all', or Tensor of specific nodes/edges/graphs
765-
:return: tensor of nodes/edges/graphs
766-
"""
767-
try:
768-
mask_tensor = {
769-
'train': gen_dataset.train_mask,
770-
'val': gen_dataset.val_mask,
771-
'test': gen_dataset.test_mask,
772-
'all': tensor([True] * len(gen_dataset.labels)),
773-
}[mask]
774-
except KeyError:
775-
assert isinstance(mask, torch.Tensor)
776-
mask_tensor = mask
777-
778-
return mask_tensor

gnn_aid/models_builder/models_utils.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, List, Callable, Union
1+
from typing import Any, Optional, List, Callable, Union, Tuple
22
from typing import Callable
33

44
import sklearn.metrics
@@ -8,6 +8,8 @@
88
from torch.utils.hooks import RemovableHandle
99
from torch_geometric.nn import MessagePassing
1010

11+
from gnn_aid.datasets import GeneralDataset
12+
1113

1214
def apply_message_gradient_capture(
1315
layer: Any,
@@ -694,4 +696,52 @@ def _predict_full_enumeration(model, data, h, existing_set, k, is_directed, remo
694696

695697
print(f"Top-{final_k} edges found with scores from {top_scores[-1]:.4f} to {top_scores[0]:.4f}")
696698

697-
return top_edges.cpu(), top_scores.cpu()
699+
return top_edges.cpu(), top_scores.cpu()
700+
701+
702+
def mask_to_tensor(
703+
gen_dataset: GeneralDataset,
704+
mask: Union[str, int, Tuple[int], list, torch.Tensor] = 'test'
705+
) -> torch.Tensor:
706+
"""
707+
Convert a mask over nodes/edges/graphs to tensor of specific size.
708+
Mask can be 'train', 'val', 'test', 'all', or id, or a list of ids, or a tensor.
709+
710+
:param gen_dataset: dataset
711+
:param mask: part of the dataset on which the output will be obtained.
712+
Can be a node id, graph id, or edge as a tuple (i,j).
713+
Can be string: 'train', 'val', 'test', 'all'.
714+
Can be Tensor of specific nodes/edges/graphs.
715+
:return: tensor of nodes/edges/graphs
716+
"""
717+
task = gen_dataset.dataset_var_config.task
718+
719+
if isinstance(mask, str):
720+
mask_tensor = {
721+
'train': gen_dataset.train_mask,
722+
'val': gen_dataset.val_mask,
723+
'test': gen_dataset.test_mask,
724+
'all': tensor([True] * len(gen_dataset.labels)),
725+
}[mask]
726+
727+
elif isinstance(mask, torch.Tensor):
728+
mask_tensor = mask
729+
730+
elif task.is_node_level(): # Node id
731+
assert not gen_dataset.is_multi()
732+
mask_tensor = tensor([False] * gen_dataset.info.nodes[0])
733+
mask_tensor[mask] = True # for int or list of ints
734+
735+
elif task.is_graph_level(): # Graph id
736+
assert gen_dataset.is_multi()
737+
mask_tensor = tensor([False] * len(gen_dataset.info.nodes))
738+
mask_tensor[mask] = True # for int or list of ints
739+
740+
elif task.is_edge_level(): # Edge
741+
# isinstance(mask, tuple)
742+
raise NotImplementedError
743+
744+
else:
745+
raise RuntimeError(f"Cannot infer mask tensor for given mask {mask}.")
746+
747+
return mask_tensor

web_interface/back_front/attack_defense_blocks.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55

66
from gnn_aid.attacks import MIAttacker
77
from gnn_aid.aux import DataInfo
8-
from gnn_aid.data_structures.configs import ConfigPattern, PoisonAttackConfig, PoisonDefenseConfig, EvasionAttackConfig, \
9-
EvasionDefenseConfig, MIAttackConfig, MIDefenseConfig
8+
from gnn_aid.data_structures.configs import ConfigPattern, PoisonAttackConfig, PoisonDefenseConfig, \
9+
EvasionAttackConfig, \
10+
EvasionDefenseConfig, MIAttackConfig, MIDefenseConfig, CONFIG_OBJ
1011
from gnn_aid.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, \
1112
EVASION_ATTACK_PARAMETERS_PATH, EVASION_DEFENSE_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH, \
1213
MI_DEFENSE_PARAMETERS_PATH
1314
from gnn_aid.datasets.gen_dataset import GeneralDataset
1415
from gnn_aid.models_builder import Metric
1516
from gnn_aid.models_builder.attack_defense_manager import FrameworkAttackDefenseManager
1617
from gnn_aid.models_builder.model_managers import GNNModelManager
18+
from gnn_aid.models_builder.models_utils import mask_to_tensor
1719
from . import VisiblePart
1820
from .block import Block
1921
from .utils import WebInterfaceError, send_epoch_results, compute_stats_data
@@ -212,11 +214,12 @@ def do(
212214

213215
# Apply evasion attack
214216
if self.ad_configs["AD-ea"] is not None:
215-
# attack_mask =
217+
element_idx = getattr(self.ad_configs["AD-ea"], CONFIG_OBJ).element_idx
218+
attack_mask = mask_to_tensor(self.gen_dataset, element_idx)
216219

217220
self.model_manager.call_evasion_attack(
218221
gen_dataset=self.gen_dataset,
219-
mask=torch.empty(1),
222+
mask=attack_mask,
220223
)
221224

222225
# Evaluate metrics without attacks
@@ -253,6 +256,7 @@ def do(
253256

254257
# Update dataset features
255258
dvd = self.visible_part.get_dataset_var_data()
259+
print('dvd after attack', dvd.node['features'])
256260

257261
dvd = add_into_dvd(self.gen_dataset, stats_data, dvd)
258262

web_interface/main_aiohttp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
from web_interface.back_front.utils import WebInterfaceError, json_loads, json_dumps, SocketConnect
1717

1818
# Socket.IO server (ASGI not used here)
19-
sio = socketio.AsyncServer(async_mode="aiohttp")
19+
sio = socketio.AsyncServer(
20+
async_mode="aiohttp",
21+
ping_timeout=600, # wait pong from client
22+
ping_interval=25,
23+
cors_allowed_origins='*'
24+
)
2025
app = web.Application()
2126
sio.attach(app)
2227

web_interface/static/js/controllers/controller.js

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,19 @@ class Controller {
1212
reconnectionAttempts: Infinity,
1313
reconnectionDelay: 1000,
1414
query: {mode: mode},
15+
16+
// Longer timeout for backend debug - 10 mins
17+
timeout: 600*1000,
18+
pingTimeout: 600*1000,
19+
pingInterval: 600*1000,
1520
})
1621
this.socket.on('connect', () => {
1722
console.log('socket connected, sid=', this.socket.id)
1823
this.sid = this.socket.id
1924
if (this.isActive) {
25+
console.log('This is a reconnection')
2026
// FIXME seems sometimes it happens when backend is busy - we don't need to reload?
21-
// Means re-connection to server. Need to reload the page
27+
// // Means re-connection to server. Need to reload the page
2228
alert("Web-socket connection is lost. Press OK to reload the page.")
2329
this.isActive = false
2430
window.location.reload(true)
@@ -49,14 +55,20 @@ class Controller {
4955
console.log('socket error', err)
5056
})
5157

52-
this.socket.on('disconnect', () => {
58+
this.socket.on('disconnect', (reason) => {
5359
// this.isActive = false
54-
console.log('Disconnected from server, sid=', this.socket.id);
60+
console.log('Disconnected from server, reason=', reason, ', sid=', this.socket.id);
5561
if (this.isActive) {
56-
// Show a notification to the user
57-
alert("Web-socket connection is lost. Press OK to reload the page.")
58-
this.isActive = false
59-
window.location.reload(true)
62+
const softDisconnectReasons = ['ping timeout', 'transport error'];
63+
if (softDisconnectReasons.includes(reason)) {
64+
console.log('Soft disconnect, will try to reconnect...');
65+
// this.reconnectStartTime = Date.now()
66+
} else {
67+
// Show a notification to the user
68+
alert("Web-socket connection is lost. Press OK to reload the page.")
69+
this.isActive = false
70+
window.location.reload(true)
71+
}
6072
}
6173
})
6274

web_interface/static/js/paramsBuilder.js

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,52 @@ class ParamsBuilder {
125125
$input.change(() => this.kwArgs[name] = $input.is(":checked"))
126126
}
127127

128+
else if (type === "int_or_tuple") {
129+
// fixme this is a copy. todo tuple for edge task
130+
$input.attr("type", "number")
131+
$input.val(Number.isFinite(def) ? def : possible["min"])
132+
let checkClass = id + "-radio"
133+
if ("special" in possible) {
134+
// Add special values as separate checkboxes
135+
let checkId = id + "-check"
136+
for (const variant of possible.special) {
137+
let $checkBox = $("<input>").attr("id", checkId)
138+
.attr("type", "checkbox").prop('checked', variant === def)
139+
$checkBox.addClass(checkClass)
140+
$cb.append($checkBox)
141+
let $label = $("<label></label>").text(variant == null ? "None" : variant)
142+
.attr("for", checkId)
143+
$cb.append($label)
144+
$checkBox.change((e) => { // Uncheck all but this
145+
let wasChecked = $checkBox.is(":checked")
146+
$("." + checkClass).prop("checked", false)
147+
$checkBox.prop("checked", true)
148+
this.kwArgs[name] = variant
149+
})
150+
}
151+
$input.focus(() => {
152+
$("." + checkClass).prop("checked", false)
153+
$input.trigger("change")
154+
})
155+
$input.css("min-width", "60px")
156+
delete possible.special
157+
}
158+
159+
if (type === "int") {
160+
$input.attr("step", 1)
161+
$input.attr("pattern", "\d+")
162+
$input.change(() => this.kwArgs[name] = parseInt($input.val()))
163+
}
164+
else {
165+
// fixme
166+
}
167+
for (const [key, value] of Object.entries(possible))
168+
$input.attr(key, value)
169+
170+
// Check input value when user unfocus it or change it
171+
addValueChecker($input, type, def, possible["min"], possible["max"], "change")
172+
}
173+
128174
else if (type === "int" || type === "float") {
129175
$input.attr("type", "number")
130176
$input.val(Number.isFinite(def) ? def : possible["min"])

0 commit comments

Comments
 (0)