Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,24 @@ def shielded(func: Func) -> Func:

@functools.wraps(func)
async def wrapped(*args: Any, **kwargs: Any) -> Any:
return await asyncio.shield(func(*args, **kwargs))
# Capture the current context to preserve context variables
ctx = copy_context()

# Create the coroutine
coro = func(*args, **kwargs)

# For Python 3.11+, create task with explicit context
# For older versions, fallback to original behavior
try:
# Create a task with the captured context to preserve context variables
task = asyncio.create_task(coro, context=ctx) # type: ignore[call-arg, unused-ignore]
# `call-arg` used to not fail 3.9 or 3.10 tests
return await asyncio.shield(task)
except TypeError:
# Python < 3.11 fallback - create task normally then shield
# This won't preserve context perfectly but is better than nothing
task = asyncio.create_task(coro)
return await asyncio.shield(task)

return cast("Func", wrapped)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,65 @@ async def on_llm_start(
2,
3,
3,
], f"Expected order of states was broken due to context loss. Got {states}"
]


async def test_shielded_callback_context_preservation() -> None:
"""Verify that shielded callbacks preserve context variables.

This test specifically addresses the issue where async callbacks decorated
with @shielded do not properly preserve context variables, breaking
instrumentation and other context-dependent functionality.

The issue manifests in callbacks that use the @shielded decorator:
* on_llm_end
* on_llm_error
* on_chain_end
* on_chain_error
* And other shielded callback methods
"""
context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context")

class ContextTestHandler(AsyncCallbackHandler):
"""Handler that reads context variables in shielded callbacks."""

def __init__(self) -> None:
self.run_inline = False
self.context_values: list[str] = []

@override
async def on_llm_end(self, response: Any, **kwargs: Any) -> None:
"""This method is decorated with @shielded in the run manager."""
# This should preserve the context variable value
self.context_values.append(context_var.get("not_found"))

@override
async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None:
"""This method is decorated with @shielded in the run manager."""
# This should preserve the context variable value
self.context_values.append(context_var.get("not_found"))

# Set up the test context
context_var.set("test_value")
handler = ContextTestHandler()
manager = AsyncCallbackManager(handlers=[handler])

# Create run managers that have the shielded methods
llm_managers = await manager.on_llm_start({}, ["test prompt"])
llm_run_manager = llm_managers[0]

chain_run_manager = await manager.on_chain_start({}, {"test": "input"})

# Test LLM end callback (which is shielded)
await llm_run_manager.on_llm_end({"response": "test"}) # type: ignore[arg-type]

# Test Chain end callback (which is shielded)
await chain_run_manager.on_chain_end({"output": "test"})

# The context should be preserved in shielded callbacks
# This was the main issue - shielded decorators were not preserving context
assert handler.context_values == ["test_value", "test_value"], (
f"Expected context values ['test_value', 'test_value'], "
f"but got {handler.context_values}. "
f"This indicates the shielded decorator is not preserving context variables."
)