Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/spanningtrees/kruskal.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,39 @@
"""
kruskal_mst(g, distmx=weights(g); minimize=true)
kruskal_mst(g, weight_vector; minimize=true)

Return a vector of edges representing the minimum (by default) spanning tree of a connected,
undirected graph `g` with optional distance matrix `distmx` using [Kruskal's algorithm](https://en.wikipedia.org/wiki/Kruskal%27s_algorithm).

Alternative to the distance matrix `distmx`, one can pass a `weight_vector` with weights ordered as `edges(g)`.

### Optional Arguments
- `minimize=true`: if set to `false`, calculate the maximum spanning tree.
"""
function kruskal_mst end
# see https://github.com/mauro3/SimpleTraits.jl/issues/47#issuecomment-327880153 for syntax
@traitfn function kruskal_mst(
g::AG::(!IsDirected), distmx::AbstractMatrix{T}=weights(g); minimize=true
) where {T<:Number,U,AG<:AbstractGraph{U}}
weight_vector = Vector{T}()
sizehint!(weight_vector, ne(g))
for e in edges(g)
push!(weight_vector, distmx[src(e), dst(e)])
end
return kruskal_mst(g, weight_vector; minimize=minimize)
end

@traitfn function kruskal_mst(
g::AG::(!IsDirected), weight_vector::AbstractVector{T}; minimize=true
) where {T<:Number,U,AG<:AbstractGraph{U}}
connected_vs = IntDisjointSet(nv(g))

mst = Vector{edgetype(g)}()
nv(g) <= 1 && return mst
sizehint!(mst, nv(g) - 1)

weights = Vector{T}()
sizehint!(weights, ne(g))
edge_list = collect(edges(g))
for e in edge_list
push!(weights, distmx[src(e), dst(e)])
end

for e in edge_list[sortperm(weights; rev=(!minimize))]
for e in edge_list[sortperm(weight_vector; rev=(!minimize))]
if !in_same_set(connected_vs, src(e), dst(e))
union!(connected_vs, src(e), dst(e))
push!(mst, e)
Expand Down
18 changes: 18 additions & 0 deletions test/spanningtrees/kruskal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
6 10 3 0
]

weight_vector = [distmx[src(e), dst(e)] for e in edges(g4)]

vec_mst = Vector{Edge}([Edge(1, 2), Edge(3, 4), Edge(2, 3)])
max_vec_mst = Vector{Edge}([Edge(2, 4), Edge(1, 4), Edge(1, 3)])
for g in test_generic_graphs(g4)
Expand All @@ -18,7 +20,15 @@
# so instead we compare tuples of source and target vertices
@test sort([(src(e), dst(e)) for e in mst]) == sort([(src(e), dst(e)) for e in vec_mst])
@test sort([(src(e), dst(e)) for e in max_mst]) == sort([(src(e), dst(e)) for e in max_vec_mst])
# test equivalent vector form
mst_vec = @inferred(kruskal_mst(g, weight_vector))
max_mst_vec = @inferred(kruskal_mst(g, weight_vector, minimize=false))
@test src.(mst_vec) == src.(mst)
@test dst.(mst_vec) == dst.(mst)
@test src.(max_mst_vec) == src.(max_mst)
@test dst.(max_mst_vec) == dst.(max_mst)
end

# second test
distmx_sec = [
0 0 0.26 0 0.38 0 0.58 0.16
Expand All @@ -32,6 +42,8 @@
]

gx = SimpleGraph(distmx_sec)
weight_vector_sec = [distmx_sec[src(e), dst(e)] for e in edges(gx)]

vec2 = Vector{Edge}([
Edge(1, 8), Edge(3, 4), Edge(2, 8), Edge(1, 3), Edge(6, 8), Edge(5, 6), Edge(3, 7)
])
Expand All @@ -40,9 +52,15 @@
])
for g in test_generic_graphs(gx)
mst2 = @inferred(kruskal_mst(g, distmx_sec))
mst2_vec = @inferred(kruskal_mst(g, weight_vector_sec))
max_mst2 = @inferred(kruskal_mst(g, distmx_sec, minimize=false))
max_mst2_vec = @inferred(kruskal_mst(g, weight_vector_sec, minimize=false))
@test sort([(src(e), dst(e)) for e in mst2]) == sort([(src(e), dst(e)) for e in vec2])
@test sort([(src(e), dst(e)) for e in max_mst2]) == sort([(src(e), dst(e)) for e in max_vec2])
@test src.(mst2) == src.(mst2_vec)
@test dst.(mst2) == dst.(mst2_vec)
@test src.(max_mst2) == src.(max_mst2_vec)
@test dst.(max_mst2) == dst.(max_mst2_vec)
end

# non regression test for #362
Expand Down
Loading