@@ -255,6 +255,11 @@ def name(
255255 res += '{' + kwargs + '}'
256256 return res
257257
258+ def __str__ (
259+ self
260+ ) -> str :
261+ return self .name ()
262+
258263 def compute (
259264 self ,
260265 y_true ,
@@ -343,7 +348,8 @@ def forward(self, tensors):
343348# ================================================
344349
345350
346- def predict_top_k_edges (model , data , exclude_edges , k = 100 , use_faiss = True , faiss_k_per_node = 200 ):
351+ def predict_top_k_edges (model , data , exclude_edges , k = 100 , use_faiss = True ,
352+ faiss_k_per_node = 100 , is_directed = True , remove_loops = True ):
347353 """
348354 Predict top-k new edges with FAISS support for large graphs
349355
@@ -354,6 +360,8 @@ def predict_top_k_edges(model, data, exclude_edges, k=100, use_faiss=True, faiss
354360 k: number of top edges to return (globally)
355361 use_faiss: use FAISS for fast search (recommended for large graphs)
356362 faiss_k_per_node: how many candidates to search per node via FAISS
363+ is_directed: if False, normalize edges to (i,j) where i < j and remove duplicates
364+ remove_loops: if True, exclude self-loops (i,i) from predictions
357365
358366 Returns:
359367 top_edges: torch.Tensor shape (2, k) - indices of top-k node pairs
@@ -374,9 +382,15 @@ def predict_top_k_edges(model, data, exclude_edges, k=100, use_faiss=True, faiss
374382 # Create set of existing edges for fast lookup
375383 existing_set = set ()
376384 if exclude_edges is not None and exclude_edges .size (1 ) > 0 :
385+ # Normalize exclude_edges for undirected graphs
386+ if not is_directed :
387+ exclude_edges_normalized = _normalize_edges (exclude_edges )
388+ else :
389+ exclude_edges_normalized = exclude_edges
390+
377391 existing_set = set (zip (
378- exclude_edges [0 ].cpu ().tolist (),
379- exclude_edges [1 ].cpu ().tolist ()
392+ exclude_edges_normalized [0 ].cpu ().tolist (),
393+ exclude_edges_normalized [1 ].cpu ().tolist ()
380394 ))
381395 print (f"Excluding { len (existing_set )} existing edges" )
382396
@@ -397,22 +411,86 @@ def predict_top_k_edges(model, data, exclude_edges, k=100, use_faiss=True, faiss
397411 total_pairs = num_nodes * num_nodes
398412 if use_faiss and total_pairs > 100e6 :
399413 print (f"Using FAISS for large graph: { num_nodes } x { num_nodes } = { total_pairs } pairs" )
400- return _predict_with_faiss (model , data , h_norm , existing_set , k , faiss_k_per_node )
414+ return _predict_with_faiss (model , data , h_norm , existing_set , k ,
415+ faiss_k_per_node , is_directed , remove_loops )
401416 else :
402417 print (
403418 f"Using full enumeration for small graph: { num_nodes } x { num_nodes } = { total_pairs } pairs" )
404- return _predict_full_enumeration (model , data , h_norm , existing_set , k )
419+ return _predict_full_enumeration (model , data , h_norm , existing_set , k ,
420+ is_directed , remove_loops )
421+
422+
423+ def _normalize_edges (edges ):
424+ """
425+ Normalize edges for undirected graph: ensure i < j for each edge (i,j)
426+
427+ Args:
428+ edges: torch.Tensor shape (2, num_edges)
429+
430+ Returns:
431+ normalized_edges: torch.Tensor shape (2, num_edges) with i < j
432+ """
433+ src , dst = edges [0 ], edges [1 ]
434+
435+ # Swap where src > dst
436+ mask = src > dst
437+ src_new = torch .where (mask , dst , src )
438+ dst_new = torch .where (mask , src , dst )
439+
440+ return torch .stack ([src_new , dst_new ], dim = 0 )
441+
442+
443+ def _deduplicate_edges (edges , scores ):
444+ """
445+ Remove duplicate edges and keep only unique ones with highest scores
446+
447+ Args:
448+ edges: torch.Tensor shape (2, num_edges)
449+ scores: torch.Tensor shape (num_edges,)
450+
451+ Returns:
452+ unique_edges: torch.Tensor shape (2, num_unique)
453+ unique_scores: torch.Tensor shape (num_unique,)
454+ """
455+ if edges .size (1 ) == 0 :
456+ return edges , scores
457+
458+ # Create dictionary: edge_tuple -> (score, index)
459+ edge_dict = {}
460+ for idx in range (edges .size (1 )):
461+ edge_tuple = (edges [0 , idx ].item (), edges [1 , idx ].item ())
462+ score = scores [idx ].item ()
463+
464+ # Keep edge with higher score
465+ if edge_tuple not in edge_dict or score > edge_dict [edge_tuple ][0 ]:
466+ edge_dict [edge_tuple ] = (score , idx )
467+
468+ # Extract unique edges and their scores
469+ unique_indices = [v [1 ] for v in edge_dict .values ()]
470+ unique_edges = edges [:, unique_indices ]
471+ unique_scores = scores [unique_indices ]
472+
473+ # Re-sort by scores
474+ sorted_scores , sorted_indices = torch .sort (unique_scores , descending = True )
475+ unique_edges = unique_edges [:, sorted_indices ]
405476
477+ return unique_edges , sorted_scores
406478
407- def _predict_with_faiss (model , data , h_norm , existing_set , k , faiss_k_per_node ):
479+
480+ def _predict_with_faiss (model , data , h_norm , existing_set , k , faiss_k_per_node ,
481+ is_directed , remove_loops ):
408482 """Prediction using FAISS for fast search"""
409483 import faiss
410484
411485 num_nodes = h_norm .size (0 )
412486 embedding_dim = h_norm .size (1 )
413487
414488 # Parameter: how many candidates to take per node
415- faiss_k_per_node = min (faiss_k_per_node , num_nodes )
489+ # For undirected graphs, we need more candidates to account for deduplication
490+ if not is_directed :
491+ faiss_k_per_node = min (faiss_k_per_node * 2 , num_nodes )
492+ else :
493+ faiss_k_per_node = min (faiss_k_per_node , num_nodes )
416494
417495 # Build FAISS index for all nodes
418496 h_np = h_norm .cpu ().numpy ().astype ('float32' )
@@ -430,23 +508,37 @@ def _predict_with_faiss(model, data, h_norm, existing_set, k, faiss_k_per_node):
430508 all_pairs = []
431509 for src_idx in range (num_nodes ):
432510 for rank , dst_idx in enumerate (candidate_indices [src_idx ]):
511+ dst_idx = int (dst_idx )
512+
513+ # Remove loops if requested
514+ if remove_loops and src_idx == dst_idx :
515+ continue
516+
517+ # Normalize edge for undirected graph
518+ if not is_directed :
519+ edge = (min (src_idx , dst_idx ), max (src_idx , dst_idx ))
520+ else :
521+ edge = (src_idx , dst_idx )
522+
433523 # Skip existing edges
434- if (src_idx , int (dst_idx )) in existing_set :
435- print (f"Score for existing ({ src_idx } , { dst_idx } ): { similarities [src_idx , rank ]} " )
524+ if edge in existing_set :
436525 continue
437526
438527 # Save (src, dst, similarity)
439- all_pairs .append ((src_idx , int ( dst_idx ) , similarities [src_idx , rank ]))
528+ all_pairs .append ((edge [ 0 ], edge [ 1 ] , similarities [src_idx , rank ]))
440529
441530 print (f"Found { len (all_pairs )} candidate pairs after filtering" )
442531
443532 if len (all_pairs ) == 0 :
444533 print ("Warning: No candidate pairs found!" )
445534 return torch .zeros ((2 , 0 ), dtype = torch .long ), torch .zeros (0 )
446535
447- # Sort by similarity and take top-k
536+ # Sort by similarity and take more than k to account for deduplication
448537 all_pairs .sort (key = lambda x : x [2 ], reverse = True )
449- top_pairs = all_pairs [:min (k , len (all_pairs ))]
538+
539+ # For undirected graphs, we already normalized, so no duplicates
540+ # But we take more candidates to be safe
541+ top_pairs = all_pairs [:min (k * 2 if not is_directed else k , len (all_pairs ))]
450542
451543 # Now compute exact scores via model for final candidates
452544 h = model (data .x , data .edge_index )
@@ -461,26 +553,34 @@ def _predict_with_faiss(model, data, h_norm, existing_set, k, faiss_k_per_node):
461553 final_scores = model .decode (h_src_final , h_dst_final )
462554 final_scores = final_scores .sigmoid ()
463555
464- # Re-sort by exact scores
465- sorted_scores , sorted_indices = torch .sort (final_scores , descending = True )
466- final_src = final_src [sorted_indices ]
467- final_dst = final_dst [sorted_indices ]
468- top_edges = torch .stack ([final_src , final_dst ], dim = 0 )
556+ # Create edges tensor
557+ edges = torch .stack ([final_src , final_dst ], dim = 0 )
469558
470- print (f"Top-{ len (sorted_scores )} edges found with scores from "
471- f"{ sorted_scores [- 1 ]:.4f} to { sorted_scores [0 ]:.4f} " )
559+ # Deduplicate if undirected (should already be normalized, but just in case)
560+ if not is_directed :
561+ edges , final_scores = _deduplicate_edges (edges , final_scores )
472562
473- return top_edges .cpu (), sorted_scores .cpu ()
563+ # Take final top-k
564+ final_k = min (k , edges .size (1 ))
565+ top_edges = edges [:, :final_k ]
566+ top_scores = final_scores [:final_k ]
474567
568+ print (f"Top-{ final_k } edges found with scores from "
569+ f"{ top_scores [- 1 ]:.4f} to { top_scores [0 ]:.4f} " )
475570
476- def _predict_full_enumeration (model , data , h , existing_set , k ):
571+ return top_edges .cpu (), top_scores .cpu ()
572+
573+
574+ def _predict_full_enumeration (model , data , h , existing_set , k , is_directed , remove_loops ):
477575 """Prediction with full enumeration of all pairs (for small graphs)"""
478576 num_nodes = h .size (0 )
479577 device = h .device
480578
481579 # Request more candidates with margin for filtering
482580 k_candidates = k + (len (existing_set ) if existing_set else 0 )
483- k_candidates = min (k_candidates * 2 , num_nodes * num_nodes ) # with 2x margin
581+ if not is_directed :
582+ k_candidates = k_candidates * 3 # More margin for undirected due to deduplication
583+ k_candidates = min (k_candidates * 2 , num_nodes * num_nodes )
484584
485585 # Store global top-k
486586 global_top_scores = None
@@ -500,6 +600,22 @@ def _predict_full_enumeration(model, data, h, existing_set, k):
500600 src_repeated = src_batch .repeat_interleave (num_nodes )
501601 dst_repeated = dst_nodes .repeat (batch_src_size )
502602
603+ # Filter before computing scores
604+ # Remove loops if requested
605+ if remove_loops :
606+ loop_mask = src_repeated != dst_repeated
607+ src_repeated = src_repeated [loop_mask ]
608+ dst_repeated = dst_repeated [loop_mask ]
609+
610+ # For undirected graphs, only keep pairs where src <= dst to avoid duplicates
611+ if not is_directed :
612+ undirected_mask = src_repeated <= dst_repeated
613+ src_repeated = src_repeated [undirected_mask ]
614+ dst_repeated = dst_repeated [undirected_mask ]
615+
616+ if len (src_repeated ) == 0 :
617+ continue
618+
503619 # Prepare embeddings for decode
504620 h_src_batch = h [src_repeated ]
505621 h_dst_batch = h [dst_repeated ]
@@ -546,16 +662,21 @@ def _predict_full_enumeration(model, data, h, existing_set, k):
546662 global_top_src = global_top_src [mask ]
547663 global_top_dst = global_top_dst [mask ]
548664
665+ # Create edges tensor
666+ edges = torch .stack ([global_top_src , global_top_dst ], dim = 0 )
667+
668+ # For undirected graphs, edges should already be normalized (src <= dst)
669+ # But deduplicate just in case
670+ if not is_directed :
671+ edges , global_top_scores = _deduplicate_edges (edges , global_top_scores )
672+
549673 # Final top-k
550- final_k = min (k , len ( global_top_scores ))
674+ final_k = min (k , edges . size ( 1 ))
551675 if final_k < k :
552676 print (f"Warning: Only { final_k } unique pairs available, returning all" )
553677
678+ top_edges = edges [:, :final_k ]
554679 top_scores = global_top_scores [:final_k ]
555- top_src = global_top_src [:final_k ]
556- top_dst = global_top_dst [:final_k ]
557-
558- top_edges = torch .stack ([top_src , top_dst ], dim = 0 )
559680
560681 print (f"Top-{ final_k } edges found with scores from { top_scores [- 1 ]:.4f} to { top_scores [0 ]:.4f} " )
561682
0 commit comments