Skip to content

Commit f4a5c29

Browse files
committed
2,4,7 fixes
1 parent cb4de1f commit f4a5c29

File tree

12 files changed

+211
-31
lines changed

12 files changed

+211
-31
lines changed

src/attacks/mi_attacks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
from typing import Union, List
2+
from typing import Union, List, Dict
33

44
import torch
55
from sklearn.metrics import accuracy_score
@@ -30,7 +30,7 @@ def compute_single_attack_accuracy(
3030
inferred_labels: torch.Tensor,
3131
mask_true: torch.Tensor,
3232
train_class_label: bool = True
33-
) -> float:
33+
) -> Union[float, Dict]:
3434
"""
3535
Computes accuracy for a single attack result (mask + inferred labels pair).
3636
@@ -40,7 +40,7 @@ def compute_single_attack_accuracy(
4040
mask_true: Tensor of true labels for all nodes in the graph
4141
4242
Returns:
43-
float: Accuracy (0.0 to 1.0) of correct predictions among attacked samples
43+
float: Dict with metrics for predictions among attacked samples.
4444
Returns 0.0 if no samples were attacked
4545
"""
4646
metrics = {
@@ -83,6 +83,7 @@ def compute_single_attack_accuracy(
8383

8484
return metrics
8585

86+
8687
class EmptyMIAttacker(
8788
MIAttacker
8889
):

src/aux/data_info.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,9 @@ def values_list_and_technical_files_by_path_and_prefix(
188188
for file_info_dict in val["files_info"]:
189189
if file_info_dict["file_name"] == "origin":
190190
file_name = path[parts_parse].strip()
191-
else:
192-
file_name = file_info_dict["file_name"]
193-
file_name += file_info_dict["format"]
194-
description_info.update(
195-
{key: {parts_val[-1]: os.path.join(*path[:parts_parse], file_name)}})
191+
file_name += file_info_dict["format"]
192+
description_info.update(
193+
{key: {parts_val[-1]: os.path.join(*path[:parts_parse], file_name)}})
196194
parts_parse += 1
197195
return parts_val, description_info
198196

src/models_builder/gnn_models.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,8 @@ def train_model(
11291129
mode: Union[str, None] = None,
11301130
steps=None,
11311131
metrics: List[Metric] = None,
1132-
socket: SocketConnect = None
1132+
socket: SocketConnect = None,
1133+
apply_posisoning_ad: bool = True
11331134
) -> Union[str, Path]:
11341135
"""
11351136
Convenient train method.
@@ -1140,14 +1141,16 @@ def train_model(
11401141
:param steps: train specific number of epochs, if None - all of them
11411142
:param metrics: list of metrics to measure at each step or at the end of training
11421143
:param socket: socket to use for sending data to frontend
1144+
:param apply_posisoning_ad: if True, apply posisoning attack and defense before the training
11431145
"""
11441146
from explainers.explainer import ProgressBar
1145-
gen_dataset = self.load_or_execute_poisoning_attack(
1146-
gen_dataset=gen_dataset
1147-
)
1148-
gen_dataset = self.load_or_execute_poisoning_defense(
1149-
gen_dataset=gen_dataset
1150-
)
1147+
if apply_posisoning_ad:
1148+
gen_dataset = self.load_or_execute_poisoning_attack(
1149+
gen_dataset=gen_dataset
1150+
)
1151+
gen_dataset = self.load_or_execute_poisoning_defense(
1152+
gen_dataset=gen_dataset
1153+
)
11511154

11521155
self.socket = socket
11531156
pbar = ProgressBar(self.socket, "mt")

web_interface/back_front/attack_defense_blocks.py

Lines changed: 98 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import json
2+
from copy import deepcopy
23

4+
from attacks.mi_attacks import MIAttacker
5+
from aux.data_info import DataInfo
36
from data_structures.configs import ConfigPattern, PoisonAttackConfig, PoisonDefenseConfig, EvasionAttackConfig, \
47
EvasionDefenseConfig, MIAttackConfig, MIDefenseConfig
58
from aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, \
69
EVASION_ATTACK_PARAMETERS_PATH, EVASION_DEFENSE_PARAMETERS_PATH, MI_ATTACK_PARAMETERS_PATH, \
710
MI_DEFENSE_PARAMETERS_PATH
811
from datasets.gen_dataset import GeneralDataset
912
from models_builder.attack_defense_manager import FrameworkAttackDefenseManager
10-
from models_builder.gnn_models import GNNModelManager
13+
from models_builder.gnn_models import GNNModelManager, Metric
1114
from web_interface.back_front.block import Block
1215
from web_interface.back_front.utils import WebInterfaceError
1316

@@ -130,7 +133,7 @@ def __init__(
130133
):
131134
super().__init__(*args, **kwargs)
132135

133-
self.gen_dataset = None
136+
self.gen_dataset: GeneralDataset = None
134137
self.model_manager: GNNModelManager = None
135138
self.metrics: list = None
136139

@@ -139,6 +142,9 @@ def __init__(
139142
"AD-ma": None,
140143
}
141144

145+
# Copy of the dataset before attacks applied.
146+
self._gen_dataset_backup: GeneralDataset = None
147+
142148
def _init(
143149
self,
144150
gen_dataset: GeneralDataset,
@@ -150,27 +156,75 @@ def _init(
150156
return FrameworkAttackDefenseManager.available_ad_methods(
151157
self.gen_dataset, self.model_manager)
152158

159+
def _finalize(
160+
self
161+
) -> bool:
162+
# Make a dataset backup
163+
if self._gen_dataset_backup is None:
164+
# Make a dataset backup
165+
# FIXME This is a bad way - for large datasets very bad. It is a temporary solution
166+
self._gen_dataset_backup = deepcopy(self.gen_dataset)
167+
else:
168+
# Restore dataset
169+
self._restore_dataset()
170+
return True
171+
172+
def _clear_configs(
173+
self
174+
) -> None:
175+
self.ad_configs = {
176+
"AD-ea": None,
177+
"AD-ma": None,
178+
}
179+
153180
def do(
154181
self,
155182
do,
156183
params
157184
) -> str:
158185
if do == "run with attacks":
186+
# Effect of pressing 'accept'
187+
self._finalize()
188+
self._is_set = True # to make diagram call unlock() when we break this block
189+
190+
self._clear_configs()
159191
for name, config in json.loads(params.get('configs')).items():
160192
# FIXME check config
161193
self.ad_configs[name] = ConfigPattern(
162194
**config,
163195
_import_path=NAME_TO_PATH[name],
164196
_config_class=NAME_TO_CLASS[name])
165197

166-
if self.ad_configs["AD-ea"]:
167-
self.model_manager.set_evasion_attacker(self.ad_configs["AD-ea"])
168-
if self.ad_configs["AD-ma"]:
169-
self.model_manager.set_mi_attacker(self.ad_configs["AD-ma"])
198+
self.model_manager.set_evasion_attacker(self.ad_configs["AD-ea"])
199+
self.model_manager.set_mi_attacker(self.ad_configs["AD-ma"])
200+
201+
# Apply attacks
202+
metrics_values = {}
203+
metrics_values['After attacks'] = self.model_manager.evaluate_model(
204+
gen_dataset=self.gen_dataset, metrics=self.metrics)
170205

171-
metrics_values = self.model_manager.evaluate_model(
172-
self.gen_dataset, metrics=self.metrics)
173-
self.model_manager.compute_stats_data(self.gen_dataset, predictions=True, logits=True)
206+
# Get MI metrics
207+
import numpy as np
208+
assert not self.gen_dataset.is_multi()
209+
target_list = np.random.choice(
210+
self.gen_dataset.info.nodes[0], size=100, replace=False)
211+
mask_loc = Metric.create_mask_by_target_list(
212+
y_true=self.gen_dataset.labels, target_list=target_list)
213+
# self.model_manager.evaluate_model(
214+
# gen_dataset=self.gen_dataset,
215+
# metrics=[Metric("F1", mask=mask_loc, average='macro')])
216+
# Apply MI attack on a special mask
217+
self.model_manager.mi_attacker.attack(
218+
gen_dataset=self.gen_dataset, model=self.model_manager.gnn,
219+
mask_tensor=mask_loc)
220+
res = self.model_manager.mi_attacker.results.get(mask_loc)
221+
if res is not None:
222+
metrics_values['MI attack results'] = MIAttacker.compute_single_attack_accuracy(
223+
mask_loc, res, self.gen_dataset.train_mask)
224+
225+
# Update model logits and predictions
226+
self.model_manager.compute_stats_data(
227+
gen_dataset=self.gen_dataset, predictions=True, logits=True)
174228

175229
# print("metrics_values after attacks", metrics_values)
176230
stats_data = {k: self.gen_dataset.visible_part.filter(v)
@@ -179,6 +233,41 @@ def do(
179233
metrics_values=metrics_values, stats_data=stats_data, socket=self.socket)
180234
return ''
181235

236+
elif do == "save attack configs":
237+
# We want to save the given config
238+
self._clear_configs()
239+
for name, config in json.loads(params.get('configs')).items():
240+
# FIXME check config
241+
self.ad_configs[name] = ConfigPattern(
242+
**config,
243+
_import_path=NAME_TO_PATH[name],
244+
_config_class=NAME_TO_CLASS[name])
245+
return self._save_attack_confgis()
246+
182247
else:
183248
raise WebInterfaceError(f"Unknown 'do' command '{do}' for model")
184249

250+
def _save_attack_confgis(
251+
self
252+
) -> str:
253+
# FIXME discuss scenario with Kirill
254+
# no sense to save model, only configs
255+
path = self.model_manager.save_model_executor()
256+
self.gen_dataset.save_train_test_mask(path)
257+
DataInfo.refresh_models_dir_structure()
258+
return str(path)
259+
260+
def _unlock(
261+
self
262+
) -> None:
263+
# Retract changes - reset dataset as before evasion attacks
264+
# and remove attacks from model manager
265+
self._restore_dataset()
266+
self.model_manager.set_evasion_attacker(None)
267+
self.model_manager.set_mi_attacker(None)
268+
269+
def _restore_dataset(
270+
self
271+
) -> None:
272+
# FIXME This is a bad way - for large datasets very bad. It is a temporary solution
273+
self.gen_dataset = deepcopy(self._gen_dataset_backup)

web_interface/back_front/block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def finalize(
135135

136136
def _finalize(
137137
self
138-
) -> None:
138+
) -> bool:
139139
""" Checks whether the config is correct to create the object.
140140
Returns True if OK or False.
141141
# TODO can we send to front errors to be fixed?

web_interface/back_front/model_blocks.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
from copy import deepcopy
34
from pathlib import Path
45
from typing import Union
56

@@ -341,10 +342,13 @@ def __init__(
341342
):
342343
super().__init__(*args, **kwargs)
343344

344-
self.gen_dataset = None
345+
self.gen_dataset: GeneralDataset = None
345346
self.model_manager = None
346347
self.metrics = None
347348

349+
# Copy of the dataset before attacks applied.
350+
self._gen_dataset_backup: GeneralDataset = None
351+
348352
def _init(
349353
self,
350354
gen_dataset: GeneralDataset,
@@ -367,8 +371,21 @@ def _submit(
367371
self
368372
) -> None:
369373
self.metrics = [Metric(**m) for m in self._config.get('metrics')]
370-
371374
self._object = [self.model_manager, self.metrics]
375+
self._save_model()
376+
377+
# Make a dataset backup
378+
if self._gen_dataset_backup is None:
379+
# Make a dataset backup
380+
# FIXME This is a bad way - for large datasets very bad. It is a temporary solution
381+
self._gen_dataset_backup = deepcopy(self.gen_dataset)
382+
383+
def _unlock(
384+
self
385+
) -> None:
386+
# Retract changes - reset dataset as before evasion attacks
387+
# FIXME This is a bad way - for large datasets very bad. It is a temporary solution
388+
self.gen_dataset = deepcopy(self._gen_dataset_backup)
372389

373390
def do(
374391
self,
@@ -443,9 +460,12 @@ def _train_model(
443460
mode: Union[str, None],
444461
steps: Union[int, None]
445462
) -> None:
463+
apply_posisoning_ad = True if self.model_manager.modification.epochs == 0 else False
446464
self.model_manager.train_model(
447465
gen_dataset=self.gen_dataset, save_model_flag=False,
448-
mode=mode, steps=steps, metrics=self.metrics, socket=self.socket)
466+
mode=mode, steps=steps, metrics=self.metrics, socket=self.socket,
467+
apply_posisoning_ad=apply_posisoning_ad)
468+
self.socket.send(block='mt', msg={"info": "training-finished"})
449469

450470
def _save_model(
451471
self

web_interface/main_aiohttp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def worker_process(
259259
result = json_dumps(client.mcustomBlock.get_index())
260260
elif do in ['train', 'reset', 'run', 'save']:
261261
result = client.mtBlock.do(do, args)
262-
elif do in ['run with attacks']:
262+
elif do in ['run with attacks', 'save attack configs']:
263263
result = client.atBlock.do(do, args)
264264
else:
265265
raise WebInterfaceError(f"Unknown do command: '{do}'")

web_interface/static/css/controls.css

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
border-bottom: 2px solid;
7373
padding: 4px;
7474
font-size: 14pt;
75-
display: flex;
75+
display: grid;
7676
align-items: center;
7777
}
7878

web_interface/static/css/view.css

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,23 @@ div.disabled * {
127127
-ms-user-select: none;
128128
user-select: none;
129129
}
130+
131+
details.notAllowed {
132+
cursor: not-allowed;
133+
}
134+
details.wait {
135+
cursor: wait;
136+
}
137+
details.disabled {
138+
opacity: 0.6;
139+
}
140+
details.disabled *:focus {
141+
outline: 0;
142+
}
143+
details.disabled * {
144+
pointer-events: none;
145+
-webkit-user-select: none;
146+
-moz-user-select: none;
147+
-ms-user-select: none;
148+
user-select: none;
149+
}

web_interface/static/js/presentation/left_menu/model/menuAfterTrainView.js

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ class MenuAfterTrainView extends MenuView {
1515
this.$methodSelects = Object.fromEntries(MenuAfterTrainView.names.map(x => [x, null]))
1616

1717
// Buttons
18-
this.$run = null
18+
this.$run = null // Run with attacks
19+
this.$save = null
1920
// this.$reset = null
2021
}
2122

@@ -51,6 +52,16 @@ class MenuAfterTrainView extends MenuView {
5152
this.$acceptDiv.find('button').prop("disabled", false)
5253
}
5354

55+
async onsave() {
56+
let configs = this._getConfigs()
57+
this.$acceptDiv.find('button').prop("disabled", true)
58+
this.$save.prop("disabled", true)
59+
await controller.ajaxRequest('/model',
60+
{do: "save attack configs", configs: JSON_stringify(configs)})
61+
this.$save.prop("disabled", false)
62+
this.$acceptDiv.find('button').prop("disabled", false)
63+
}
64+
5465
// Build buttons for model training process in model menu
5566
async addConfigMenu() {
5667
console.log('addConfigMenu')
@@ -126,5 +137,17 @@ class MenuAfterTrainView extends MenuView {
126137
this.onrun() // No await
127138
})
128139

140+
$cb = $("<div></div>").attr("class", "control-block")
141+
this.$mainDiv.append($cb)
142+
this.$save = $("<button></button>")
143+
.attr("id", "model-button-save-after").text("Save")
144+
.css("margin-right", "5px")
145+
.attr("title", "Save the chosen attack configurations")
146+
$cb.append(this.$save)
147+
148+
this.$save.click(async () => {
149+
this.onsave() // No await
150+
})
151+
129152
}
130153
}

0 commit comments

Comments
 (0)