Skip to content

Commit a9ab8b4

Browse files
[FEA] Support Heterogeneous Sampling in cuGraph-PyG (#82)
Allows sampling of heterogeneous graphs. Removes unbuffered sampling from the PyG examples and completely disables it in DGL. A future PR will completely drop PyG support for unbuffered sampling, and a future `cugraph` PR will drop support for unbuffered sampling in the distributed sampler. Merge after rapidsai/cugraph#4795 Closes rapidsai/cugraph#4402 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Tingyu Wang (https://github.com/tingyu66) - James Lamb (https://github.com/jameslamb) URL: #82
1 parent 87455cf commit a9ab8b4

File tree

13 files changed

+515
-196
lines changed

13 files changed

+515
-196
lines changed

ci/test_notebooks.sh

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
#!/bin/bash
2-
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
2+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
33

44
set -Eeuo pipefail
55

66
. /opt/conda/etc/profile.d/conda.sh
77

88
RAPIDS_VERSION="$(rapids-version)"
99

10+
rapids-logger "Downloading artifacts from previous jobs"
11+
CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp)
12+
PYTHON_CHANNEL=$(rapids-download-conda-from-s3 python)
13+
1014
rapids-logger "Generate notebook testing dependencies"
1115
rapids-dependency-file-generator \
1216
--output conda \
1317
--file-key test_notebooks \
14-
--matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml
18+
--matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" \
19+
--prepend-channel dglteam/label/th23_cu118 \
20+
--prepend-channel "${CPP_CHANNEL}" \
21+
--prepend-channel "${PYTHON_CHANNEL}" \
22+
| tee env.yaml
1523

1624
rapids-mamba-retry env create --yes -f env.yaml -n test
1725

@@ -22,16 +30,6 @@ set -u
2230

2331
rapids-print-env
2432

25-
rapids-logger "Downloading artifacts from previous jobs"
26-
CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp)
27-
PYTHON_CHANNEL=$(rapids-download-conda-from-s3 python)
28-
29-
rapids-mamba-retry install \
30-
--channel "${CPP_CHANNEL}" \
31-
--channel "${PYTHON_CHANNEL}" \
32-
--channel dglteam/label/th23_cu118 \
33-
"cugraph-dgl=${RAPIDS_VERSION}"
34-
3533
NBTEST="$(realpath "$(dirname "$0")/utils/nbtest.sh")"
3634
NOTEBOOK_LIST="$(realpath "$(dirname "$0")/notebook_list.py")"
3735
EXITCODE=0

dependencies.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ files:
5454
includes:
5555
- cuda_version
5656
- depends_on_pytorch
57+
- depends_on_cugraph_dgl
5758
- py_version
5859
- test_notebook
5960
test_python:
@@ -540,6 +541,12 @@ dependencies:
540541
- cugraph-cu11==25.2.*,>=0.0.0a0
541542
- {matrix: null, packages: [*cugraph_unsuffixed]}
542543

544+
depends_on_cugraph_dgl:
545+
common:
546+
- output_types: conda
547+
packages:
548+
- cugraph-dgl==25.2.*,>=0.0.0a0
549+
543550
depends_on_cudf:
544551
common:
545552
- output_types: conda

python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -14,11 +14,10 @@
1414
from __future__ import annotations
1515

1616
import warnings
17-
import tempfile
1817

1918
from typing import Sequence, Optional, Union, List, Tuple, Iterator
2019

21-
from cugraph.gnn import UniformNeighborSampler, BiasedNeighborSampler, DistSampleWriter
20+
from cugraph.gnn import UniformNeighborSampler, BiasedNeighborSampler
2221
from cugraph.utilities.utils import import_optional
2322

2423
import cugraph_dgl
@@ -124,7 +123,7 @@ def __init__(
124123
Can be either "dgl.Block" (default), or "cugraph_dgl.nn.SparseGraph".
125124
**kwargs
126125
Keyword arguments for the underlying cuGraph distributed sampler
127-
and writer (directory, batches_per_partition, format,
126+
and writer (batches_per_partition, format,
128127
local_seeds_per_call).
129128
"""
130129

@@ -165,18 +164,6 @@ def sample(
165164
) -> Iterator[DGLSamplerOutput]:
166165
kwargs = dict(**self.__kwargs)
167166

168-
directory = kwargs.pop("directory", None)
169-
if directory is None:
170-
warnings.warn("Setting a directory to store samples is recommended.")
171-
self._tempdir = tempfile.TemporaryDirectory()
172-
directory = self._tempdir.name
173-
174-
writer = DistSampleWriter(
175-
directory=directory,
176-
batches_per_partition=kwargs.pop("batches_per_partition", 256),
177-
format=kwargs.pop("format", "parquet"),
178-
)
179-
180167
sampling_clx = (
181168
UniformNeighborSampler
182169
if self.__prob_attr is None
@@ -185,7 +172,7 @@ def sample(
185172

186173
ds = sampling_clx(
187174
g._graph(self.edge_dir, prob_attr=self.__prob_attr),
188-
writer,
175+
writer=None,
189176
compression="CSR",
190177
fanout=self._reversed_fanout_vals,
191178
prior_sources_behavior="carryover",

python/cugraph-pyg/cugraph_pyg/data/graph_store.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -70,6 +70,7 @@ def __clear_graph(self):
7070
self.__graph = None
7171
self.__vertex_offsets = None
7272
self.__weight_attr = None
73+
self.__numeric_edge_types = None
7374

7475
def _put_edge_index(
7576
self,
@@ -240,6 +241,27 @@ def _vertex_offsets(self) -> Dict[str, int]:
240241

241242
return dict(self.__vertex_offsets)
242243

244+
@property
245+
def _vertex_offset_array(self) -> "torch.Tensor":
246+
off = torch.tensor(
247+
[self._vertex_offsets[k] for k in sorted(self._vertex_offsets.keys())],
248+
dtype=torch.int64,
249+
device="cuda",
250+
)
251+
252+
return torch.concat(
253+
[
254+
off,
255+
torch.tensor(
256+
list(self._num_vertices().values()),
257+
device="cuda",
258+
dtype=torch.int64,
259+
)
260+
.sum()
261+
.reshape((1,)),
262+
]
263+
)
264+
243265
@property
244266
def is_homogeneous(self) -> bool:
245267
return len(self._vertex_offsets) == 1
@@ -270,6 +292,38 @@ def __get_weight_tensor(
270292

271293
return torch.concat(weights)
272294

295+
@property
296+
def _numeric_edge_types(self) -> Tuple[List, "torch.Tensor", "torch.Tensor"]:
297+
"""
298+
Returns the canonical edge types in order (the 0th canonical type corresponds
299+
to numeric edge type 0, etc.), along with the numeric source and destination
300+
vertex types for each edge type.
301+
"""
302+
303+
if self.__numeric_edge_types is None:
304+
sorted_keys = sorted(
305+
list(self.__edge_indices.keys(leaves_only=True, include_nested=True))
306+
)
307+
308+
vtype_table = {
309+
k: i for i, k in enumerate(sorted(self._vertex_offsets.keys()))
310+
}
311+
312+
srcs = []
313+
dsts = []
314+
315+
for can_etype in sorted_keys:
316+
srcs.append(vtype_table[can_etype[0]])
317+
dsts.append(vtype_table[can_etype[2]])
318+
319+
self.__numeric_edge_types = (
320+
sorted_keys,
321+
torch.tensor(srcs, device="cuda", dtype=torch.int32),
322+
torch.tensor(dsts, device="cuda", dtype=torch.int32),
323+
)
324+
325+
return self.__numeric_edge_types
326+
273327
def __get_edgelist(self):
274328
"""
275329
Returns

python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -17,7 +17,6 @@
1717
import argparse
1818
import os
1919
import warnings
20-
import tempfile
2120
import time
2221
import json
2322

@@ -179,7 +178,6 @@ def run_train(
179178
fan_out,
180179
num_classes,
181180
wall_clock_start,
182-
tempdir=None,
183181
num_layers=3,
184182
in_memory=False,
185183
seeds_per_call=-1,
@@ -194,41 +192,29 @@ def run_train(
194192
from cugraph_pyg.loader import NeighborLoader
195193

196194
ix_train = split_idx["train"].cuda()
197-
train_path = None if in_memory else os.path.join(tempdir, f"train_{global_rank}")
198-
if train_path:
199-
os.mkdir(train_path)
200195
train_loader = NeighborLoader(
201196
data,
202197
input_nodes=ix_train,
203-
directory=train_path,
204198
shuffle=True,
205199
drop_last=True,
206200
local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None,
207201
**kwargs,
208202
)
209203

210204
ix_test = split_idx["test"].cuda()
211-
test_path = None if in_memory else os.path.join(tempdir, f"test_{global_rank}")
212-
if test_path:
213-
os.mkdir(test_path)
214205
test_loader = NeighborLoader(
215206
data,
216207
input_nodes=ix_test,
217-
directory=test_path,
218208
shuffle=True,
219209
drop_last=True,
220210
local_seeds_per_call=80000,
221211
**kwargs,
222212
)
223213

224214
ix_valid = split_idx["valid"].cuda()
225-
valid_path = None if in_memory else os.path.join(tempdir, f"valid_{global_rank}")
226-
if valid_path:
227-
os.mkdir(valid_path)
228215
valid_loader = NeighborLoader(
229216
data,
230217
input_nodes=ix_valid,
231-
directory=valid_path,
232218
shuffle=True,
233219
drop_last=True,
234220
local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None,
@@ -347,7 +333,6 @@ def parse_args():
347333
parser.add_argument("--epochs", type=int, default=4)
348334
parser.add_argument("--batch_size", type=int, default=1024)
349335
parser.add_argument("--fan_out", type=int, default=30)
350-
parser.add_argument("--tempdir_root", type=str, default=None)
351336
parser.add_argument("--dataset_root", type=str, default="datasets")
352337
parser.add_argument("--dataset", type=str, default="ogbn-products")
353338
parser.add_argument("--skip_partition", action="store_true")
@@ -427,23 +412,21 @@ def parse_args():
427412
).to(device)
428413
model = DistributedDataParallel(model, device_ids=[local_rank])
429414

430-
with tempfile.TemporaryDirectory(dir=args.tempdir_root) as tempdir:
431-
run_train(
432-
global_rank,
433-
data,
434-
split_idx,
435-
world_size,
436-
device,
437-
model,
438-
args.epochs,
439-
args.batch_size,
440-
args.fan_out,
441-
meta["num_classes"],
442-
wall_clock_start,
443-
tempdir,
444-
args.num_layers,
445-
args.in_memory,
446-
args.seeds_per_call,
447-
)
415+
run_train(
416+
global_rank,
417+
data,
418+
split_idx,
419+
world_size,
420+
device,
421+
model,
422+
args.epochs,
423+
args.batch_size,
424+
args.fan_out,
425+
meta["num_classes"],
426+
wall_clock_start,
427+
args.num_layers,
428+
args.in_memory,
429+
args.seeds_per_call,
430+
)
448431
else:
449432
warnings.warn("This script should be run with 'torchrun`. Exiting.")

0 commit comments

Comments
 (0)