Skip to content

Commit dc91e1a

Browse files
committed
fix: Fix reconstruct physically merged state nodes
Fixes #118
1 parent 60e7ab3 commit dc91e1a

File tree

3 files changed

+112
-37
lines changed

3 files changed

+112
-37
lines changed

src/core/InfoNode.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ class InfoNode
4242

4343
// typedef SiblingIterator<InfoNode*> sibling_iterator;
4444
// typedef SiblingIterator<InfoNode const*> const_sibling_iterator;
45-
// typedef LeafNodeIterator<InfoNode*> leaf_iterator;
46-
// typedef LeafNodeIterator<InfoNode const*> const_leaf_iterator;
47-
// typedef LeafModuleIterator<InfoNode*> leaf_module_iterator;
48-
// typedef LeafModuleIterator<InfoNode const*> const_leaf_module_iterator;
45+
typedef LeafNodeIterator<InfoNode*> leaf_node_iterator;
46+
typedef LeafNodeIterator<InfoNode const*> const_leaf_node_iterator;
47+
typedef LeafModuleIterator<InfoNode*> leaf_module_iterator;
48+
typedef LeafModuleIterator<InfoNode const*> const_leaf_module_iterator;
4949

5050
// typedef DepthFirstIterator<InfoNode*, true> pre_depth_first_iterator;
5151
// typedef DepthFirstIterator<InfoNode const*, true> const_pre_depth_first_iterator;
@@ -260,6 +260,12 @@ class InfoNode
260260
post_depth_first_iterator begin_post_depth_first()
261261
{ return post_depth_first_iterator(this); }
262262

263+
leaf_node_iterator begin_leaf_nodes()
264+
{ return leaf_node_iterator(this); }
265+
266+
leaf_module_iterator begin_leaf_modules()
267+
{ return leaf_module_iterator(this); }
268+
263269
tree_iterator begin_tree(unsigned int maxClusterLevel = std::numeric_limits<unsigned int>::max())
264270
{ return tree_iterator(this, maxClusterLevel); }
265271

src/core/InfomapBase.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,7 @@ InfomapBase& InfomapBase::initTree(const NodePaths& tree)
552552

553553
}
554554
aggregateFlowValuesFromLeafToRoot();
555+
initNetwork();
555556

556557
m_hierarchicalCodelength = calcCodelengthOnTree(true);
557558
Log(4) << " => " << maxDepth << " levels with codelength: " << m_hierarchicalCodelength << "\n";

src/core/MemMapEquation.cpp

Lines changed: 101 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -63,62 +63,130 @@ void MemMapEquation::initPartition(std::vector<InfoNode*>& nodes)
6363

6464
void MemMapEquation::initPhysicalNodes(InfoNode& root)
6565
{
66-
bool notInitiated = root.firstChild->physicalNodes.empty();
67-
if (notInitiated)
66+
// bool notInitiated = root.firstChild->physicalNodes.empty();
67+
auto firstLeafIt = root.begin_leaf_nodes();
68+
auto depth = firstLeafIt.depth();
69+
bool notInitiatedOnLeafNodes = firstLeafIt->physicalNodes.empty();
70+
if (notInitiatedOnLeafNodes)
6871
{
6972
Log(3) << "MemMapEquation::initPhysicalNodesOnOriginalNetwork()...\n";
7073
std::set<unsigned int> setOfPhysicalNodes;
71-
// Collect all physical nodes in this network
72-
for (InfoNode& node : root)
73-
{
74+
unsigned int maxPhysicalId = 0;
75+
unsigned int minPhysicalId = std::numeric_limits<unsigned int>::max();
76+
for (auto it(root.begin_leaf_nodes()); !it.isEnd(); ++it) {
77+
InfoNode& node = *it;
7478
setOfPhysicalNodes.insert(node.physicalId);
79+
maxPhysicalId = std::max(maxPhysicalId, node.physicalId);
80+
minPhysicalId = std::min(minPhysicalId, node.physicalId);
7581
}
7682

7783
m_numPhysicalNodes = setOfPhysicalNodes.size();
7884

79-
// Re-index physical nodes
85+
// Re-index physical nodes if necessary
8086
std::map<unsigned int, unsigned int> toZeroBasedIndex;
81-
unsigned int zeroBasedPhysicalId = 0;
82-
for (unsigned int physIndex : setOfPhysicalNodes)
83-
{
84-
toZeroBasedIndex.insert(std::make_pair(physIndex, zeroBasedPhysicalId++));
87+
if (maxPhysicalId - minPhysicalId + 1 > m_numPhysicalNodes) {
88+
unsigned int zeroBasedPhysicalId = 0;
89+
for (unsigned int physIndex : setOfPhysicalNodes)
90+
{
91+
toZeroBasedIndex.insert(std::make_pair(physIndex, zeroBasedPhysicalId++));
92+
}
8593
}
8694

87-
for (InfoNode& node : root)
95+
for (auto it(root.begin_leaf_nodes()); !it.isEnd(); ++it)
8896
{
89-
unsigned int zeroBasedIndex = toZeroBasedIndex[node.physicalId];
97+
InfoNode& node = *it;
98+
// Log() << "Leaf node " << node.stateId << " (phys " << node.physicalId << ") physicalNodes: ";
99+
// unsigned int zeroBasedIndex = toZeroBasedIndex[node.physicalId];
100+
// unsigned int zeroBasedIndex = getPhysIndex[node.physicalId];
101+
unsigned int zeroBasedIndex = !toZeroBasedIndex.empty() ? toZeroBasedIndex[node.physicalId] : (node.physicalId - minPhysicalId);
90102
node.physicalNodes.push_back(PhysData(zeroBasedIndex, node.data.flow));
103+
// Log() << "(" << zeroBasedIndex << ", " << node.data.flow << "), length: " << node.physicalNodes.size() << "\n";
104+
}
105+
106+
// If leaf nodes was not directly under root, make sure leaf modules have
107+
// physical nodes defined also
108+
if (depth > 1) {
109+
for (auto it(root.begin_leaf_modules()); !it.isEnd(); ++it)
110+
{
111+
InfoNode& module = *it;
112+
std::map<unsigned int, double> physToFlow;
113+
for (auto& node : module)
114+
{
115+
for (PhysData& physData : node.physicalNodes)
116+
{
117+
physToFlow[physData.physNodeIndex] += physData.sumFlowFromM2Node;
118+
}
119+
}
120+
for (auto& physFlow : physToFlow)
121+
{
122+
module.physicalNodes.push_back(PhysData(physFlow.first, physFlow.second));
123+
}
124+
}
91125
}
92126
}
93127
else
94128
{
95-
Log(3) << "MemMapEquation::initPhysicalNodesOnSubNetwork()...\n";
96-
std::set<unsigned int> setOfPhysicalNodes;
97-
98-
// Collect all physical nodes in this sub network
99-
for (InfoNode& node : root)
129+
// Either a sub-network (without modules) or the whole network with reconstructed tree
130+
if (depth == 1)
100131
{
101-
for (PhysData& physData : node.physicalNodes)
132+
// new sub-network
133+
Log(3) << "MemMapEquation::initPhysicalNodesOnSubNetwork()...\n";
134+
std::set<unsigned int> setOfPhysicalNodes;
135+
// std::cout << "_*!|!*_";
136+
unsigned int maxPhysNodeIndex = 0;
137+
unsigned int minPhysNodeIndex = std::numeric_limits<unsigned int>::max();
138+
139+
// Collect all physical nodes in this sub network
140+
for (InfoNode& node : root)
102141
{
103-
setOfPhysicalNodes.insert(physData.physNodeIndex);
142+
for (PhysData& physData : node.physicalNodes)
143+
{
144+
setOfPhysicalNodes.insert(physData.physNodeIndex);
145+
maxPhysNodeIndex = std::max(maxPhysNodeIndex, physData.physNodeIndex);
146+
minPhysNodeIndex = std::min(minPhysNodeIndex, physData.physNodeIndex);
147+
}
104148
}
105-
}
106-
107-
m_numPhysicalNodes = setOfPhysicalNodes.size();
108149

109-
// Re-index physical nodes
110-
std::map<unsigned int, unsigned int> toZeroBasedIndex;
111-
unsigned int zeroBasedPhysicalId = 0;
112-
for (unsigned int physIndex : setOfPhysicalNodes)
113-
{
114-
toZeroBasedIndex.insert(std::make_pair(physIndex, zeroBasedPhysicalId++));
150+
m_numPhysicalNodes = setOfPhysicalNodes.size();
151+
152+
// Re-index physical nodes if needed (not required when reconstructing tree)
153+
if (maxPhysNodeIndex >= m_numPhysicalNodes) {
154+
std::map<unsigned int, unsigned int> toZeroBasedIndex;
155+
if (maxPhysNodeIndex - minPhysNodeIndex + 1 > m_numPhysicalNodes) {
156+
unsigned int zeroBasedPhysicalId = 0;
157+
for (unsigned int physIndex : setOfPhysicalNodes)
158+
{
159+
toZeroBasedIndex.insert(std::make_pair(physIndex, zeroBasedPhysicalId++));
160+
}
161+
}
162+
163+
for (InfoNode& node : root)
164+
{
165+
for (PhysData& physData : node.physicalNodes)
166+
{
167+
unsigned int zeroBasedIndex = !toZeroBasedIndex.empty() ? toZeroBasedIndex[physData.physNodeIndex] : (physData.physNodeIndex - minPhysNodeIndex);
168+
physData.physNodeIndex = zeroBasedIndex;
169+
}
170+
}
171+
}
115172
}
116-
117-
for (InfoNode& node : root)
118-
{
119-
for (PhysData& physData : node.physicalNodes)
173+
else {
174+
// whole network with reconstructed tree
175+
for (auto it(root.begin_leaf_modules()); !it.isEnd(); ++it)
120176
{
121-
physData.physNodeIndex = toZeroBasedIndex[physData.physNodeIndex];
177+
InfoNode& module = *it;
178+
std::map<unsigned int, double> physToFlow;
179+
for (auto& node : module)
180+
{
181+
for (PhysData& physData : node.physicalNodes)
182+
{
183+
physToFlow[physData.physNodeIndex] += physData.sumFlowFromM2Node;
184+
}
185+
}
186+
for (auto& physFlow : physToFlow)
187+
{
188+
module.physicalNodes.push_back(PhysData(physFlow.first, physFlow.second));
189+
}
122190
}
123191
}
124192
}

0 commit comments

Comments
 (0)