Skip to content

Commit 87b8afd

Browse files
committed
+
1 parent 32a5b4f commit 87b8afd

File tree

1 file changed

+74
-6
lines changed

1 file changed

+74
-6
lines changed

web_interface/back_front/visible_part.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
2-
from copy import copy
3-
from typing import Union, List, Tuple, Dict
2+
from copy import copy, deepcopy
3+
from typing import Union, List, Tuple, Dict, Any
44

55
from gnn_aid.aux.custom_decorators import timing_decorator
66
from gnn_aid.aux.utils import short_str, edge_index_to_edge_list
@@ -111,21 +111,33 @@ class DatasetDiffData:
111111
"""
112112
def __init__(
113113
self,
114-
artifact: GraphModificationArtifact
115114
):
116-
self._artifact = artifact
115+
self.node: Dict[str, Any] = {
116+
"remove": None,
117+
"add": None,
118+
"feature": None,
119+
"change_f": None
120+
}
121+
self.edge: Dict[str, Any] = {
122+
"remove": None,
123+
"add": None,
124+
"feature": None,
125+
"change_f": None
126+
}
117127

118128
def __str__(
119129
self
120130
):
121-
return f"DatasetDiffData{self._artifact.to_json()}"
131+
return f"DatasetDiffData[\n" \
132+
f" node:{self.node}\n" \
133+
f" edge:{self.edge}]"
122134

123135
def to_json(
124136
self,
125137
**dump_args
126138
) -> str:
127139
""" Return json string. """
128-
return json.dumps(self._artifact.to_json(), **dump_args)
140+
return json.dumps(self.__dict__, **dump_args)
129141

130142

131143
class ViewPoint:
@@ -493,6 +505,62 @@ def get_dataset_var_data(
493505

494506
return dataset_var_data
495507

508+
# @timing_decorator
509+
def get_dataset_diff_data(
510+
self,
511+
view_point: ViewPoint = None,
512+
artifact: GraphModificationArtifact = None
513+
) -> DatasetDiffData:
514+
"""
515+
Get DatasetDiffData for a specified dataset diff.
516+
"""
517+
if self.gen_dataset.is_multi():
518+
raise NotImplementedError
519+
520+
self.update_view_point(view_point)
521+
522+
dataset_diff_data = DatasetDiffData()
523+
print("Computing dataset_diff_data for", view_point)
524+
525+
center = self.dataset_index.view_point.center
526+
if center is None: # Graph
527+
# Node level
528+
if len(artifact.nodes["remove"]) > 0:
529+
dataset_diff_data.node["remove"] = list(artifact.nodes["remove"])
530+
if len(artifact.nodes["add"]) > 0:
531+
dataset_diff_data.node["add"] = list(artifact.nodes["add"].keys())
532+
dataset_diff_data.node["feature"] = dict(artifact.nodes["add"])
533+
if len(artifact.nodes["change_f"]) > 0:
534+
dataset_diff_data.node["change_f"] = deepcopy(artifact.nodes["change_f"])
535+
# Edge level
536+
if len(artifact.edges["remove"]) > 0:
537+
dataset_diff_data.edge["remove"] = deepcopy(artifact.edges["remove"])
538+
if len(artifact.edges["add"]) > 0:
539+
dataset_diff_data.edge["add"] = [[i, j] for i, j, _ in artifact.edges["add"]]
540+
# dataset_diff_data.edge["feature"] = dict(artifact.edges["add"])
541+
# if len(artifact.edges["change_f"]) > 0:
542+
# dataset_diff_data.edge["change_f"] = deepcopy(artifact.edges["change_f"])
543+
544+
else: # Neighborhood
545+
node_index = self.dataset_index.node_index
546+
edge_index = self.dataset_index.edge_index
547+
# Node level
548+
if len(artifact.nodes["remove"]) > 0:
549+
dataset_diff_data.node["remove"] = list(artifact.nodes["remove"])
550+
if len(artifact.nodes["add"]) > 0:
551+
dataset_diff_data.node["add"] = list(artifact.nodes["add"].keys())
552+
dataset_diff_data.node["feature"] = dict(artifact.nodes["add"])
553+
if len(artifact.nodes["change_f"]) > 0:
554+
dataset_diff_data.node["change_f"] = deepcopy(artifact.nodes["change_f"])
555+
# Edge level
556+
if len(artifact.edges["remove"]) > 0:
557+
dataset_diff_data.edge["remove"] = deepcopy(artifact.edges["remove"])
558+
if len(artifact.edges["add"]) > 0:
559+
dataset_diff_data.edge["add"] = [[i, j] for i, j, _ in artifact.edges["add"]]
560+
561+
562+
return dataset_diff_data
563+
496564
def get_train_test_mask(
497565
self,
498566
) -> DatasetVarData:

0 commit comments

Comments
 (0)