Skip to content

Commit 04fad44

Browse files
committed
chore(core): improve types for StreamingRunnable
1 parent 721bf15 commit 04fad44

1 file changed

Lines changed: 12 additions & 12 deletions

File tree

libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747
from langchain_core.runnables.history import RunnableWithMessageHistory
4848
from langchain_core.runnables.schema import StreamEvent
49-
from langchain_core.runnables.utils import Input, Output
49+
from langchain_core.runnables.utils import Addable
5050
from langchain_core.tools import tool
5151
from langchain_core.utils.aiter import aclosing
5252
from tests.unit_tests.runnables.test_runnable_events_v1 import (
@@ -2116,39 +2116,39 @@ def add_one_proxy(x: int, config: RunnableConfig) -> int:
21162116
_assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS)
21172117

21182118

2119-
class StreamingRunnable(Runnable[Input, Output]):
2119+
class StreamingRunnable(Runnable[Any, Addable]):
21202120
"""A custom runnable used for testing purposes."""
21212121

2122-
iterable: Iterable[Any]
2122+
iterable: Iterable[Addable]
21232123

2124-
def __init__(self, iterable: Iterable[Any]) -> None:
2124+
def __init__(self, iterable: Iterable[Addable]) -> None:
21252125
"""Initialize the runnable."""
21262126
self.iterable = iterable
21272127

21282128
@override
21292129
def invoke(
2130-
self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
2131-
) -> Output:
2130+
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
2131+
) -> Addable:
21322132
"""Invoke the runnable."""
21332133
msg = "Server side error"
21342134
raise ValueError(msg)
21352135

21362136
@override
21372137
def stream(
21382138
self,
2139-
input: Input,
2139+
input: Any,
21402140
config: RunnableConfig | None = None,
21412141
**kwargs: Any | None,
2142-
) -> Iterator[Output]:
2142+
) -> Iterator[Addable]:
21432143
raise NotImplementedError
21442144

21452145
@override
21462146
async def astream(
21472147
self,
2148-
input: Input,
2148+
input: Any,
21492149
config: RunnableConfig | None = None,
21502150
**kwargs: Any | None,
2151-
) -> AsyncIterator[Output]:
2151+
) -> AsyncIterator[Addable]:
21522152
config = ensure_config(config)
21532153
callback_manager = get_async_callback_manager_for_config(config)
21542154
run_manager = await callback_manager.on_chain_start(
@@ -2183,7 +2183,7 @@ async def astream(
21832183
async def test_astream_events_from_custom_runnable() -> None:
21842184
"""Test astream events from a custom runnable."""
21852185
iterator = ["1", "2", "3"]
2186-
runnable: Runnable[int, str] = StreamingRunnable(iterator)
2186+
runnable = StreamingRunnable(iterator)
21872187
chunks = [chunk async for chunk in runnable.astream(1, version="v2")]
21882188
assert chunks == ["1", "2", "3"]
21892189
events = await _collect_events(runnable.astream_events(1, version="v2"))
@@ -2386,7 +2386,7 @@ async def generator(_: AsyncIterator[str]) -> AsyncIterator[str]:
23862386
yield "1"
23872387
yield "2"
23882388

2389-
runnable: Runnable[str, str] = RunnableGenerator(transform=generator)
2389+
runnable = RunnableGenerator(transform=generator)
23902390
events = await _collect_events(runnable.astream_events("hello", version="v2"))
23912391
_assert_events_equal_allow_superset_metadata(
23922392
events,

0 commit comments

Comments
 (0)