@@ -171,19 +171,20 @@ def link_prediction():
171171 manager_config = ConfigPattern (
172172 _config_class = "ModelManagerConfig" ,
173173 _config_kwargs = {
174- "batch" : 64 ,
174+ "batch" : 128 ,
175175 "mask_features" : [],
176176 "optimizer" : {
177177 "_class_name" : "Adam" ,
178178 "_config_kwargs" : {},
179179 },
180180 "loss_function" : {
181181 "_config_class" : "Config" ,
182- "_class_name" : "CrossEntropyLoss " ,
182+ "_class_name" : "BCEWithLogitsLoss " ,
183183 "_import_path" : FUNCTIONS_PARAMETERS_PATH ,
184184 "_class_import_info" : ["torch.nn" ],
185185 "_config_kwargs" : {},
186186 },
187+ "neg_samples_ratio" : 2
187188 }
188189 )
189190
@@ -208,10 +209,14 @@ def link_prediction():
208209 )
209210 print ("Training was successful" )
210211
211- # res = gnn_model_manager.run_model(
212- # gen_dataset=gen_dataset,
213- # mask='all'
214- # )
212+ res = gnn_model_manager .run_model (
213+ gen_dataset = gen_dataset ,
214+ mask = 'all' ,
215+ out = 'predictions'
216+ )
217+ print (json .dumps (res .tolist (), indent = 2 ))
218+ return
219+
215220 from gnn_aid .aux .utils import EVASION_ATTACK_PARAMETERS_PATH
216221 evasion_attack_config = ConfigPattern (
217222 _class_name = "FGSM" ,
@@ -227,17 +232,17 @@ def link_prediction():
227232 # атака
228233 # gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)
229234 #
230- # res = gnn_model_manager.evaluate_model(
231- # gen_dataset=gen_dataset,
232- # metrics=[
233- # Metric("Accuracy", mask='all'),
234- # Metric("AUC", mask='test'),
235- # # Metric("Precision@k", mask='test', k=50),
236- # # Metric("Precision@k", mask='test', k=500000),
237- # # Metric("Recall@k", mask='test', k=500000),
238- # ]
239- # )
240- # print(json.dumps(res, indent=2))
235+ res = gnn_model_manager .evaluate_model (
236+ gen_dataset = gen_dataset ,
237+ metrics = [
238+ Metric ("Accuracy" , mask = 'all' ),
239+ Metric ("AUC" , mask = 'test' ),
240+ # Metric("Precision@k", mask='test', k=50),
241+ # Metric("Precision@k", mask='test', k=500000),
242+ # Metric("Recall@k", mask='test', k=500000),
243+ ]
244+ )
245+ print (json .dumps (res , indent = 2 ))
241246
242247 # explainer
243248 from gnn_aid .aux .utils import EXPLAINERS_INIT_PARAMETERS_PATH , EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH
0 commit comments