Skip to content

Commit 2b59be6

Browse files
committed
fix edge index in run_model
1 parent cc15349 commit 2b59be6

File tree

4 files changed

+21
-13
lines changed

4 files changed

+21
-13
lines changed

experiments/various_tasks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def link_prediction():
108108

109109
gen_dataset = DatasetManager.get_by_config(dc, dvc)
110110
print(gen_dataset.data)
111-
gen_dataset.train_test_split(percent_train_class=0.85, percent_test_class=0.1)
111+
gen_dataset.train_test_split(percent_train_class=0.8, percent_test_class=0.2)
112112

113113
gnn = FrameworkGNNConstructor(
114114
model_config=ModelConfig(
@@ -227,7 +227,8 @@ def link_prediction():
227227
edge_label_index = torch.tensor([[5], [6]])
228228

229229
# get embeddings for all nodes
230-
node_out = gnn(data.x, data.edge_index)
230+
train_edge_index = gen_dataset.edge_label_index[:, gen_dataset.train_mask]
231+
node_out = gnn(data.x, train_edge_index)
231232

232233
# Get embeddings for our nodes
233234
src = node_out[edge_label_index[0]]

gnn_aid/datasets/gen_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ def train_test_split(
366366
train_data.edge_label_index.size(1) + val_data.edge_label_index.size(1)] = True
367367

368368
test_mask = torch.zeros(total_edges, dtype=torch.bool)
369-
test_mask[-test_data.edge_label_index.size(1):] = True
369+
if test_data.edge_label_index.size(1) > 0:
370+
test_mask[-test_data.edge_label_index.size(1):] = True
370371
else:
371372
raise ValueError(f"Unsupported task type {task_type}")
372373

gnn_aid/models_builder/model_managers/framework_mm.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -519,13 +519,12 @@ def run_model(
519519

520520
elif task_type.is_edge_level():
521521
data = gen_dataset.data
522-
edge_label_index = mask_tensor
523-
522+
train_edge_index = gen_dataset.edge_label_index[:, gen_dataset.train_mask]
524523
data_x_copy = torch.clone(data.x)
525524

526525
# FIXME misha check, test
527526
if hasattr(self, 'mask_features'):
528-
node_ind = torch.unique(edge_label_index)
527+
node_ind = torch.unique(train_edge_index)
529528
for elem_ind in node_ind:
530529
for feature in self.mask_features:
531530
data_x_copy[elem_ind][gen_dataset.node_attr_slices[feature][0]:
@@ -535,21 +534,23 @@ def run_model(
535534
if hasattr(self, 'optimizer'):
536535
self.optimizer.zero_grad()
537536

538-
# get logits for nodes
539-
node_out = self.gnn(data_x_copy, data.edge_index)
540-
541-
src = node_out[edge_label_index[0]]
542-
dst = node_out[edge_label_index[1]]
537+
# get logits for all nodes based on train edges
538+
node_out = self.gnn(data_x_copy, train_edge_index)
543539

540+
# Get logits for edges from mask
541+
src = node_out[mask_tensor[0]]
542+
dst = node_out[mask_tensor[1]]
544543
edge_out = self.gnn.decode(src, dst)
545544

545+
# Apply different out
546546
full_out = None
547547
if out == 'logits':
548548
full_out = edge_out
549549
elif out == 'predictions':
550550
if task_type == Task.EDGE_PREDICTION:
551-
# TODO misha
552-
raise NotImplementedError
551+
# TODO misha is it ok?
552+
full_out = edge_out.softmax(dim=-1)
553+
# raise NotImplementedError
553554
elif task_type == Task.EDGE_CLASSIFICATION:
554555
full_out = edge_out.softmax(dim=-1)
555556
elif task_type == Task.EDGE_REGRESSION:

gnn_aid/models_builder/models_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,11 @@ def name(
255255
res += '{' + kwargs + '}'
256256
return res
257257

258+
def __str__(
259+
self
260+
) -> str:
261+
return self.name()
262+
258263
def compute(
259264
self,
260265
y_true,

0 commit comments

Comments
 (0)