11import warnings
2+
23from torch import device
34from torch .cuda import is_available
45
6+ from data_structures .configs import DatasetConfig , DatasetVarConfig , ExplainerInitConfig , \
7+ ExplainerRunConfig , FeatureConfig , Task
58from datasets .datasets_manager import DatasetManager
6- from data_structures .configs import DatasetConfig , DatasetVarConfig , ExplainerInitConfig , ExplainerRunConfig
79from explainers .explainers_manager import FrameworkExplainersManager
810from models_builder .gnn_models import FrameworkGNNModelManager , Metric
911from models_builder .models_zoo import model_configs_zoo
@@ -13,47 +15,31 @@ def explainers_test():
1315 my_device = device ('cuda' if is_available () else 'cpu' )
1416
1517 # Init datasets
16- dataset_mg_small , _ , results_dataset_path_mg_small = DatasetManager .get_by_full_name (
17- full_name = ( "example" , "example3" ,),
18- features = FeatureConfig ( node_attr = [ 'a' ]),
19- labeling = 'binary' ,
20- dataset_ver_ind = 0
18+ dataset_mg_small = DatasetManager .get_by_config (
19+ DatasetConfig (( "example" , "example3" ,) ),
20+ DatasetVarConfig (
21+ task = Task . GRAPH_CLASSIFICATION , features = FeatureConfig ( node_attr = [ 'a' ]) ,
22+ labeling = 'binary' , dataset_ver_ind = 0 )
2123 )
2224
23- dataset_sg_example , _ , results_dataset_path_sg_example = DatasetManager .get_by_full_name (
24- full_name = ( "example" , "single-graph" , " example" ,),
25- features = FeatureConfig ( node_attr = [ 'a' ]),
26- labeling = 'binary' ,
27- dataset_ver_ind = 0
25+ dataset_sg_example = DatasetManager .get_by_config (
26+ DatasetConfig (( "example" , "example" ,) ),
27+ DatasetVarConfig (
28+ task = Task . GRAPH_CLASSIFICATION , features = FeatureConfig ( node_attr = [ 'a' ]) ,
29+ labeling = 'binary' , dataset_ver_ind = 0 )
2830 )
31+ dataset_sg_example .train_test_split (percent_train_class = 0.6 , percent_test_class = 0.4 )
2932
3033 gen_dataset_mg_small = DatasetManager .get_by_config (
31- DatasetConfig (
32- domain = "Homogeneous" ,
33- group = "custom" ,
34- graph = "small" ),
35- DatasetVarConfig (features = FeatureConfig (node_attr = ['a' ]),
36- labeling = 'binary' ,
37- dataset_ver_ind = 0 )
38- )
39- gen_dataset_sg_example = DatasetManager .get_by_config (
40- DatasetConfig (
41- domain = "single-graph" ,
42- group = "custom" ,
43- graph = "example" ),
44- DatasetVarConfig (features = FeatureConfig (node_attr = ['a' ]),
45- labeling = 'binary' ,
46- dataset_ver_ind = 0 )
34+ DatasetConfig (("example" ,"example8" )),
35+ DatasetVarConfig (
36+ task = Task .GRAPH_CLASSIFICATION , features = FeatureConfig (node_attr = ['a' ]),
37+ labeling = 'binary' , dataset_ver_ind = 0 )
4738 )
4839 gen_dataset_mg_small .train_test_split (percent_train_class = 0.6 , percent_test_class = 0.4 )
49- gen_dataset_sg_example .train_test_split (percent_train_class = 0.6 , percent_test_class = 0.4 )
5040
5141
5242 dataset_mg_small = gen_dataset_mg_small
53- results_dataset_path_mg_small = gen_dataset_mg_small .prepared_dir
54-
55- dataset_sg_example = gen_dataset_sg_example
56- results_dataset_path_sg_example = gen_dataset_sg_example .prepared_dir
5743
5844 # Init gnns and gnn_model_managers
5945 # gat2_cora = model_configs_zoo(dataset=dataset_cora, model_name='gat_gat')
@@ -64,15 +50,15 @@ def explainers_test():
6450
6551 gnn_model_manager_mg_small = FrameworkGNNModelManager (
6652 gnn = gin3_lin2_mg_small ,
67- dataset_path = results_dataset_path_mg_small ,
53+ dataset_path = dataset_mg_small . prepared_dir ,
6854 )
6955
7056 # gin3_lin2_sg_example = model_configs_zoo(dataset=dataset_sg_example, model_name='gin_gin_gin_lin_lin')
7157 gin3_lin2_sg_example = model_configs_zoo (dataset = dataset_sg_example , model_name = 'gcn_gcn_lin' )
7258
7359 gnn_model_manager_sg_example = FrameworkGNNModelManager (
7460 gnn = gin3_lin2_sg_example ,
75- dataset_path = results_dataset_path_sg_example ,
61+ dataset_path = dataset_sg_example . prepared_dir ,
7662 )
7763
7864 # Train models
0 commit comments