Skip to content

Commit 4005ab1

Browse files
[FEA] Support Standard Temporal Sampling Behavior (#347)
Supports standard temporal sampling behavior. Has _limited_ support for "last n" behavior; the remaining requirements for this will be added in a future PR. Merge after rapidsai/cugraph#5345
1 parent ec608df commit 4005ab1

File tree

6 files changed

+147
-60
lines changed

6 files changed

+147
-60
lines changed

python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
batch_size: int = 16, # Refers to number of edges per batch.
6060
compression: Optional[str] = None,
6161
local_seeds_per_call: Optional[int] = None,
62+
temporal_comparison: Optional[str] = None,
6263
**kwargs,
6364
):
6465
"""
@@ -128,12 +129,21 @@ def __init__(
128129
all workers. If not provided, it will be automatically
129130
calculated.
130131
See cugraph_pyg.sampler.BaseDistributedSampler.
132+
temporal_comparison: str (optional, default='monotonically decreasing')
133+
The comparison operator for temporal sampling
134+
('strictly increasing', 'monotonically increasing',
135+
'strictly decreasing', 'monotonically decreasing', 'last').
136+
Note that this should be 'last' for temporal_strategy='last'.
137+
See cugraph_pyg.sampler.BaseDistributedSampler.
131138
**kwargs
132139
Other keyword arguments passed to the superclass.
133140
"""
134141

135142
subgraph_type = torch_geometric.sampler.base.SubgraphType(subgraph_type)
136143

144+
if temporal_comparison is None:
145+
temporal_comparison = "monotonically decreasing"
146+
137147
if not directed:
138148
subgraph_type = torch_geometric.sampler.base.SubgraphType.induced
139149
warnings.warn(
@@ -172,17 +182,15 @@ def __init__(
172182

173183
is_temporal = (edge_label_time is not None) and (time_attr is not None)
174184

185+
if (edge_label_time is None) != (time_attr is None):
186+
warnings.warn(
187+
"Edge-based temporal sampling requires that both edge_label_time and time_attr are provided. Defaulting to non-temporal sampling."
188+
)
189+
175190
if weight_attr is not None:
176191
graph_store._set_weight_attr((feature_store, weight_attr))
177192
if is_temporal:
178-
# TODO Confirm that time is an edge attribute
179-
# TODO Add support for time override (see rapidsai/cugraph#5263)
180193
graph_store._set_etime_attr((feature_store, time_attr))
181-
warnings.warn(
182-
"Temporal sampling in cuGraph-PyG is currently only forward in time"
183-
" instead of the expected backward in time. This will be fixed in a"
184-
" future release."
185-
)
186194

187195
if isinstance(num_neighbors, dict):
188196
sorted_keys, _, _ = graph_store._numeric_edge_types
@@ -209,6 +217,7 @@ def __init__(
209217
biased=(weight_attr is not None),
210218
heterogeneous=(not graph_store.is_homogeneous),
211219
temporal=is_temporal,
220+
temporal_comparison=temporal_comparison,
212221
vertex_type_offsets=graph_store._vertex_offset_array,
213222
num_edge_types=len(graph_store.get_all_edge_attrs()),
214223
),

python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
batch_size: int = 16,
5757
compression: Optional[str] = None,
5858
local_seeds_per_call: Optional[int] = None,
59+
temporal_comparison: Optional[str] = None,
5960
**kwargs,
6061
):
6162
"""
@@ -121,12 +122,21 @@ def __init__(
121122
all workers. If not provided, it will be automatically
122123
calculated.
123124
See cugraph_pyg.sampler.BaseDistributedSampler.
125+
temporal_comparison: str (optional, default='monotonically decreasing')
126+
The comparison operator for temporal sampling
127+
('strictly increasing', 'monotonically increasing',
128+
'strictly decreasing', 'monotonically decreasing', 'last').
129+
Note that this should be 'last' for temporal_strategy='last'.
130+
See cugraph_pyg.sampler.BaseDistributedSampler.
124131
**kwargs
125132
Other keyword arguments passed to the superclass.
126133
"""
127134

128135
subgraph_type = torch_geometric.sampler.base.SubgraphType(subgraph_type)
129136

137+
if temporal_comparison is None:
138+
temporal_comparison = "monotonically decreasing"
139+
130140
if not directed:
131141
subgraph_type = torch_geometric.sampler.base.SubgraphType.induced
132142
warnings.warn(
@@ -166,15 +176,18 @@ def __init__(
166176
is_temporal = time_attr is not None
167177

168178
if is_temporal:
169-
# TODO Confirm that time is an edge attribute
170-
# TODO Add support for time override (see rapidsai/cugraph#5263)
171179
graph_store._set_etime_attr((feature_store, time_attr))
172180

173-
warnings.warn(
174-
"Temporal sampling in cuGraph-PyG is currently only forward in time"
175-
" instead of the expected backward in time. This will be fixed in a"
176-
" future release."
177-
)
181+
if input_time is None:
182+
input_type, input_nodes, _ = (
183+
torch_geometric.loader.utils.get_input_nodes(
184+
data, input_nodes, None
185+
)
186+
)
187+
if input_type is None:
188+
input_type = list(graph_store._vertex_offsets.keys())[0]
189+
# will assume the time attribute exists for nodes as well
190+
input_time = feature_store[input_type, time_attr, None][input_nodes]
178191

179192
if weight_attr is not None:
180193
graph_store._set_weight_attr((feature_store, weight_attr))
@@ -204,6 +217,7 @@ def __init__(
204217
biased=(weight_attr is not None),
205218
heterogeneous=(not graph_store.is_homogeneous),
206219
temporal=is_temporal,
220+
temporal_comparison=temporal_comparison,
207221
vertex_type_offsets=graph_store._vertex_offset_array,
208222
num_edge_types=len(graph_store.get_all_edge_attrs()),
209223
),

python/cugraph-pyg/cugraph_pyg/loader/node_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(
116116
if input_id is None
117117
else input_id,
118118
node=input_nodes,
119-
time=None,
119+
time=input_time,
120120
input_type=input_type,
121121
)
122122

0 commit comments

Comments
 (0)