Skip to content

Commit 9fc1f15

Browse files
committed
Merge branch 'develop' of https://github.com/ispras/GNN-AID into fgsm_pgd_for_link_pred
# Conflicts: # experiments/various_tasks.py
2 parents 0bdc206 + 534e04e commit 9fc1f15

File tree

4 files changed

+171
-41
lines changed

4 files changed

+171
-41
lines changed

gnn_aid/datasets/gen_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ def train_test_split(
366366
train_data.edge_label_index.size(1) + val_data.edge_label_index.size(1)] = True
367367

368368
test_mask = torch.zeros(total_edges, dtype=torch.bool)
369-
test_mask[-test_data.edge_label_index.size(1):] = True
369+
if test_data.edge_label_index.size(1) > 0:
370+
test_mask[-test_data.edge_label_index.size(1):] = True
370371
else:
371372
raise ValueError(f"Unsupported task type {task_type}")
372373

gnn_aid/models_builder/model_managers/framework_mm.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -521,13 +521,12 @@ def run_model(
521521

522522
elif task_type.is_edge_level():
523523
data = gen_dataset.data
524-
edge_label_index = mask_tensor
525-
524+
train_edge_index = gen_dataset.edge_label_index[:, gen_dataset.train_mask]
526525
data_x_copy = torch.clone(data.x)
527526

528527
# FIXME misha check, test
529528
if hasattr(self, 'mask_features'):
530-
node_ind = torch.unique(edge_label_index)
529+
node_ind = torch.unique(train_edge_index)
531530
for elem_ind in node_ind:
532531
for feature in self.mask_features:
533532
data_x_copy[elem_ind][gen_dataset.node_attr_slices[feature][0]:
@@ -537,21 +536,23 @@ def run_model(
537536
if hasattr(self, 'optimizer'):
538537
self.optimizer.zero_grad()
539538

540-
# get logits for nodes
541-
node_out = self.gnn(data_x_copy, data.edge_index)
542-
543-
src = node_out[edge_label_index[0]]
544-
dst = node_out[edge_label_index[1]]
539+
# get logits for all nodes based on train edges
540+
node_out = self.gnn(data_x_copy, train_edge_index)
545541

542+
# Get logits for edges from mask
543+
src = node_out[mask_tensor[0]]
544+
dst = node_out[mask_tensor[1]]
546545
edge_out = self.gnn.decode(src, dst)
547546

547+
# Apply different out
548548
full_out = None
549549
if out == 'logits':
550550
full_out = edge_out
551551
elif out == 'predictions':
552552
if task_type == Task.EDGE_PREDICTION:
553-
# TODO misha
554-
raise NotImplementedError
553+
# TODO misha is it ok?
554+
full_out = edge_out.softmax(dim=-1)
555+
# raise NotImplementedError
555556
elif task_type == Task.EDGE_CLASSIFICATION:
556557
full_out = edge_out.softmax(dim=-1)
557558
elif task_type == Task.EDGE_REGRESSION:
@@ -710,10 +711,13 @@ def evaluate_model(
710711
if any(m.needs_all_node_pairs() for m in metrics):
711712
assert gen_dataset.dataset_var_config.task == Task.EDGE_PREDICTION
712713

713-
exclude_edges = None # TODO
714+
# By default we exclude train edges from prediction
715+
exclude_edges = gen_dataset.edge_label_index[:, gen_dataset.train_mask]
714716
k = max(m.kwargs.get('k', 0) for m in metrics)
715717
top_edges, top_scores = predict_top_k_edges(
716-
self.gnn, gen_dataset.data, exclude_edges, k=k, use_faiss=False)
718+
self.gnn, gen_dataset.data, exclude_edges, k=k, use_faiss=True,
719+
is_directed=gen_dataset.is_directed(), remove_loops=True
720+
)
717721
# y_pred_edges = list(zip(map(int, top_edges[0]), map(int, top_edges[1])))
718722
model_outputs['all_pairs'] = top_edges
719723

@@ -733,7 +737,8 @@ def evaluate_model(
733737
if metric.needs_logits():
734738
y_pred = model_outputs[mask]['logits']
735739
if metric.needs_all_node_pairs():
736-
y_pred = model_outputs['all_pairs']
740+
k = metric.kwargs.get('k')
741+
y_pred = model_outputs['all_pairs'][:k]
737742
y_true = model_outputs[mask]['true_edges']
738743

739744
if mask not in metrics_values:

gnn_aid/models_builder/models_utils.py

Lines changed: 148 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,11 @@ def name(
255255
res += '{' + kwargs + '}'
256256
return res
257257

258+
def __str__(
259+
self
260+
) -> str:
261+
return self.name()
262+
258263
def compute(
259264
self,
260265
y_true,
@@ -343,7 +348,8 @@ def forward(self, tensors):
343348
# ================================================
344349

345350

346-
def predict_top_k_edges(model, data, exclude_edges, k=100, use_faiss=True, faiss_k_per_node=200):
351+
def predict_top_k_edges(model, data, exclude_edges, k=100, use_faiss=True,
352+
faiss_k_per_node=100, is_directed=True, remove_loops=True):
347353
"""
348354
Predict top-k new edges with FAISS support for large graphs
349355
@@ -354,6 +360,8 @@ def predict_top_k_edges(model, data, exclude_edges, k=100, use_faiss=True, faiss
354360
k: number of top edges to return (globally)
355361
use_faiss: use FAISS for fast search (recommended for large graphs)
356362
faiss_k_per_node: how many candidates to search per node via FAISS
363+
is_directed: if False, normalize edges to (i,j) where i < j and remove duplicates
364+
remove_loops: if True, exclude self-loops (i,i) from predictions
357365
358366
Returns:
359367
top_edges: torch.Tensor shape (2, k) - indices of top-k node pairs
@@ -374,9 +382,15 @@ def predict_top_k_edges(model, data, exclude_edges, k=100, use_faiss=True, faiss
374382
# Create set of existing edges for fast lookup
375383
existing_set = set()
376384
if exclude_edges is not None and exclude_edges.size(1) > 0:
385+
# Normalize exclude_edges for undirected graphs
386+
if not is_directed:
387+
exclude_edges_normalized = _normalize_edges(exclude_edges)
388+
else:
389+
exclude_edges_normalized = exclude_edges
390+
377391
existing_set = set(zip(
378-
exclude_edges[0].cpu().tolist(),
379-
exclude_edges[1].cpu().tolist()
392+
exclude_edges_normalized[0].cpu().tolist(),
393+
exclude_edges_normalized[1].cpu().tolist()
380394
))
381395
print(f"Excluding {len(existing_set)} existing edges")
382396

@@ -397,22 +411,86 @@ def predict_top_k_edges(model, data, exclude_edges, k=100, use_faiss=True, faiss
397411
total_pairs = num_nodes * num_nodes
398412
if use_faiss and total_pairs > 100e6:
399413
print(f"Using FAISS for large graph: {num_nodes} x {num_nodes} = {total_pairs} pairs")
400-
return _predict_with_faiss(model, data, h_norm, existing_set, k, faiss_k_per_node)
414+
return _predict_with_faiss(model, data, h_norm, existing_set, k,
415+
faiss_k_per_node, is_directed, remove_loops)
401416
else:
402417
print(
403418
f"Using full enumeration for small graph: {num_nodes} x {num_nodes} = {total_pairs} pairs")
404-
return _predict_full_enumeration(model, data, h_norm, existing_set, k)
419+
return _predict_full_enumeration(model, data, h_norm, existing_set, k,
420+
is_directed, remove_loops)
421+
422+
423+
def _normalize_edges(edges):
424+
"""
425+
Normalize edges for undirected graph: ensure i < j for each edge (i,j)
426+
427+
Args:
428+
edges: torch.Tensor shape (2, num_edges)
429+
430+
Returns:
431+
normalized_edges: torch.Tensor shape (2, num_edges) with i < j
432+
"""
433+
src, dst = edges[0], edges[1]
434+
435+
# Swap where src > dst
436+
mask = src > dst
437+
src_new = torch.where(mask, dst, src)
438+
dst_new = torch.where(mask, src, dst)
439+
440+
return torch.stack([src_new, dst_new], dim=0)
441+
442+
443+
def _deduplicate_edges(edges, scores):
444+
"""
445+
Remove duplicate edges and keep only unique ones with highest scores
446+
447+
Args:
448+
edges: torch.Tensor shape (2, num_edges)
449+
scores: torch.Tensor shape (num_edges,)
450+
451+
Returns:
452+
unique_edges: torch.Tensor shape (2, num_unique)
453+
unique_scores: torch.Tensor shape (num_unique,)
454+
"""
455+
if edges.size(1) == 0:
456+
return edges, scores
457+
458+
# Create dictionary: edge_tuple -> (score, index)
459+
edge_dict = {}
460+
for idx in range(edges.size(1)):
461+
edge_tuple = (edges[0, idx].item(), edges[1, idx].item())
462+
score = scores[idx].item()
463+
464+
# Keep edge with higher score
465+
if edge_tuple not in edge_dict or score > edge_dict[edge_tuple][0]:
466+
edge_dict[edge_tuple] = (score, idx)
467+
468+
# Extract unique edges and their scores
469+
unique_indices = [v[1] for v in edge_dict.values()]
470+
unique_edges = edges[:, unique_indices]
471+
unique_scores = scores[unique_indices]
472+
473+
# Re-sort by scores
474+
sorted_scores, sorted_indices = torch.sort(unique_scores, descending=True)
475+
unique_edges = unique_edges[:, sorted_indices]
405476

477+
return unique_edges, sorted_scores
406478

407-
def _predict_with_faiss(model, data, h_norm, existing_set, k, faiss_k_per_node):
479+
480+
def _predict_with_faiss(model, data, h_norm, existing_set, k, faiss_k_per_node,
481+
is_directed, remove_loops):
408482
"""Prediction using FAISS for fast search"""
409483
import faiss
410484

411485
num_nodes = h_norm.size(0)
412486
embedding_dim = h_norm.size(1)
413487

414488
# Parameter: how many candidates to take per node
415-
faiss_k_per_node = min(faiss_k_per_node, num_nodes)
489+
# For undirected graphs, we need more candidates to account for deduplication
490+
if not is_directed:
491+
faiss_k_per_node = min(faiss_k_per_node * 2, num_nodes)
492+
else:
493+
faiss_k_per_node = min(faiss_k_per_node, num_nodes)
416494

417495
# Build FAISS index for all nodes
418496
h_np = h_norm.cpu().numpy().astype('float32')
@@ -430,23 +508,37 @@ def _predict_with_faiss(model, data, h_norm, existing_set, k, faiss_k_per_node):
430508
all_pairs = []
431509
for src_idx in range(num_nodes):
432510
for rank, dst_idx in enumerate(candidate_indices[src_idx]):
511+
dst_idx = int(dst_idx)
512+
513+
# Remove loops if requested
514+
if remove_loops and src_idx == dst_idx:
515+
continue
516+
517+
# Normalize edge for undirected graph
518+
if not is_directed:
519+
edge = (min(src_idx, dst_idx), max(src_idx, dst_idx))
520+
else:
521+
edge = (src_idx, dst_idx)
522+
433523
# Skip existing edges
434-
if (src_idx, int(dst_idx)) in existing_set:
435-
print(f"Score for existing ({src_idx}, {dst_idx}): {similarities[src_idx, rank]}")
524+
if edge in existing_set:
436525
continue
437526

438527
# Save (src, dst, similarity)
439-
all_pairs.append((src_idx, int(dst_idx), similarities[src_idx, rank]))
528+
all_pairs.append((edge[0], edge[1], similarities[src_idx, rank]))
440529

441530
print(f"Found {len(all_pairs)} candidate pairs after filtering")
442531

443532
if len(all_pairs) == 0:
444533
print("Warning: No candidate pairs found!")
445534
return torch.zeros((2, 0), dtype=torch.long), torch.zeros(0)
446535

447-
# Sort by similarity and take top-k
536+
# Sort by similarity and take more than k to account for deduplication
448537
all_pairs.sort(key=lambda x: x[2], reverse=True)
449-
top_pairs = all_pairs[:min(k, len(all_pairs))]
538+
539+
# For undirected graphs, we already normalized, so no duplicates
540+
# But we take more candidates to be safe
541+
top_pairs = all_pairs[:min(k * 2 if not is_directed else k, len(all_pairs))]
450542

451543
# Now compute exact scores via model for final candidates
452544
h = model(data.x, data.edge_index)
@@ -461,26 +553,34 @@ def _predict_with_faiss(model, data, h_norm, existing_set, k, faiss_k_per_node):
461553
final_scores = model.decode(h_src_final, h_dst_final)
462554
final_scores = final_scores.sigmoid()
463555

464-
# Re-sort by exact scores
465-
sorted_scores, sorted_indices = torch.sort(final_scores, descending=True)
466-
final_src = final_src[sorted_indices]
467-
final_dst = final_dst[sorted_indices]
468-
top_edges = torch.stack([final_src, final_dst], dim=0)
556+
# Create edges tensor
557+
edges = torch.stack([final_src, final_dst], dim=0)
469558

470-
print(f"Top-{len(sorted_scores)} edges found with scores from "
471-
f"{sorted_scores[-1]:.4f} to {sorted_scores[0]:.4f}")
559+
# Deduplicate if undirected (should already be normalized, but just in case)
560+
if not is_directed:
561+
edges, final_scores = _deduplicate_edges(edges, final_scores)
472562

473-
return top_edges.cpu(), sorted_scores.cpu()
563+
# Take final top-k
564+
final_k = min(k, edges.size(1))
565+
top_edges = edges[:, :final_k]
566+
top_scores = final_scores[:final_k]
474567

568+
print(f"Top-{final_k} edges found with scores from "
569+
f"{top_scores[-1]:.4f} to {top_scores[0]:.4f}")
475570

476-
def _predict_full_enumeration(model, data, h, existing_set, k):
571+
return top_edges.cpu(), top_scores.cpu()
572+
573+
574+
def _predict_full_enumeration(model, data, h, existing_set, k, is_directed, remove_loops):
477575
"""Prediction with full enumeration of all pairs (for small graphs)"""
478576
num_nodes = h.size(0)
479577
device = h.device
480578

481579
# Request more candidates with margin for filtering
482580
k_candidates = k + (len(existing_set) if existing_set else 0)
483-
k_candidates = min(k_candidates * 2, num_nodes * num_nodes) # with 2x margin
581+
if not is_directed:
582+
k_candidates = k_candidates * 3 # More margin for undirected due to deduplication
583+
k_candidates = min(k_candidates * 2, num_nodes * num_nodes)
484584

485585
# Store global top-k
486586
global_top_scores = None
@@ -500,6 +600,22 @@ def _predict_full_enumeration(model, data, h, existing_set, k):
500600
src_repeated = src_batch.repeat_interleave(num_nodes)
501601
dst_repeated = dst_nodes.repeat(batch_src_size)
502602

603+
# Filter before computing scores
604+
# Remove loops if requested
605+
if remove_loops:
606+
loop_mask = src_repeated != dst_repeated
607+
src_repeated = src_repeated[loop_mask]
608+
dst_repeated = dst_repeated[loop_mask]
609+
610+
# For undirected graphs, only keep pairs where src <= dst to avoid duplicates
611+
if not is_directed:
612+
undirected_mask = src_repeated <= dst_repeated
613+
src_repeated = src_repeated[undirected_mask]
614+
dst_repeated = dst_repeated[undirected_mask]
615+
616+
if len(src_repeated) == 0:
617+
continue
618+
503619
# Prepare embeddings for decode
504620
h_src_batch = h[src_repeated]
505621
h_dst_batch = h[dst_repeated]
@@ -546,16 +662,21 @@ def _predict_full_enumeration(model, data, h, existing_set, k):
546662
global_top_src = global_top_src[mask]
547663
global_top_dst = global_top_dst[mask]
548664

665+
# Create edges tensor
666+
edges = torch.stack([global_top_src, global_top_dst], dim=0)
667+
668+
# For undirected graphs, edges should already be normalized (src <= dst)
669+
# But deduplicate just in case
670+
if not is_directed:
671+
edges, global_top_scores = _deduplicate_edges(edges, global_top_scores)
672+
549673
# Final top-k
550-
final_k = min(k, len(global_top_scores))
674+
final_k = min(k, edges.size(1))
551675
if final_k < k:
552676
print(f"Warning: Only {final_k} unique pairs available, returning all")
553677

678+
top_edges = edges[:, :final_k]
554679
top_scores = global_top_scores[:final_k]
555-
top_src = global_top_src[:final_k]
556-
top_dst = global_top_dst[:final_k]
557-
558-
top_edges = torch.stack([top_src, top_dst], dim=0)
559680

560681
print(f"Top-{final_k} edges found with scores from {top_scores[-1]:.4f} to {top_scores[0]:.4f}")
561682

requirements3.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ matplotlib
1919
# For other datasets
2020
dgl==1.1.0
2121
mat73 # powergraph
22+
23+
# For edge prediction on big graphs
24+
faiss-cpu==1.13.2

0 commit comments

Comments
 (0)