Skip to content

Commit 7847448

Browse files
author
Charles Gallay
committed
Cleaning code and correct the indentation for of the doc
1 parent 4d90887 commit 7847448

File tree

1 file changed

+53
-57
lines changed

1 file changed

+53
-57
lines changed

pygsp/graphs/graph.py

Lines changed: 53 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def to_graphtool(self, edge_prop_name='weight', directed=True):
160160
"""
161161
##from graph_tool.all import *
162162
import graph_tool
163-
g_gt = graph_tool.Graph(directed=directed) #TODO check for undirected graph
163+
g_gt = graph_tool.Graph(directed=directed)
164164
nonzero = self.W.nonzero()
165165
g_gt.add_edge_list(np.transpose(nonzero))
166166
edge_weight = g_gt.new_edge_property('double')
@@ -188,16 +188,18 @@ def from_networkx(cls, graph_nx, singals_names = []):
188188
-------
189189
g : :class:`~pygsp.graphs.Graph`
190190
"""
191-
192191
import networkx as nx
192+
#keep a consistent order of nodes for the agency matrix and the signal array
193193
nodelist = graph_nx.nodes()
194194
A = nx.to_scipy_sparse_matrix(graph_nx, nodelist)
195195
G = cls(A)
196+
#Adding the signals
196197
for s_name in singals_names:
197198
s_dict = nx.get_node_attributes(graph_nx, s_name)
198199
if len(s_dict.keys()) == 0:
199200
raise ValueError("Signal {} is not present in the networkx graph".format(s_name))
200-
s_value = np.array([s_dict[n] for n in nodelist]) #force the order to be same as for the agency matrix
201+
#The signal is set to zero for node not present in the networkx signal
202+
s_value = np.array([s_dict[n] if n in s_dict else 0 for n in nodelist])
201203
G.set_signal(s_value, s_name)
202204
return G
203205

@@ -214,22 +216,22 @@ def from_graphtool(cls, graph_gt, edge_prop_name='weight', aggr_fun=sum, singals
214216
to be loaded as weight for the graph. If the property is not found a graph with default weight set to 1 is created.
215217
On the other hand if the property is found but not set for a specific edge the weight of zero will be set
216218
therefore for single edge this will result in a none existing edge. If you want to set to a default value please
217-
use `set_value<https://graph-tool.skewed.de/static/doc/graph_tool.html?highlight=propertyarray#graph_tool.PropertyMap.set_value>`_
219+
use `set_value <https://graph-tool.skewed.de/static/doc/graph_tool.html?highlight=propertyarray#graph_tool.PropertyMap.set_value>`_
218220
from the graph_tool object.
219221
aggr_fun : function
220222
When the graph as multiple edge connecting the same two nodes the aggragate function is called to merge the
221223
edges. By default the sum is taken.
222224
singals_names : list[String] or 'all'
223225
List of signals names to import from the graph_tool graph or if set to 'all' import all signal present
224226
in the graph
227+
225228
Returns
226229
-------
227230
g : :class:`~pygsp.graphs.Graph`
228231
The weight of the graph are loaded from the edge property named ``edge_prop_name``
229232
230233
"""
231234
nb_vertex = len(graph_gt.get_vertices())
232-
edge_weight = np.ones(nb_vertex)
233235
W = np.zeros(shape=(nb_vertex, nb_vertex))
234236

235237
props_names = graph_gt.edge_properties.keys()
@@ -240,14 +242,16 @@ def from_graphtool(cls, graph_gt, edge_prop_name='weight', aggr_fun=sum, singals
240242
else:
241243
warnings.warn("""{} property not found in the graph, \
242244
weights of 1 for the edges are set""".format(edge_prop_name))
243-
edge_weight = np.ones(graph_gt.edge_index_range)
245+
edge_weight = np.ones(nb_vertex)
246+
244247
# merging multi-edge
245248
merged_edge_weight = []
246249
for k, grp in groupby(graph_gt.get_edges(), key=lambda e: (e[0], e[1])):
247250
merged_edge_weight.append((k[0], k[1], aggr_fun([edge_weight[e[2]] for e in grp])))
248251
for e in merged_edge_weight:
249252
W[e[0], e[1]] = e[2]
250253
g = cls(W)
254+
251255
#Adding signals
252256
if singals_names == 'all':
253257
singals_names = graph_gt.vertex_properties.keys()
@@ -266,14 +270,14 @@ def load(cls, path, fmt='auto', lib='networkx'):
266270
267271
Parameters
268272
----------
269-
path : String
270-
Where the file is located on the disk.
271-
fmt : String
272-
Format in which the graph is encoded. Currently supported format are:
273-
GML, gpickle.
274-
lib : String
275-
Python library used in background to load the graph.
276-
Supported library are networkx and graph_tool
273+
path : String
274+
Where the file is located on the disk.
275+
fmt : String
276+
Format in which the graph is encoded. Currently supported format are:
277+
GML and gpickle.
278+
lib : String
279+
Python library used in background to load the graph.
280+
Supported library are networkx and graph_tool
277281
278282
Returns
279283
-------
@@ -283,69 +287,61 @@ def load(cls, path, fmt='auto', lib='networkx'):
283287
if fmt == 'auto':
284288
fmt = path.split('.')[-1]
285289

286-
if lib == 'networkx':
287-
import networkx
288-
if lib == 'graph_tool':
289-
import graph_tool
290-
291290
err = NotImplementedError('{} can not be load with {}. \
292-
Try another background library'.format(fmt, lib))
291+
Try another background library'.format(fmt, lib))
293292

294-
if fmt == 'gml':
295-
if lib == 'networkx':
293+
if lib == 'networkx':
294+
import networkx
295+
if fmt == 'gml':
296296
g = networkx.read_gml(path)
297297
return cls.from_networkx(g)
298-
if lib == 'graph_tool':
299-
g = graph_tool.load_graph(path, fmt=fmt)
300-
return cls.from_graphtool(g)
301-
raise err
302-
303-
if fmt in ['gpickle', 'p', 'pkl', 'pickle']:
304-
if lib == 'networkx':
298+
if fmt in ['gpickle', 'p', 'pkl', 'pickle']:
305299
g = networkx.read_gpickle(path)
306300
return cls.from_networkx(g)
307301
raise err
308-
302+
if lib == 'graph_tool':
303+
import graph_tool
304+
g = graph_tool.load_graph(path, fmt=fmt)
305+
return cls.from_graphtool(g)
306+
309307
raise NotImplementedError('the format {} is not suported'.format(fmt))
310308

311309
def save(self, path, fmt='auto', lib='networkx'):
312310
r"""Save the graph into a file
313311
314312
Parameters
315313
----------
316-
path : String
317-
Where to save file on the disk.
318-
fmt : String
319-
Format in which the graph will be encoded. The format is guessed from
320-
the `path` extention when fmt is set to 'auto'
321-
Currently supported format are:
322-
GML, gpickle.
323-
lib : String
324-
Python library used in background to save the graph.
325-
Supported library are networkx and graph_tool
326-
327-
314+
path : String
315+
Where to save file on the disk.
316+
fmt : String
317+
Format in which the graph will be encoded. The format is guessed from
318+
the `path` extention when fmt is set to 'auto'
319+
Currently supported format are:
320+
GML and gpickle.
321+
lib : String
322+
Python library used in background to save the graph.
323+
Supported library are networkx and graph_tool
328324
"""
329325
if fmt == 'auto':
330326
fmt = path.split('.')[-1]
331327

332-
if lib == 'networkx':
333-
import networkx
334-
if lib == 'graph_tool':
335-
import graph_tool
336-
337328
err = NotImplementedError('{} can not be save with {}. \
338329
Try another background library'.format(fmt, lib))
339-
if fmt == 'gml':
340-
if lib == 'networkx':
330+
331+
if lib == 'networkx':
332+
import networkx
333+
if fmt == 'gml':
341334
g = self.to_networkx()
342335
networkx.write_gml(g, path)
343336
return
344-
if lib == 'graph_tool':
345-
g = self.to_graphtool()
346-
g.save(path, fmt=fmt)
347337
raise err
348-
338+
339+
if lib == 'graph_tool':
340+
import graph_tool
341+
g = self.to_graphtool()
342+
g.save(path, fmt=fmt)
343+
return
344+
349345
raise NotImplementedError('the format {} is not suported'.format(fmt))
350346

351347
def set_signal(self, signal, signal_name):
@@ -354,10 +350,10 @@ def set_signal(self, signal, signal_name):
354350
355351
Parameters
356352
----------
357-
signal : numpy.array
358-
An array maping from node to his value. For example the value of the singal at node i is signal[i]
359-
signal_name : String
360-
Name associated to the signal.
353+
signal : numpy.array
354+
An array maping from node to his value. For example the value of the singal at node i is signal[i]
355+
signal_name : String
356+
Name associated to the signal.
361357
"""
362358
assert len(signal) == self.N, "A value must be attached to every vertex in the graph"
363359
self.signals[signal_name] = np.array(signal)

0 commit comments

Comments
 (0)