Skip to content

Feat!: Tag queries with their plan id #4832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sqlmesh/core/config/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ class BuiltInSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig)
def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
return BuiltInPlanEvaluator(
state_sync=context.state_sync,
snapshot_evaluator=context.snapshot_evaluator,
create_scheduler=context.create_scheduler,
default_catalog=context.default_catalog,
console=context.console,
Expand Down
58 changes: 38 additions & 20 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
run_tests,
)
from sqlmesh.core.user import User
from sqlmesh.utils import UniqueKeyDict, Verbosity
from sqlmesh.utils import UniqueKeyDict, Verbosity, CorrelationId
from sqlmesh.utils.concurrency import concurrent_apply_to_values
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import (
Expand Down Expand Up @@ -418,7 +418,7 @@ def __init__(
self.config.get_state_connection(self.gateway) or self.connection_config
)

self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None
self._snapshot_evaluators: t.Dict[t.Optional[CorrelationId], SnapshotEvaluator] = {}

self.console = get_console()
setattr(self.console, "dialect", self.config.dialect)
Expand Down Expand Up @@ -446,18 +446,22 @@ def engine_adapter(self) -> EngineAdapter:
self._engine_adapter = self.connection_config.create_engine_adapter()
return self._engine_adapter

@property
def snapshot_evaluator(self) -> SnapshotEvaluator:
if not self._snapshot_evaluator:
self._snapshot_evaluator = SnapshotEvaluator(
def snapshot_evaluator(
self, correlation_id: t.Optional[CorrelationId] = None
) -> SnapshotEvaluator:
# Cache snapshot evaluators by correlation_id to avoid old correlation_ids being attached to future Context operations
if correlation_id not in self._snapshot_evaluators:
self._snapshot_evaluators[correlation_id] = SnapshotEvaluator(
{
gateway: adapter.with_log_level(logging.INFO)
gateway: adapter.with_settings(
log_level=logging.INFO, correlation_id=correlation_id
)
for gateway, adapter in self.engine_adapters.items()
},
ddl_concurrent_tasks=self.concurrent_tasks,
selected_gateway=self.selected_gateway,
)
return self._snapshot_evaluator
return self._snapshot_evaluators[correlation_id]

def execution_context(
self,
Expand Down Expand Up @@ -538,7 +542,9 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:

return self.create_scheduler(snapshots)

def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
def create_scheduler(
self, snapshots: t.Iterable[Snapshot], correlation_id: t.Optional[CorrelationId] = None
) -> Scheduler:
"""Creates the built-in scheduler.

Args:
Expand All @@ -549,7 +555,7 @@ def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
"""
return Scheduler(
snapshots,
self.snapshot_evaluator,
self.snapshot_evaluator(correlation_id),
self.state_sync,
default_catalog=self.default_catalog,
max_workers=self.concurrent_tasks,
Expand Down Expand Up @@ -714,7 +720,7 @@ def run(
NotificationEvent.RUN_START, environment=environment
)
analytics_run_id = analytics.collector.on_run_start(
engine_type=self.snapshot_evaluator.adapter.dialect,
engine_type=self.snapshot_evaluator().adapter.dialect,
state_sync_type=self.state_sync.state_type(),
)
self._load_materializations()
Expand Down Expand Up @@ -1076,7 +1082,7 @@ def evaluate(
and not parent_snapshot.categorized
]

df = self.snapshot_evaluator.evaluate_and_fetch(
df = self.snapshot_evaluator().evaluate_and_fetch(
snapshot,
start=start,
end=end,
Expand Down Expand Up @@ -1588,7 +1594,12 @@ def apply(
default_catalog=self.default_catalog,
console=self.console,
)
explainer.evaluate(plan.to_evaluatable())
explainer.evaluate(
plan.to_evaluatable(),
snapshot_evaluator=self.snapshot_evaluator(
correlation_id=CorrelationId.from_plan_id(plan.plan_id)
),
)
return

self.notification_target_manager.notify(
Expand Down Expand Up @@ -1902,7 +1913,7 @@ def _table_diff(
)

return TableDiff(
adapter=adapter.with_log_level(logger.getEffectiveLevel()),
adapter=adapter.with_settings(logger.getEffectiveLevel()),
source=source,
target=target,
on=on,
Expand Down Expand Up @@ -2111,7 +2122,7 @@ def audit(
errors = []
skipped_count = 0
for snapshot in snapshots:
for audit_result in self.snapshot_evaluator.audit(
for audit_result in self.snapshot_evaluator().audit(
snapshot=snapshot,
start=start,
end=end,
Expand Down Expand Up @@ -2143,7 +2154,7 @@ def audit(
self.console.log_status_update(f"Got {error.count} results, expected 0.")
if error.query:
self.console.show_sql(
f"{error.query.sql(dialect=self.snapshot_evaluator.adapter.dialect)}"
f"{error.query.sql(dialect=self.snapshot_evaluator().adapter.dialect)}"
)

self.console.log_status_update("Done.")
Expand Down Expand Up @@ -2335,11 +2346,14 @@ def print_environment_names(self) -> None:

def close(self) -> None:
"""Releases all resources allocated by this context."""
if self._snapshot_evaluator:
self._snapshot_evaluator.close()
for evaluator in self._snapshot_evaluators.values():
evaluator.close()

if self._state_sync:
self._state_sync.close()

self._snapshot_evaluators.clear()

def _run(
self,
environment: str,
Expand Down Expand Up @@ -2390,7 +2404,11 @@ def _run(

def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None:
self._scheduler.create_plan_evaluator(self).evaluate(
plan.to_evaluatable(), circuit_breaker=circuit_breaker
plan.to_evaluatable(),
snapshot_evaluator=self.snapshot_evaluator(
correlation_id=CorrelationId.from_plan_id(plan.plan_id)
),
circuit_breaker=circuit_breaker,
)

@python_api_analytics
Expand Down Expand Up @@ -2683,7 +2701,7 @@ def _run_janitor(self, ignore_ttl: bool = False) -> None:
)

# Remove the expired snapshots tables
self.snapshot_evaluator.cleanup(
self.snapshot_evaluator().cleanup(
target_snapshots=cleanup_targets,
on_complete=self.console.update_cleanup_progress,
)
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/engine_adapter/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any
):
# Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config
# which means that EngineAdapter.with_log_level() keeps this property when it makes a clone
# which means that EngineAdapter.with_settings() keeps this property when it makes a clone
super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs)
self.s3_warehouse_location = s3_warehouse_location

Expand Down
12 changes: 9 additions & 3 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from sqlmesh.core.model.kind import TimeColumn
from sqlmesh.core.schema_diff import SchemaDiffer
from sqlmesh.utils import columns_to_types_all_known, random_id
from sqlmesh.utils import columns_to_types_all_known, random_id, CorrelationId
from sqlmesh.utils.connection_pool import create_connection_pool, ConnectionPool
from sqlmesh.utils.date import TimeLike, make_inclusive, to_time_column
from sqlmesh.utils.errors import (
Expand Down Expand Up @@ -123,6 +123,7 @@ def __init__(
pre_ping: bool = False,
pretty_sql: bool = False,
shared_connection: bool = False,
correlation_id: t.Optional[CorrelationId] = None,
**kwargs: t.Any,
):
self.dialect = dialect.lower() or self.DIALECT
Expand All @@ -144,19 +145,21 @@ def __init__(
self._pre_ping = pre_ping
self._pretty_sql = pretty_sql
self._multithreaded = multithreaded
self.correlation_id = correlation_id

def with_log_level(self, level: int) -> EngineAdapter:
def with_settings(self, log_level: int, **kwargs: t.Any) -> EngineAdapter:
adapter = self.__class__(
self._connection_pool,
dialect=self.dialect,
sql_gen_kwargs=self._sql_gen_kwargs,
default_catalog=self._default_catalog,
execute_log_level=level,
execute_log_level=log_level,
register_comments=self._register_comments,
null_connection=True,
multithreaded=self._multithreaded,
pretty_sql=self._pretty_sql,
**self._extra_config,
**kwargs,
)

return adapter
Expand Down Expand Up @@ -2211,6 +2214,9 @@ def execute(
else:
sql = t.cast(str, e)

if self.correlation_id:
sql = f"/* {self.correlation_id} */ {sql}"

self._log_sql(
sql,
expression=e if isinstance(e, exp.Expression) else None,
Expand Down
12 changes: 9 additions & 3 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@
class PlanEvaluator(abc.ABC):
@abc.abstractmethod
def evaluate(
self, plan: EvaluatablePlan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None
self,
plan: EvaluatablePlan,
snapshot_evaluator: SnapshotEvaluator,
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
) -> None:
"""Evaluates a plan by pushing snapshots and backfilling data.

Expand All @@ -60,20 +63,20 @@ def evaluate(

Args:
plan: The plan to evaluate.
snapshot_evaluator: The snapshot evaluator to use.
circuit_breaker: The circuit breaker to use.
"""


class BuiltInPlanEvaluator(PlanEvaluator):
def __init__(
self,
state_sync: StateSync,
snapshot_evaluator: SnapshotEvaluator,
create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler],
default_catalog: t.Optional[str],
console: t.Optional[Console] = None,
):
self.state_sync = state_sync
self.snapshot_evaluator = snapshot_evaluator
self.create_scheduler = create_scheduler
self.default_catalog = default_catalog
self.console = console or get_console()
Expand All @@ -82,9 +85,12 @@ def __init__(
def evaluate(
self,
plan: EvaluatablePlan,
snapshot_evaluator: SnapshotEvaluator,
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
) -> None:
self._circuit_breaker = circuit_breaker
self.snapshot_evaluator = snapshot_evaluator

self.console.start_plan_evaluation(plan)
analytics.collector.on_plan_apply_start(
plan=plan,
Expand Down
6 changes: 5 additions & 1 deletion sqlmesh/core/plan/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
from sqlmesh.utils.date import to_ts
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.core.snapshot.evaluator import SnapshotEvaluator


logger = logging.getLogger(__name__)
Expand All @@ -37,7 +38,10 @@ def __init__(
self.console = console or get_console()

def evaluate(
self, plan: EvaluatablePlan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None
self,
plan: EvaluatablePlan,
snapshot_evaluator: SnapshotEvaluator,
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
) -> None:
plan_stages = stages.build_plan_stages(plan, self.state_reader, self.default_catalog)
explainer_console = _get_explainer_console(
Expand Down
21 changes: 21 additions & 0 deletions sqlmesh/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import types
import typing as t
import uuid
from dataclasses import dataclass
from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy
Expand Down Expand Up @@ -382,3 +383,23 @@ def to_snake_case(name: str) -> str:
return "".join(
f"_{c.lower()}" if c.isupper() and idx != 0 else c.lower() for idx, c in enumerate(name)
)


class JobType(Enum):
PLAN = "SQLMESH_PLAN"
RUN = "SQLMESH_RUN"


@dataclass(frozen=True)
class CorrelationId:
"""ID that is added to each query in order to identify the job that created it."""

job_type: JobType
job_id: str

def __str__(self) -> str:
return f"{self.job_type.value}: {self.job_id}"

@classmethod
def from_plan_id(cls, plan_id: str) -> CorrelationId:
return CorrelationId(JobType.PLAN, plan_id)
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
SnapshotDataVersion,
SnapshotFingerprint,
)
from sqlmesh.utils import random_id
from sqlmesh.utils import random_id, CorrelationId
from sqlmesh.utils.date import TimeLike, to_date
from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path
from sqlmesh.core.engine_adapter.shared import CatalogSupport
Expand Down Expand Up @@ -266,10 +266,12 @@ def duck_conn() -> duckdb.DuckDBPyConnection:
def push_plan(context: Context, plan: Plan) -> None:
plan_evaluator = BuiltInPlanEvaluator(
context.state_sync,
context.snapshot_evaluator,
context.create_scheduler,
context.default_catalog,
)
plan_evaluator.snapshot_evaluator = context.snapshot_evaluator(
CorrelationId.from_plan_id(plan.plan_id)
)
deployability_index = DeployabilityIndex.create(context.snapshots.values())
evaluatable_plan = plan.to_evaluatable()
stages = plan_stages.build_plan_stages(
Expand Down
Loading