Skip to content

Commit 73f5e70

Browse files
committed
+
1 parent 534e04e commit 73f5e70

File tree

4 files changed

+94
-12
lines changed

4 files changed

+94
-12
lines changed

experiments/zorro_exp_example.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,27 @@
44
from torch import device
55
from torch.cuda import is_available
66

7-
from data_structures.configs import ConfigPattern, ModelManagerConfig, ModelModificationConfig, \
8-
Task, DatasetConfig
9-
from aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH
10-
from datasets.ptg_datasets import LibPTGDataset
11-
from explainers.explainers_manager import FrameworkExplainersManager
12-
from models_builder.gnn_models import FrameworkGNNModelManager, Metric
13-
from datasets.datasets_manager import DatasetManager
14-
from models_builder.models_zoo import model_configs_zoo
7+
# from data_structures.configs import ConfigPattern, ModelManagerConfig, ModelModificationConfig, \
8+
# Task, DatasetConfig
9+
# from aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH
10+
# from datasets.ptg_datasets import LibPTGDataset
11+
# from explainers.explainers_manager import FrameworkExplainersManager
12+
# from models_builder.gnn_models import FrameworkGNNModelManager, Metric
13+
# from datasets.datasets_manager import DatasetManager
14+
# from models_builder.models_zoo import model_configs_zoo
1515

1616

1717
# from pytorch_model_summary import summary
1818

1919

2020
# from visualization.plotutils import draw_vk, draw_cora, draw
21+
from gnn_aid.aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH
22+
from gnn_aid.data_structures import DatasetConfig, DatasetVarConfig, Task
23+
from gnn_aid.data_structures.configs import ConfigPattern, FeatureConfig
24+
from gnn_aid.datasets import DatasetManager
25+
from gnn_aid.explainers import FrameworkExplainersManager
26+
from gnn_aid.models_builder import model_configs_zoo
27+
from gnn_aid.models_builder.model_managers import FrameworkGNNModelManager
2128

2229

2330
def test_Zorro(save_nan=True):
@@ -186,9 +193,68 @@ def test_Zorro(save_nan=True):
186193
# print("INCORRECT PREDICTION")
187194

188195

196+
def test_zorro():
197+
# Single-Graph - Example
198+
gen_dataset_sg_example = DatasetManager.get_by_config(
199+
DatasetConfig(("example", "example")),
200+
DatasetVarConfig(task=Task.NODE_CLASSIFICATION,
201+
features=FeatureConfig(node_attr=['a']),
202+
labeling='binary',
203+
dataset_ver_ind=0)
204+
)
205+
gen_dataset_sg_example.train_test_split(percent_train_class=0.6, percent_test_class=0.4)
206+
dataset_sg_example = gen_dataset_sg_example
207+
results_dataset_path_sg_example = gen_dataset_sg_example.prepared_dir
208+
209+
gcn2_sg_example = model_configs_zoo(dataset=gen_dataset_sg_example, model_name='gcn_gcn')
210+
211+
gnn_model_manager_sg_example_manager_config = ConfigPattern(
212+
_config_class="ModelManagerConfig",
213+
_config_kwargs={
214+
"batch": 10000,
215+
"mask_features": []
216+
}
217+
)
218+
gnn_model_manager_sg_example = FrameworkGNNModelManager(
219+
gnn=gcn2_sg_example,
220+
dataset_path=results_dataset_path_sg_example,
221+
manager_config=gnn_model_manager_sg_example_manager_config
222+
)
223+
224+
225+
explainer_init_config = ConfigPattern(
226+
_class_name="Zorro",
227+
_import_path=EXPLAINERS_INIT_PARAMETERS_PATH,
228+
_config_class="ExplainerInitConfig",
229+
_config_kwargs={
230+
"_test_": (5, 6)
231+
}
232+
)
233+
explainer_run_config = ConfigPattern(
234+
_config_class="ExplainerRunConfig",
235+
_config_kwargs={
236+
"mode": "local",
237+
"kwargs": {
238+
"_class_name": "Zorro",
239+
"_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH,
240+
"_config_class": "Config",
241+
"_config_kwargs": {
242+
},
243+
}
244+
}
245+
)
246+
explainer_Zorro = FrameworkExplainersManager(
247+
init_config=explainer_init_config,
248+
dataset=dataset_sg_example, gnn_manager=gnn_model_manager_sg_example,
249+
explainer_name='Zorro',
250+
)
251+
explainer_Zorro.conduct_experiment(explainer_run_config)
252+
253+
189254
if __name__ == '__main__':
190255
# t0 = time.clock()
191-
test_Zorro()
256+
# test_Zorro()
257+
test_zorro()
192258

193259
# print(time.clock() - t0)
194260
# test_Vk()

gnn_aid/aux/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,20 +117,33 @@ def setting_class_default_parameters(
117117
raise Exception(f"{class_name} is not currently supported")
118118
class_kwargs_default = class_kwargs_default[class_name]
119119
for key, val in class_kwargs.items():
120+
key_1 = class_kwargs_default[key][1]
120121
if key == TECHNICAL_PARAMETER_KEY or key not in class_kwargs_default.keys():
121122
# raise Exception(
122123
# f"Parameter {key} cannot be set for {class_name}")
123124
warnings.warn(f"WARNING: Parameter {key} cannot be set for {class_name} "
124125
f"in def setting_class_default_parameters")
125126
continue
126-
elif val is None or class_kwargs_default[key][1] == 'string' or (class_kwargs_default[key][1] == 'dynamic' and isinstance(val, str)) or np.isinf(val):
127+
elif key_1 == 'int_or_tuple':
128+
try:
129+
class_kwargs[key] = int(val)
130+
except TypeError:
131+
class_kwargs[key] = tuple(val)
132+
elif val is None or key_1 == 'string'\
133+
or (key_1 == 'dynamic' and isinstance(val, str))\
134+
or np.isinf(val):
127135
class_kwargs[key] = val
128136
else:
129-
class_kwargs[key] = locate(class_kwargs_default[key][1])(val)
137+
class_kwargs[key] = locate(key_1)(val)
130138
for key, val in class_kwargs_default.items():
131139
if key != TECHNICAL_PARAMETER_KEY and key not in class_kwargs.keys():
132140
if val[2] is None or val[1] == 'string' or val[2] == np.inf:
133141
class_kwargs[key] = val[2]
142+
elif val[1] == 'int_or_tuple':
143+
try:
144+
class_kwargs[key] = int(val[2])
145+
except TypeError:
146+
class_kwargs[key] = tuple(val[2])
134147
else:
135148
class_kwargs[key] = locate(val[1])(val[2])
136149

metainfo/explainers_init_parameters.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
"Zorro": {
33
"greedy": ["Greedy", "bool", true, null, "?"],
44
"add_noise": ["Noise", "bool", false, null, "?"],
5-
"samples": ["Samples", "int", 100, {"min": 1}, "Number of random noise samples to control fidelity"]
5+
"samples": ["Samples", "int", 100, {"min": 1}, "Number of random noise samples to control fidelity"],
6+
"_test_": ["Test", "int_or_tuple", [0, 1], {"min": 1}, "_test_"]
67
},
78

89
"SubgraphX": {

web_interface/back_front/dataset_blocks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from gnn_aid.aux.data_info import DataInfo
44
from gnn_aid.aux.utils import TORCH_GEOM_GRAPHS_PATH
5+
from gnn_aid.data_structures import Task
56
from gnn_aid.data_structures.configs import DatasetConfig, DatasetVarConfig, FeatureConfig
67
from gnn_aid.datasets.datasets_manager import DatasetManager
78
from gnn_aid.datasets.gen_dataset import GeneralDataset
@@ -112,6 +113,7 @@ def _finalize(
112113

113114
kwargs = self._config.copy()
114115
features = FeatureConfig(**kwargs.pop('features'))
116+
kwargs['task'] = Task(kwargs.pop('task'))
115117
kwargs['features'] = features
116118
self.dataset_var_config = DatasetVarConfig(**kwargs)
117119
# print(self.dataset_var_config.to_dict())

0 commit comments

Comments
 (0)