Skip to content

Commit 07c9fbd

Browse files
authored
Make Kruskal spanning tree accept weight vector instead of matrix (#486)
* accept spanning tree vector instead of matrix for Kruskal * replace equality with src and dst
1 parent acb2c4d commit 07c9fbd

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

src/spanningtrees/kruskal.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,39 @@
11
"""
22
kruskal_mst(g, distmx=weights(g); minimize=true)
3+
kruskal_mst(g, weight_vector; minimize=true)
34
45
Return a vector of edges representing the minimum (by default) spanning tree of a connected,
56
undirected graph `g` with optional distance matrix `distmx` using [Kruskal's algorithm](https://en.wikipedia.org/wiki/Kruskal%27s_algorithm).
67
8+
Alternative to the distance matrix `distmx`, one can pass a `weight_vector` with weights ordered as `edges(g)`.
9+
710
### Optional Arguments
811
- `minimize=true`: if set to `false`, calculate the maximum spanning tree.
912
"""
1013
function kruskal_mst end
1114
# see https://github.com/mauro3/SimpleTraits.jl/issues/47#issuecomment-327880153 for syntax
1215
@traitfn function kruskal_mst(
1316
g::AG::(!IsDirected), distmx::AbstractMatrix{T}=weights(g); minimize=true
17+
) where {T<:Number,U,AG<:AbstractGraph{U}}
18+
weight_vector = Vector{T}()
19+
sizehint!(weight_vector, ne(g))
20+
for e in edges(g)
21+
push!(weight_vector, distmx[src(e), dst(e)])
22+
end
23+
return kruskal_mst(g, weight_vector; minimize=minimize)
24+
end
25+
26+
@traitfn function kruskal_mst(
27+
g::AG::(!IsDirected), weight_vector::AbstractVector{T}; minimize=true
1428
) where {T<:Number,U,AG<:AbstractGraph{U}}
1529
connected_vs = IntDisjointSet(nv(g))
1630

1731
mst = Vector{edgetype(g)}()
1832
nv(g) <= 1 && return mst
1933
sizehint!(mst, nv(g) - 1)
2034

21-
weights = Vector{T}()
22-
sizehint!(weights, ne(g))
2335
edge_list = collect(edges(g))
24-
for e in edge_list
25-
push!(weights, distmx[src(e), dst(e)])
26-
end
27-
28-
for e in edge_list[sortperm(weights; rev=(!minimize))]
36+
for e in edge_list[sortperm(weight_vector; rev=(!minimize))]
2937
if !in_same_set(connected_vs, src(e), dst(e))
3038
union!(connected_vs, src(e), dst(e))
3139
push!(mst, e)

test/spanningtrees/kruskal.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
6 10 3 0
99
]
1010

11+
weight_vector = [distmx[src(e), dst(e)] for e in edges(g4)]
12+
1113
vec_mst = Vector{Edge}([Edge(1, 2), Edge(3, 4), Edge(2, 3)])
1214
max_vec_mst = Vector{Edge}([Edge(2, 4), Edge(1, 4), Edge(1, 3)])
1315
for g in test_generic_graphs(g4)
@@ -18,7 +20,15 @@
1820
# so instead we compare tuples of source and target vertices
1921
@test sort([(src(e), dst(e)) for e in mst]) == sort([(src(e), dst(e)) for e in vec_mst])
2022
@test sort([(src(e), dst(e)) for e in max_mst]) == sort([(src(e), dst(e)) for e in max_vec_mst])
23+
# test equivalent vector form
24+
mst_vec = @inferred(kruskal_mst(g, weight_vector))
25+
max_mst_vec = @inferred(kruskal_mst(g, weight_vector, minimize=false))
26+
@test src.(mst_vec) == src.(mst)
27+
@test dst.(mst_vec) == dst.(mst)
28+
@test src.(max_mst_vec) == src.(max_mst)
29+
@test dst.(max_mst_vec) == dst.(max_mst)
2130
end
31+
2232
# second test
2333
distmx_sec = [
2434
0 0 0.26 0 0.38 0 0.58 0.16
@@ -32,6 +42,8 @@
3242
]
3343

3444
gx = SimpleGraph(distmx_sec)
45+
weight_vector_sec = [distmx_sec[src(e), dst(e)] for e in edges(gx)]
46+
3547
vec2 = Vector{Edge}([
3648
Edge(1, 8), Edge(3, 4), Edge(2, 8), Edge(1, 3), Edge(6, 8), Edge(5, 6), Edge(3, 7)
3749
])
@@ -40,9 +52,15 @@
4052
])
4153
for g in test_generic_graphs(gx)
4254
mst2 = @inferred(kruskal_mst(g, distmx_sec))
55+
mst2_vec = @inferred(kruskal_mst(g, weight_vector_sec))
4356
max_mst2 = @inferred(kruskal_mst(g, distmx_sec, minimize=false))
57+
max_mst2_vec = @inferred(kruskal_mst(g, weight_vector_sec, minimize=false))
4458
@test sort([(src(e), dst(e)) for e in mst2]) == sort([(src(e), dst(e)) for e in vec2])
4559
@test sort([(src(e), dst(e)) for e in max_mst2]) == sort([(src(e), dst(e)) for e in max_vec2])
60+
@test src.(mst2) == src.(mst2_vec)
61+
@test dst.(mst2) == dst.(mst2_vec)
62+
@test src.(max_mst2) == src.(max_mst2_vec)
63+
@test dst.(max_mst2) == dst.(max_mst2_vec)
4664
end
4765

4866
# non regression test for #362

0 commit comments

Comments
 (0)