Skip to content

Commit 0bb7639

Browse files
authored
Updates to Benchmark Script (#51)
## Changes - Skip functions not implemented by nx-cugraph - Handle exception where katz_centrality fails to converge and support storing benchmark time - Verify that input & output graph for ego_graph is consistent. Authors: - Ralph Liu (https://github.com/nv-rliu) Approvers: - Erik Welch (https://github.com/eriknw) URL: #51
1 parent b8c4a7d commit 0bb7639

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

benchmarks/pytest-based/bench_algos.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
1415
from collections.abc import Mapping
1516

1617
import networkx as nx
@@ -225,6 +226,19 @@ def build_personalization_dict(pagerank_dict):
225226
return pers_dict
226227

227228

229+
# Used to return a function that calls the original function inside a try-except block
230+
# which is useful because it allows us to save pytest-benchmark numbers if failure is
231+
# the correct behavior for certain graphs
232+
def possible_to_fail(exception, function):
233+
def nested_func(*args, **kwargs):
234+
try:
235+
return function(*args, **kwargs)
236+
except exception:
237+
print(f"{function.__name__} raised {exception}")
238+
239+
return nested_func
240+
241+
228242
################################################################################
229243
# Benchmarks
230244
def bench_from_networkx(benchmark, graph_obj):
@@ -366,7 +380,11 @@ def bench_in_degree_centrality(benchmark, graph_obj, backend_wrapper):
366380
def bench_katz_centrality(benchmark, graph_obj, backend_wrapper, normalized):
367381
G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper)
368382
result = benchmark.pedantic(
369-
target=backend_wrapper(nx.katz_centrality),
383+
# calling katz_centrality this way because the algorithm may fail to
384+
# converge for some graphs, which is expected
385+
target=possible_to_fail(
386+
nx.PowerIterationFailedConvergence, backend_wrapper(nx.katz_centrality)
387+
),
370388
args=(G,),
371389
kwargs=dict(
372390
normalized=normalized,
@@ -375,7 +393,7 @@ def bench_katz_centrality(benchmark, graph_obj, backend_wrapper, normalized):
375393
iterations=iterations,
376394
warmup_rounds=warmup_rounds,
377395
)
378-
assert type(result) is dict
396+
assert type(result) is dict or result is None
379397

380398

381399
def bench_k_truss(benchmark, graph_obj, backend_wrapper):
@@ -692,6 +710,7 @@ def bench_descendants_at_distance(benchmark, graph_obj, backend_wrapper):
692710
assert type(result) is set
693711

694712

713+
@pytest.mark.skip(reason="benchmark not implemented")
695714
def bench_is_bipartite(benchmark, graph_obj, backend_wrapper):
696715
G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper)
697716
result = benchmark.pedantic(
@@ -704,6 +723,7 @@ def bench_is_bipartite(benchmark, graph_obj, backend_wrapper):
704723
assert type(result) is bool
705724

706725

726+
@pytest.mark.skip(reason="benchmark not implemented")
707727
def bench_is_strongly_connected(benchmark, graph_obj, backend_wrapper):
708728
G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper)
709729
result = benchmark.pedantic(
@@ -728,6 +748,7 @@ def bench_is_weakly_connected(benchmark, graph_obj, backend_wrapper):
728748
assert type(result) is bool
729749

730750

751+
@pytest.mark.skip(reason="benchmark not implemented")
731752
def bench_number_strongly_connected_components(benchmark, graph_obj, backend_wrapper):
732753
G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper)
733754
result = benchmark.pedantic(
@@ -780,6 +801,7 @@ def bench_reciprocity(benchmark, graph_obj, backend_wrapper):
780801
assert type(result) is float
781802

782803

804+
@pytest.mark.skip(reason="benchmark not implemented")
783805
def bench_strongly_connected_components(benchmark, graph_obj, backend_wrapper):
784806
G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper)
785807
result = benchmark.pedantic(
@@ -850,7 +872,7 @@ def bench_ego_graph(benchmark, graph_obj, backend_wrapper):
850872
iterations=iterations,
851873
warmup_rounds=warmup_rounds,
852874
)
853-
assert isinstance(result, (nx.Graph, nxcg.Graph))
875+
assert type(result) is type(G)
854876

855877

856878
@pytest.mark.skip(reason="benchmark not implemented")

0 commit comments

Comments
 (0)