|
1 | 1 | import json |
2 | | -from copy import copy |
3 | | -from typing import Union, List, Tuple, Dict |
| 2 | +from copy import copy, deepcopy |
| 3 | +from typing import Union, List, Tuple, Dict, Any |
4 | 4 |
|
5 | 5 | from gnn_aid.aux.custom_decorators import timing_decorator |
6 | 6 | from gnn_aid.aux.utils import short_str, edge_index_to_edge_list |
@@ -111,21 +111,33 @@ class DatasetDiffData: |
111 | 111 | """ |
112 | 112 | def __init__( |
113 | 113 | self, |
114 | | - artifact: GraphModificationArtifact |
115 | 114 | ): |
116 | | - self._artifact = artifact |
| 115 | + self.node: Dict[str, Any] = { |
| 116 | + "remove": None, |
| 117 | + "add": None, |
| 118 | + "feature": None, |
| 119 | + "change_f": None |
| 120 | + } |
| 121 | + self.edge: Dict[str, Any] = { |
| 122 | + "remove": None, |
| 123 | + "add": None, |
| 124 | + "feature": None, |
| 125 | + "change_f": None |
| 126 | + } |
117 | 127 |
|
118 | 128 | def __str__( |
119 | 129 | self |
120 | 130 | ): |
121 | | - return f"DatasetDiffData{self._artifact.to_json()}" |
| 131 | + return f"DatasetDiffData[\n" \ |
| 132 | + f" node:{self.node}\n" \ |
| 133 | + f" edge:{self.edge}]" |
122 | 134 |
|
123 | 135 | def to_json( |
124 | 136 | self, |
125 | 137 | **dump_args |
126 | 138 | ) -> str: |
127 | 139 | """ Return json string. """ |
128 | | - return json.dumps(self._artifact.to_json(), **dump_args) |
| 140 | + return json.dumps(self.__dict__, **dump_args) |
129 | 141 |
|
130 | 142 |
|
131 | 143 | class ViewPoint: |
@@ -493,6 +505,62 @@ def get_dataset_var_data( |
493 | 505 |
|
494 | 506 | return dataset_var_data |
495 | 507 |
|
| 508 | + # @timing_decorator |
| 509 | + def get_dataset_diff_data( |
| 510 | + self, |
| 511 | + view_point: ViewPoint = None, |
| 512 | + artifact: GraphModificationArtifact = None |
| 513 | + ) -> DatasetDiffData: |
| 514 | + """ |
| 515 | + Get DatasetDiffData for a specified dataset diff. |
| 516 | + """ |
| 517 | + if self.gen_dataset.is_multi(): |
| 518 | + raise NotImplementedError |
| 519 | + |
| 520 | + self.update_view_point(view_point) |
| 521 | + |
| 522 | + dataset_diff_data = DatasetDiffData() |
| 523 | + print("Computing dataset_diff_data for", view_point) |
| 524 | + |
| 525 | + center = self.dataset_index.view_point.center |
| 526 | + if center is None: # Graph |
| 527 | + # Node level |
| 528 | + if len(artifact.nodes["remove"]) > 0: |
| 529 | + dataset_diff_data.node["remove"] = list(artifact.nodes["remove"]) |
| 530 | + if len(artifact.nodes["add"]) > 0: |
| 531 | + dataset_diff_data.node["add"] = list(artifact.nodes["add"].keys()) |
| 532 | + dataset_diff_data.node["feature"] = dict(artifact.nodes["add"]) |
| 533 | + if len(artifact.nodes["change_f"]) > 0: |
| 534 | + dataset_diff_data.node["change_f"] = deepcopy(artifact.nodes["change_f"]) |
| 535 | + # Edge level |
| 536 | + if len(artifact.edges["remove"]) > 0: |
| 537 | + dataset_diff_data.edge["remove"] = deepcopy(artifact.edges["remove"]) |
| 538 | + if len(artifact.edges["add"]) > 0: |
| 539 | + dataset_diff_data.edge["add"] = [[i, j] for i, j, _ in artifact.edges["add"]] |
| 540 | + # dataset_diff_data.edge["feature"] = dict(artifact.edges["add"]) |
| 541 | + # if len(artifact.edges["change_f"]) > 0: |
| 542 | + # dataset_diff_data.edge["change_f"] = deepcopy(artifact.edges["change_f"]) |
| 543 | + |
| 544 | + else: # Neighborhood |
| 545 | + node_index = self.dataset_index.node_index |
| 546 | + edge_index = self.dataset_index.edge_index |
| 547 | + # Node level |
| 548 | + if len(artifact.nodes["remove"]) > 0: |
| 549 | + dataset_diff_data.node["remove"] = list(artifact.nodes["remove"]) |
| 550 | + if len(artifact.nodes["add"]) > 0: |
| 551 | + dataset_diff_data.node["add"] = list(artifact.nodes["add"].keys()) |
| 552 | + dataset_diff_data.node["feature"] = dict(artifact.nodes["add"]) |
| 553 | + if len(artifact.nodes["change_f"]) > 0: |
| 554 | + dataset_diff_data.node["change_f"] = deepcopy(artifact.nodes["change_f"]) |
| 555 | + # Edge level |
| 556 | + if len(artifact.edges["remove"]) > 0: |
| 557 | + dataset_diff_data.edge["remove"] = deepcopy(artifact.edges["remove"]) |
| 558 | + if len(artifact.edges["add"]) > 0: |
| 559 | + dataset_diff_data.edge["add"] = [[i, j] for i, j, _ in artifact.edges["add"]] |
| 560 | + |
| 561 | + |
| 562 | + return dataset_diff_data |
| 563 | + |
496 | 564 | def get_train_test_mask( |
497 | 565 | self, |
498 | 566 | ) -> DatasetVarData: |
|
0 commit comments