|
4 | 4 | from torch import device |
5 | 5 | from torch.cuda import is_available |
6 | 6 |
|
7 | | -from data_structures.configs import ConfigPattern, ModelManagerConfig, ModelModificationConfig, \ |
8 | | - Task, DatasetConfig |
9 | | -from aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH |
10 | | -from datasets.ptg_datasets import LibPTGDataset |
11 | | -from explainers.explainers_manager import FrameworkExplainersManager |
12 | | -from models_builder.gnn_models import FrameworkGNNModelManager, Metric |
13 | | -from datasets.datasets_manager import DatasetManager |
14 | | -from models_builder.models_zoo import model_configs_zoo |
| 7 | +# from data_structures.configs import ConfigPattern, ModelManagerConfig, ModelModificationConfig, \ |
| 8 | +# Task, DatasetConfig |
| 9 | +# from aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH |
| 10 | +# from datasets.ptg_datasets import LibPTGDataset |
| 11 | +# from explainers.explainers_manager import FrameworkExplainersManager |
| 12 | +# from models_builder.gnn_models import FrameworkGNNModelManager, Metric |
| 13 | +# from datasets.datasets_manager import DatasetManager |
| 14 | +# from models_builder.models_zoo import model_configs_zoo |
15 | 15 |
|
16 | 16 |
|
17 | 17 | # from pytorch_model_summary import summary |
18 | 18 |
|
19 | 19 |
|
20 | 20 | # from visualization.plotutils import draw_vk, draw_cora, draw |
| 21 | +from gnn_aid.aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH |
| 22 | +from gnn_aid.data_structures import DatasetConfig, DatasetVarConfig, Task |
| 23 | +from gnn_aid.data_structures.configs import ConfigPattern, FeatureConfig |
| 24 | +from gnn_aid.datasets import DatasetManager |
| 25 | +from gnn_aid.explainers import FrameworkExplainersManager |
| 26 | +from gnn_aid.models_builder import model_configs_zoo |
| 27 | +from gnn_aid.models_builder.model_managers import FrameworkGNNModelManager |
21 | 28 |
|
22 | 29 |
|
23 | 30 | def test_Zorro(save_nan=True): |
@@ -186,9 +193,68 @@ def test_Zorro(save_nan=True): |
186 | 193 | # print("INCORRECT PREDICTION") |
187 | 194 |
|
188 | 195 |
|
| 196 | +def test_zorro(): |
| 197 | + # Single-Graph - Example |
| 198 | + gen_dataset_sg_example = DatasetManager.get_by_config( |
| 199 | + DatasetConfig(("example", "example")), |
| 200 | + DatasetVarConfig(task=Task.NODE_CLASSIFICATION, |
| 201 | + features=FeatureConfig(node_attr=['a']), |
| 202 | + labeling='binary', |
| 203 | + dataset_ver_ind=0) |
| 204 | + ) |
| 205 | + gen_dataset_sg_example.train_test_split(percent_train_class=0.6, percent_test_class=0.4) |
| 206 | + dataset_sg_example = gen_dataset_sg_example |
| 207 | + results_dataset_path_sg_example = gen_dataset_sg_example.prepared_dir |
| 208 | + |
| 209 | + gcn2_sg_example = model_configs_zoo(dataset=gen_dataset_sg_example, model_name='gcn_gcn') |
| 210 | + |
| 211 | + gnn_model_manager_sg_example_manager_config = ConfigPattern( |
| 212 | + _config_class="ModelManagerConfig", |
| 213 | + _config_kwargs={ |
| 214 | + "batch": 10000, |
| 215 | + "mask_features": [] |
| 216 | + } |
| 217 | + ) |
| 218 | + gnn_model_manager_sg_example = FrameworkGNNModelManager( |
| 219 | + gnn=gcn2_sg_example, |
| 220 | + dataset_path=results_dataset_path_sg_example, |
| 221 | + manager_config=gnn_model_manager_sg_example_manager_config |
| 222 | + ) |
| 223 | + |
| 224 | + |
| 225 | + explainer_init_config = ConfigPattern( |
| 226 | + _class_name="Zorro", |
| 227 | + _import_path=EXPLAINERS_INIT_PARAMETERS_PATH, |
| 228 | + _config_class="ExplainerInitConfig", |
| 229 | + _config_kwargs={ |
| 230 | + "_test_": (5, 6) |
| 231 | + } |
| 232 | + ) |
| 233 | + explainer_run_config = ConfigPattern( |
| 234 | + _config_class="ExplainerRunConfig", |
| 235 | + _config_kwargs={ |
| 236 | + "mode": "local", |
| 237 | + "kwargs": { |
| 238 | + "_class_name": "Zorro", |
| 239 | + "_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, |
| 240 | + "_config_class": "Config", |
| 241 | + "_config_kwargs": { |
| 242 | + }, |
| 243 | + } |
| 244 | + } |
| 245 | + ) |
| 246 | + explainer_Zorro = FrameworkExplainersManager( |
| 247 | + init_config=explainer_init_config, |
| 248 | + dataset=dataset_sg_example, gnn_manager=gnn_model_manager_sg_example, |
| 249 | + explainer_name='Zorro', |
| 250 | + ) |
| 251 | + explainer_Zorro.conduct_experiment(explainer_run_config) |
| 252 | + |
| 253 | + |
189 | 254 | if __name__ == '__main__': |
190 | 255 | # t0 = time.clock() |
191 | | - test_Zorro() |
| 256 | + # test_Zorro() |
| 257 | + test_zorro() |
192 | 258 |
|
193 | 259 | # print(time.clock() - t0) |
194 | 260 | # test_Vk() |
|
0 commit comments