Skip to content

Commit 8123b0e

Browse files
committed
WIP
1 parent 9e017cd commit 8123b0e

12 files changed

+216
-83
lines changed

src/zenml/config/compiler.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Class for compiling ZenML pipelines into a serializable format."""
1515

1616
import copy
17+
import os
1718
import string
1819
from typing import (
1920
TYPE_CHECKING,
@@ -36,6 +37,11 @@
3637
StepConfigurationUpdate,
3738
StepSpec,
3839
)
40+
from zenml.constants import (
41+
ENV_ZENML_ACTIVE_STACK_ID,
42+
ENV_ZENML_ACTIVE_WORKSPACE_ID,
43+
ENV_ZENML_STORE_PREFIX,
44+
)
3945
from zenml.environment import get_run_environment_dict
4046
from zenml.exceptions import StackValidationError
4147
from zenml.models import PipelineDeploymentBase
@@ -50,6 +56,8 @@
5056

5157
from zenml.logger import get_logger
5258

59+
ENVIRONMENT_VARIABLE_PREFIX = "__ZENML__"
60+
5361
logger = get_logger(__file__)
5462

5563

@@ -104,6 +112,9 @@ def compile(
104112
pipeline.configuration.substitutions,
105113
)
106114

115+
pipeline_environment = finalize_environment_variables(
116+
pipeline.configuration.environment
117+
)
107118
pipeline_settings = self._filter_and_validate_settings(
108119
settings=pipeline.configuration.settings,
109120
configuration_level=ConfigurationLevel.PIPELINE,
@@ -121,6 +132,7 @@ def compile(
121132
steps = {
122133
invocation_id: self._compile_step_invocation(
123134
invocation=invocation,
135+
pipeline_environment=pipeline_environment,
124136
pipeline_settings=settings_to_passdown,
125137
pipeline_extra=pipeline.configuration.extra,
126138
stack=stack,
@@ -427,6 +439,7 @@ def _get_step_spec(
427439
def _compile_step_invocation(
428440
self,
429441
invocation: "StepInvocation",
442+
pipeline_environment: Optional[Dict[str, Any]],
430443
pipeline_settings: Dict[str, "BaseSettings"],
431444
pipeline_extra: Dict[str, Any],
432445
stack: "Stack",
@@ -438,7 +451,9 @@ def _compile_step_invocation(
438451
439452
Args:
440453
invocation: The step invocation to compile.
441-
pipeline_settings: settings configured on the
454+
pipeline_environment: Environment variables configured for the
455+
pipeline.
456+
pipeline_settings: Settings configured on the
442457
pipeline of the step.
443458
pipeline_extra: Extra values configured on the pipeline of the step.
444459
stack: The stack on which the pipeline will be run.
@@ -463,6 +478,9 @@ def _compile_step_invocation(
463478
step.configuration.settings, stack=stack
464479
)
465480
step_spec = self._get_step_spec(invocation=invocation)
481+
step_environment = finalize_environment_variables(
482+
step.configuration.environment
483+
)
466484
step_settings = self._filter_and_validate_settings(
467485
settings=step.configuration.settings,
468486
configuration_level=ConfigurationLevel.STEP,
@@ -473,13 +491,15 @@ def _compile_step_invocation(
473491
step_on_success_hook_source = step.configuration.success_hook_source
474492

475493
step.configure(
494+
environment=pipeline_environment,
476495
settings=pipeline_settings,
477496
extra=pipeline_extra,
478497
on_failure=pipeline_failure_hook_source,
479498
on_success=pipeline_success_hook_source,
480499
merge=False,
481500
)
482501
step.configure(
502+
environment=step_environment,
483503
settings=step_settings,
484504
extra=step_extra,
485505
on_failure=step_on_failure_hook_source,
@@ -635,3 +655,49 @@ def convert_component_shortcut_settings_keys(
635655
)
636656

637657
settings[key] = component_settings
658+
659+
660+
def finalize_environment_variables(
661+
environment: Dict[str, Any],
662+
) -> Dict[str, str]:
663+
"""Finalize the user environment variables.
664+
665+
This function adds all __ZENML__ prefixed environment variables from the
666+
local client environment to the explicit user-defined variables.
667+
668+
Args:
669+
environment: The explicit user-defined environment variables.
670+
671+
Returns:
672+
The finalized user environment variables.
673+
"""
674+
environment = {key: str(value) for key, value in environment.items()}
675+
676+
for key, value in os.environ.items():
677+
if key.startswith(ENVIRONMENT_VARIABLE_PREFIX):
678+
key_without_prefix = key[len(ENVIRONMENT_VARIABLE_PREFIX) :]
679+
680+
if (
681+
key_without_prefix in environment
682+
and value != environment[key_without_prefix]
683+
):
684+
logger.warning(
685+
"Got multiple values for environment variable `%s`.",
686+
key_without_prefix,
687+
)
688+
else:
689+
environment[key_without_prefix] = value
690+
691+
finalized_env = {}
692+
693+
for key, value in environment:
694+
if key.upper().startswith(ENV_ZENML_STORE_PREFIX) or key.upper() in [
695+
ENV_ZENML_ACTIVE_WORKSPACE_ID,
696+
ENV_ZENML_ACTIVE_STACK_ID,
697+
]:
698+
logger.warning(
699+
"Not allowed to set `%s` config environment variable.", key
700+
)
701+
continue
702+
703+
return finalized_env

src/zenml/config/pipeline_configurations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class PipelineConfigurationUpdate(StrictBaseModel):
4040
enable_artifact_metadata: Optional[bool] = None
4141
enable_artifact_visualization: Optional[bool] = None
4242
enable_step_logs: Optional[bool] = None
43+
environment: Dict[str, Any] = {}
4344
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
4445
tags: Optional[List[str]] = None
4546
extra: Dict[str, Any] = {}

src/zenml/config/pipeline_run_configuration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class PipelineRunConfiguration(
4444
default=None, union_mode="left_to_right"
4545
)
4646
steps: Dict[str, StepConfigurationUpdate] = {}
47+
environment: Dict[str, Any] = {}
4748
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
4849
tags: Optional[List[str]] = None
4950
extra: Dict[str, Any] = {}

src/zenml/config/step_configurations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ class StepConfigurationUpdate(StrictBaseModel):
148148
step_operator: Optional[str] = None
149149
experiment_tracker: Optional[str] = None
150150
parameters: Dict[str, Any] = {}
151+
environment: Dict[str, Any] = {}
151152
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
152153
extra: Dict[str, Any] = {}
153154
failure_hook_source: Optional[SourceWithValidator] = None

src/zenml/orchestrators/cache_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def generate_cache_key(
9999
hash_.update(key.encode())
100100
hash_.update(str(value).encode())
101101

102+
# User-defined environment variables
103+
for key, value in sorted(step.config.environment.items()):
104+
hash_.update(key.encode())
105+
hash_.update(str(value).encode())
106+
102107
return hash_.hexdigest()
103108

104109

src/zenml/orchestrators/step_runner.py

Lines changed: 80 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@
5656
parse_return_type_annotations,
5757
resolve_type_annotation,
5858
)
59-
from zenml.utils import materializer_utils, source_utils, string_utils
59+
from zenml.utils import (
60+
env_utils,
61+
materializer_utils,
62+
source_utils,
63+
string_utils,
64+
)
6065
from zenml.utils.typing_utils import get_origin, is_union
6166

6267
if TYPE_CHECKING:
@@ -183,86 +188,90 @@ def run(
183188
)
184189

185190
step_failed = False
186-
try:
187-
return_values = step_instance.call_entrypoint(
188-
**function_params
189-
)
190-
except BaseException as step_exception: # noqa: E722
191-
step_failed = True
192-
if not handle_bool_env_var(
193-
ENV_ZENML_IGNORE_FAILURE_HOOK, False
194-
):
195-
if (
196-
failure_hook_source
197-
:= self.configuration.failure_hook_source
198-
):
199-
logger.info("Detected failure hook. Running...")
200-
self.load_and_run_hook(
201-
failure_hook_source,
202-
step_exception=step_exception,
203-
)
204-
raise
205-
finally:
191+
with env_utils.temporary_environment(step_run.config.environment):
206192
try:
207-
step_run_metadata = self._stack.get_step_run_metadata(
208-
info=step_run_info,
209-
)
210-
publish_step_run_metadata(
211-
step_run_id=step_run_info.step_run_id,
212-
step_run_metadata=step_run_metadata,
213-
)
214-
self._stack.cleanup_step_run(
215-
info=step_run_info, step_failed=step_failed
193+
return_values = step_instance.call_entrypoint(
194+
**function_params
216195
)
217-
if not step_failed:
196+
except BaseException as step_exception: # noqa: E722
197+
step_failed = True
198+
if not handle_bool_env_var(
199+
ENV_ZENML_IGNORE_FAILURE_HOOK, False
200+
):
218201
if (
219-
success_hook_source
220-
:= self.configuration.success_hook_source
202+
failure_hook_source
203+
:= self.configuration.failure_hook_source
221204
):
222-
logger.info("Detected success hook. Running...")
205+
logger.info("Detected failure hook. Running...")
223206
self.load_and_run_hook(
224-
success_hook_source,
225-
step_exception=None,
207+
failure_hook_source,
208+
step_exception=step_exception,
226209
)
227-
228-
# Store and publish the output artifacts of the step function.
229-
output_data = self._validate_outputs(
230-
return_values, output_annotations
231-
)
232-
artifact_metadata_enabled = is_setting_enabled(
233-
is_enabled_on_step=step_run_info.config.enable_artifact_metadata,
234-
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_metadata,
210+
raise
211+
finally:
212+
try:
213+
step_run_metadata = self._stack.get_step_run_metadata(
214+
info=step_run_info,
235215
)
236-
artifact_visualization_enabled = is_setting_enabled(
237-
is_enabled_on_step=step_run_info.config.enable_artifact_visualization,
238-
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_visualization,
216+
publish_step_run_metadata(
217+
step_run_id=step_run_info.step_run_id,
218+
step_run_metadata=step_run_metadata,
239219
)
240-
output_artifacts = self._store_output_artifacts(
241-
output_data=output_data,
242-
output_artifact_uris=output_artifact_uris,
243-
output_materializers=output_materializers,
244-
output_annotations=output_annotations,
245-
artifact_metadata_enabled=artifact_metadata_enabled,
246-
artifact_visualization_enabled=artifact_visualization_enabled,
220+
self._stack.cleanup_step_run(
221+
info=step_run_info, step_failed=step_failed
247222
)
248-
249-
if (
250-
model_version := step_run.model_version
251-
or pipeline_run.model_version
252-
):
253-
from zenml.orchestrators import step_run_utils
254-
255-
step_run_utils.link_output_artifacts_to_model_version(
256-
artifacts={
257-
k: [v] for k, v in output_artifacts.items()
258-
},
259-
model_version=model_version,
223+
if not step_failed:
224+
if (
225+
success_hook_source
226+
:= self.configuration.success_hook_source
227+
):
228+
logger.info(
229+
"Detected success hook. Running..."
230+
)
231+
self.load_and_run_hook(
232+
success_hook_source,
233+
step_exception=None,
234+
)
235+
236+
# Store and publish the output artifacts of the step function.
237+
output_data = self._validate_outputs(
238+
return_values, output_annotations
260239
)
261-
finally:
262-
step_context._cleanup_registry.execute_callbacks(
263-
raise_on_exception=False
264-
)
265-
StepContext._clear() # Remove the step context singleton
240+
artifact_metadata_enabled = is_setting_enabled(
241+
is_enabled_on_step=step_run_info.config.enable_artifact_metadata,
242+
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_metadata,
243+
)
244+
artifact_visualization_enabled = is_setting_enabled(
245+
is_enabled_on_step=step_run_info.config.enable_artifact_visualization,
246+
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_visualization,
247+
)
248+
output_artifacts = self._store_output_artifacts(
249+
output_data=output_data,
250+
output_artifact_uris=output_artifact_uris,
251+
output_materializers=output_materializers,
252+
output_annotations=output_annotations,
253+
artifact_metadata_enabled=artifact_metadata_enabled,
254+
artifact_visualization_enabled=artifact_visualization_enabled,
255+
)
256+
257+
if (
258+
model_version := step_run.model_version
259+
or pipeline_run.model_version
260+
):
261+
from zenml.orchestrators import step_run_utils
262+
263+
step_run_utils.link_output_artifacts_to_model_version(
264+
artifacts={
265+
k: [v]
266+
for k, v in output_artifacts.items()
267+
},
268+
model_version=model_version,
269+
)
270+
finally:
271+
step_context._cleanup_registry.execute_callbacks(
272+
raise_on_exception=False
273+
)
274+
StepContext._clear() # Remove the step context singleton
266275

267276
# Update the status and output artifacts of the step run.
268277
output_artifact_ids = {

src/zenml/pipelines/pipeline_decorator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def pipeline(
4949
enable_cache: Optional[bool] = None,
5050
enable_artifact_metadata: Optional[bool] = None,
5151
enable_step_logs: Optional[bool] = None,
52+
environment: Optional[Dict[str, Any]] = None,
5253
settings: Optional[Dict[str, "SettingsOrDict"]] = None,
5354
tags: Optional[List[str]] = None,
5455
extra: Optional[Dict[str, Any]] = None,
@@ -66,6 +67,7 @@ def pipeline(
6667
enable_cache: Optional[bool] = None,
6768
enable_artifact_metadata: Optional[bool] = None,
6869
enable_step_logs: Optional[bool] = None,
70+
environment: Optional[Dict[str, Any]] = None,
6971
settings: Optional[Dict[str, "SettingsOrDict"]] = None,
7072
tags: Optional[List[str]] = None,
7173
extra: Optional[Dict[str, Any]] = None,
@@ -83,6 +85,7 @@ def pipeline(
8385
enable_cache: Whether to use caching or not.
8486
enable_artifact_metadata: Whether to enable artifact metadata or not.
8587
enable_step_logs: If step logs should be enabled for this pipeline.
88+
environment: Environment variables to set when running this pipeline.
8689
settings: Settings for this pipeline.
8790
tags: Tags to apply to runs of the pipeline.
8891
extra: Extra configurations for this pipeline.
@@ -107,6 +110,7 @@ def inner_decorator(func: "F") -> "Pipeline":
107110
enable_cache=enable_cache,
108111
enable_artifact_metadata=enable_artifact_metadata,
109112
enable_step_logs=enable_step_logs,
113+
environment=environment,
110114
settings=settings,
111115
tags=tags,
112116
extra=extra,

0 commit comments

Comments
 (0)