1+ import json
2+
13import torch
24from torch import device
35
46from gnn_aid .attacks import Attacker
57from gnn_aid .aux .utils import FUNCTIONS_PARAMETERS_PATH , all_subclasses
8+ from gnn_aid .data_structures import ModelStructureConfig , ModelConfig
69from gnn_aid .data_structures .configs import DatasetConfig , DatasetVarConfig , FeatureConfig , Task , \
710 ConfigPattern , ModelModificationConfig
811from gnn_aid .datasets .datasets_manager import DatasetManager
912from gnn_aid .datasets .ptg_datasets import LibPTGDataset
10- from gnn_aid .models_builder .gnn_models import FrameworkGNNModelManager , Metric
13+ from gnn_aid .models_builder import FrameworkGNNConstructor
14+ from gnn_aid .models_builder .models_utils import Metric
15+ from gnn_aid .models_builder .model_managers import FrameworkGNNModelManager
1116from gnn_aid .models_builder .models_zoo import model_configs_zoo
1217
1318
@@ -103,9 +108,32 @@ def link_prediction():
103108
104109 gen_dataset = DatasetManager .get_by_config (dc , dvc )
105110 print (gen_dataset .data )
106- gen_dataset .train_test_split ()
111+ gen_dataset .train_test_split (percent_train_class = 0.85 , percent_test_class = 0.1 )
112+
113+ gnn = FrameworkGNNConstructor (
114+ model_config = ModelConfig (
115+ structure = ModelStructureConfig (
116+ [
117+ {
118+ 'label' : 'n' ,
119+ 'layer' : {
120+ 'layer_name' : 'SAGEConv' ,
121+ 'layer_kwargs' : {
122+ 'in_channels' : gen_dataset .num_node_features ,
123+ 'out_channels' : 16 ,
124+ },
125+ },
126+ },
127+ {
128+ 'label' : 'd' ,
129+ 'function' : {
130+ 'function_name' : 'CosineSimilarity' ,
131+ 'function_kwargs' : None
132+ }
133+ }
134+ ]
135+ )))
107136
108- gnn = model_configs_zoo (dataset = gen_dataset , model_name = 'gcn_gcn' )
109137 manager_config = ConfigPattern (
110138 _config_class = "ModelManagerConfig" ,
111139 _config_kwargs = {
@@ -125,7 +153,7 @@ def link_prediction():
125153 }
126154 )
127155
128- steps_epochs = 30
156+ steps_epochs = 3
129157 my_device = device ('cuda' if torch .cuda .is_available () else 'cpu' )
130158 gnn_model_manager = FrameworkGNNModelManager (
131159 gnn = gnn ,
@@ -145,13 +173,130 @@ def link_prediction():
145173 )
146174 print ("Training was successful" )
147175
176+ res = gnn_model_manager .run_model (
177+ gen_dataset = gen_dataset ,
178+ mask = 'all'
179+ )
180+ res = gnn_model_manager .evaluate_model (
181+ gen_dataset = gen_dataset ,
182+ metrics = [Metric ("Accuracy" , mask = "test" )]
183+ )
184+ print (json .dumps (res , indent = 2 ))
185+
186+ # gnn.eval()
187+ # z = gnn.encode(data.x, data.edge_index)
188+ # out = gnn.decode(z, data.edge_label_index).view(-1).sigmoid()
189+ # return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
190+
191+
192+ def ptg_example ():
193+ import os .path as osp
194+
195+ import torch
196+ from sklearn .metrics import roc_auc_score
197+
198+ import torch_geometric .transforms as T
199+ from torch_geometric .datasets import Planetoid
200+ from torch_geometric .nn import GCNConv
201+ from torch_geometric .utils import negative_sampling
202+
203+ if torch .cuda .is_available ():
204+ device = torch .device ('cuda' )
205+ elif hasattr (torch .backends , 'mps' ) and torch .backends .mps .is_available ():
206+ device = torch .device ('mps' )
207+ else :
208+ device = torch .device ('cpu' )
209+
210+ transform = T .Compose ([
211+ T .NormalizeFeatures (),
212+ T .ToDevice (device ),
213+ T .RandomLinkSplit (num_val = 0.05 , num_test = 0.1 , is_undirected = True ,
214+ add_negative_train_samples = False ),
215+ ])
216+ path = osp .join (osp .dirname (osp .realpath (__file__ )), '..' , 'data' , 'Planetoid' )
217+ dataset = Planetoid (path , name = 'Cora' , transform = transform )
218+ # After applying the `RandomLinkSplit` transform, the data is transformed from
219+ # a data object to a list of tuples (train_data, val_data, test_data), with
220+ # each element representing the corresponding split.
221+ train_data , val_data , test_data = dataset [0 ]
222+
223+ class Net (torch .nn .Module ):
224+ def __init__ (self , in_channels , hidden_channels , out_channels ):
225+ super ().__init__ ()
226+ self .conv1 = GCNConv (in_channels , hidden_channels )
227+ self .conv2 = GCNConv (hidden_channels , out_channels )
228+
229+ def encode (self , x , edge_index ):
230+ x = self .conv1 (x , edge_index ).relu ()
231+ return self .conv2 (x , edge_index )
232+
233+ def decode (self , z , edge_label_index ):
234+ return (z [edge_label_index [0 ]] * z [edge_label_index [1 ]]).sum (dim = - 1 )
235+
236+ def decode_all (self , z ):
237+ prob_adj = z @ z .t ()
238+ return (prob_adj > 0 ).nonzero (as_tuple = False ).t ()
239+
240+ model = Net (dataset .num_features , 128 , 64 ).to (device )
241+ optimizer = torch .optim .Adam (params = model .parameters (), lr = 0.01 )
242+ criterion = torch .nn .BCEWithLogitsLoss ()
243+
244+ def train ():
245+ model .train ()
246+ optimizer .zero_grad ()
247+ z = model .encode (train_data .x , train_data .edge_index )
248+
249+ # We perform a new round of negative sampling for every training epoch:
250+ neg_edge_index = negative_sampling (
251+ edge_index = train_data .edge_index , num_nodes = train_data .num_nodes ,
252+ num_neg_samples = train_data .edge_label_index .size (1 ), method = 'sparse' )
253+
254+ edge_label_index = torch .cat (
255+ [train_data .edge_label_index , neg_edge_index ],
256+ dim = - 1 ,
257+ )
258+ edge_label = torch .cat ([
259+ train_data .edge_label ,
260+ train_data .edge_label .new_zeros (neg_edge_index .size (1 ))
261+ ], dim = 0 )
262+
263+ out = model .decode (z , edge_label_index ).view (- 1 )
264+ loss = criterion (out , edge_label )
265+ loss .backward ()
266+ optimizer .step ()
267+ return loss
268+
269+ @torch .no_grad ()
270+ def test (data ):
271+ model .eval ()
272+ z = model .encode (data .x , data .edge_index )
273+ out = model .decode (z , data .edge_label_index ).view (- 1 ).sigmoid ()
274+ return roc_auc_score (data .edge_label .cpu ().numpy (), out .cpu ().numpy ())
275+
276+ best_val_auc = final_test_auc = 0
277+ for epoch in range (1 , 101 ):
278+ loss = train ()
279+ val_auc = test (val_data )
280+ test_auc = test (test_data )
281+ if val_auc > best_val_auc :
282+ best_val_auc = val_auc
283+ final_test_auc = test_auc
284+ print (f'Epoch: { epoch :03d} , Loss: { loss :.4f} , Val: { val_auc :.4f} , '
285+ f'Test: { test_auc :.4f} ' )
286+
287+ print (f'Final Test: { final_test_auc :.4f} ' )
288+
289+ z = model .encode (test_data .x , test_data .edge_index )
290+ final_edge_index = model .decode_all (z )
291+
148292
149293if __name__ == '__main__' :
150294 # node_regression()
151295 # graph_regression()
152296
153297 # edge_regression()
154- # link_prediction()
298+ link_prediction ()
155299
156- for c in all_subclasses (Attacker ):
157- print (c )
300+ # ptg_example()
301+ # for c in all_subclasses(Attacker):
302+ # print(c)
0 commit comments