Skip to content

Commit 4980623

Browse files
bastoneroagoscinski
andcommitted
Add tools property and entrypoint for workflows
This is an implementation of feature request in issue #6865. It also includes: * Fixes the typing hints as `tools` as it can return `None`. * Add tests for `CalcJobNode.tools` and `WorkChainNode.tools` --------- Co-authored-by: Alexander Goscinski <[email protected]>
1 parent 43176cb commit 4980623

File tree

8 files changed

+165
-2
lines changed

8 files changed

+165
-2
lines changed

src/aiida/orm/nodes/process/calculation/calcjob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class Model(CalculationNode.Model):
106106
_tools = None
107107

108108
@property
109-
def tools(self) -> 'CalculationTools':
109+
def tools(self) -> Optional['CalculationTools']:
110110
"""Return the calculation tools that are registered for the process type associated with this calculation.
111111
112112
If the entry point name stored in the `process_type` of the CalcJobNode has an accompanying entry point in the

src/aiida/orm/nodes/process/workflow/workchain.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88
###########################################################################
99
"""Module with `Node` sub class for workchain processes."""
1010

11-
from typing import Optional, Tuple
11+
from typing import TYPE_CHECKING, Optional, Tuple
1212

13+
from aiida.common import exceptions
1314
from aiida.common.lang import classproperty
1415

1516
from .workflow import WorkflowNode
1617

18+
if TYPE_CHECKING:
19+
from aiida.tools.workflows import WorkflowTools
20+
1721
__all__ = ('WorkChainNode',)
1822

1923

@@ -22,6 +26,40 @@ class WorkChainNode(WorkflowNode):
2226

2327
STEPPER_STATE_INFO_KEY = 'stepper_state_info'
2428

29+
# An optional entry point for a CalculationTools instance
30+
_tools = None
31+
32+
@property
33+
def tools(self) -> Optional['WorkflowTools']:
34+
"""Return the calculation tools that are registered for the process type associated with this calculation.
35+
36+
If the entry point name stored in the `process_type` of the CalcJobNode has an accompanying entry point in the
37+
`aiida.tools.calculations` entry point category, it will attempt to load the entry point and instantiate it
38+
passing the node to the constructor. If the entry point does not exist, cannot be resolved or loaded, a warning
39+
will be logged and the base CalculationTools class will be instantiated and returned.
40+
41+
:return: CalculationTools instance
42+
"""
43+
from aiida.plugins.entry_point import get_entry_point_from_string, is_valid_entry_point_string, load_entry_point
44+
from aiida.tools.workflows import WorkflowTools
45+
46+
if self._tools is None:
47+
entry_point_string = self.process_type
48+
49+
if entry_point_string and is_valid_entry_point_string(entry_point_string):
50+
entry_point = get_entry_point_from_string(entry_point_string)
51+
52+
try:
53+
tools_class = load_entry_point('aiida.tools.workflows', entry_point.name)
54+
self._tools = tools_class(self)
55+
except exceptions.EntryPointError as exception:
56+
self._tools = WorkflowTools(self)
57+
self.logger.warning(
58+
f'could not load the workflow tools entry point {entry_point.name}: {exception}'
59+
)
60+
61+
return self._tools
62+
2563
@classproperty
2664
def _updatable_attributes(cls) -> Tuple[str, ...]: # type: ignore[override] # noqa: N805
2765
return super()._updatable_attributes + (cls.STEPPER_STATE_INFO_KEY,)

src/aiida/tools/workflows/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################
9+
"""Workflow tool plugins for Workflow classes."""
10+
11+
# AUTO-GENERATED
12+
13+
# fmt: off
14+
15+
from .base import *
16+
17+
__all__ = (
18+
'WorkflowTools',
19+
)
20+
21+
# fmt: on

src/aiida/tools/workflows/base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################
9+
"""Base class for WorkflowTools
10+
11+
Sub-classes can be registered in the `aiida.tools.calculations` category to enable the `CalcJobNode` class from being
12+
able to find the tools plugin, load it and expose it through the `tools` property of the `CalcJobNode`.
13+
"""
14+
15+
__all__ = ('WorkflowTools',)
16+
17+
18+
class WorkflowTools:
19+
"""Base class for WorkflowTools."""
20+
21+
def __init__(self, node):
22+
self._node = node

tests/tools/calculations/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################

tests/tools/calculations/test_base.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################
9+
"""Tests for GroupPath"""
10+
11+
12+
class MockCalculationTools:
13+
def __init__(self, node):
14+
self._node = node
15+
self._mock = True
16+
17+
18+
def test_mock_calculation_tools(entry_points, generate_calcjob_node):
19+
"""Test if the calculation tools is correctly loaded from the entry point."""
20+
entry_points.add(MockCalculationTools, 'aiida.tools.calculations:MockCalculationTools')
21+
node = generate_calcjob_node(entry_point='aiida.tools.calculations:MockCalculationTools')
22+
assert node.tools._node == node
23+
assert node.tools._mock
24+
25+
26+
def test_failback_calculation_tools(entry_points, generate_calcjob_node):
27+
"""Test if the calculation tools is falling back to `CalculationTools` if it cannot be loaded from entry point."""
28+
from aiida.tools.calculations import CalculationTools
29+
30+
entry_points.add('DoesNotExist', 'aiida.tools.calculations:MockCalculationTools')
31+
node = generate_calcjob_node(entry_point='aiida.tools.calculations:MockCalculationTools')
32+
assert isinstance(node.tools, CalculationTools)
33+
assert node.tools._node == node

tests/tools/workflows/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################

tests/tools/workflows/test_base.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################
9+
"""Tests for GroupPath"""
10+
11+
12+
class MockWorkflowTools:
13+
def __init__(self, node):
14+
self._node = node
15+
self._mock = True
16+
17+
18+
def test_mock_calculation_tools(entry_points, generate_work_chain):
19+
"""Test if the calculation tools is correctly loaded from the entry point."""
20+
entry_points.add(MockWorkflowTools, 'aiida.tools.workflows:MockWorkflowTools')
21+
node = generate_work_chain(entry_point='aiida.tools.workflows:MockWorkflowTools')
22+
assert node.tools._node == node
23+
assert node.tools._mock
24+
25+
26+
def test_failback_calculation_tools(entry_points, generate_calcjob_node):
27+
"""Test if the calculation tools is falling back to `WorkflowTools` if it cannot be loaded from entry point."""
28+
from aiida.tools.workflows import WorkflowTools
29+
30+
entry_points.add('DoesNotExist', 'aiida.tools.workflows:MockWorkflowTools')
31+
node = generate_calcjob_node(entry_point='aiida.tools.workflows:MockWorkflowTools')
32+
assert isinstance(node.tools, WorkflowTools)
33+
assert node.tools._node == node

0 commit comments

Comments
 (0)