Skip to content

Commit 3f11d45

Browse files
[BUG] Fix Calculation of # Sampled Nodes/Edges with Zero Input Size (#283)
Fixes a bug found when running `movielens_mnmg.py` on input with a large number of negative edges. Some batches end up with no positive edges of the input type, which causes PyTorch to throw an exception when trying to calculate the number of nodes, since the returned tensor is empty. The fix is to check if the returned tensor is empty, and just set the number of output nodes to zero if it is. Authors: - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Tingyu Wang (https://github.com/tingyu66) URL: #283
1 parent f2b7f50 commit 3f11d45

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

python/cugraph-pyg/cugraph_pyg/sampler/sampler.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,23 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
371371
# heterogeneous edges as input, two node types per edge type
372372
ux = col[pyg_can_etype][: num_sampled_edges[pyg_can_etype][0]]
373373
uy = row[pyg_can_etype][: num_sampled_edges[pyg_can_etype][0]]
374+
uxn = (
375+
(ux.max() + 1)
376+
if ux.numel() > 0
377+
else torch.tensor(0, device=ux.device)
378+
)
374379
num_sampled_nodes[self.__dst_types[etype]][0] = torch.max(
375380
num_sampled_nodes[self.__dst_types[etype]][0],
376-
(ux.max() + 1).reshape((1,)),
381+
uxn.reshape((1,)),
382+
)
383+
uyn = (
384+
(uy.max() + 1)
385+
if uy.numel() > 0
386+
else torch.tensor(0, device=uy.device)
377387
)
378388
num_sampled_nodes[self.__src_types[etype]][0] = torch.max(
379389
num_sampled_nodes[self.__src_types[etype]][0],
380-
(uy.max() + 1).reshape((1,)),
390+
uyn.reshape((1,)),
381391
)
382392
elif isinstance(input_type, str) and input_type == pyg_can_etype[2]:
383393
integer_input_type = self.__src_types[etype]

0 commit comments

Comments
 (0)