1- from typing import Any , Optional , List , Callable , Union
1+ from typing import Any , Optional , List , Callable , Union , Tuple
22from typing import Callable
33
44import sklearn .metrics
88from torch .utils .hooks import RemovableHandle
99from torch_geometric .nn import MessagePassing
1010
11+ from gnn_aid .datasets import GeneralDataset
12+
1113
1214def apply_message_gradient_capture (
1315 layer : Any ,
@@ -694,4 +696,52 @@ def _predict_full_enumeration(model, data, h, existing_set, k, is_directed, remo
694696
695697 print (f"Top-{ final_k } edges found with scores from { top_scores [- 1 ]:.4f} to { top_scores [0 ]:.4f} " )
696698
697- return top_edges .cpu (), top_scores .cpu ()
699+ return top_edges .cpu (), top_scores .cpu ()
700+
701+
702+ def mask_to_tensor (
703+ gen_dataset : GeneralDataset ,
704+ mask : Union [str , int , Tuple [int ], list , torch .Tensor ] = 'test'
705+ ) -> torch .Tensor :
706+ """
707+ Convert a mask over nodes/edges/graphs to tensor of specific size.
708+ Mask can be 'train', 'val', 'test', 'all', or id, or a list of ids, or a tensor.
709+
710+ :param gen_dataset: dataset
711+ :param mask: part of the dataset on which the output will be obtained.
712+ Can be a node id, graph id, or edge as a tuple (i,j).
713+ Can be string: 'train', 'val', 'test', 'all'.
714+ Can be Tensor of specific nodes/edges/graphs.
715+ :return: tensor of nodes/edges/graphs
716+ """
717+ task = gen_dataset .dataset_var_config .task
718+
719+ if isinstance (mask , str ):
720+ mask_tensor = {
721+ 'train' : gen_dataset .train_mask ,
722+ 'val' : gen_dataset .val_mask ,
723+ 'test' : gen_dataset .test_mask ,
724+ 'all' : tensor ([True ] * len (gen_dataset .labels )),
725+ }[mask ]
726+
727+ elif isinstance (mask , torch .Tensor ):
728+ mask_tensor = mask
729+
730+ elif task .is_node_level (): # Node id
731+ assert not gen_dataset .is_multi ()
732+ mask_tensor = tensor ([False ] * gen_dataset .info .nodes [0 ])
733+ mask_tensor [mask ] = True # for int or list of ints
734+
735+ elif task .is_graph_level (): # Graph id
736+ assert gen_dataset .is_multi ()
737+ mask_tensor = tensor ([False ] * len (gen_dataset .info .nodes ))
738+ mask_tensor [mask ] = True # for int or list of ints
739+
740+ elif task .is_edge_level (): # Edge
741+ # isinstance(mask, tuple)
742+ raise NotImplementedError
743+
744+ else :
745+ raise RuntimeError (f"Cannot infer mask tensor for given mask { mask } ." )
746+
747+ return mask_tensor
0 commit comments