|
46 | 46 | ) |
47 | 47 | from langchain_core.runnables.history import RunnableWithMessageHistory |
48 | 48 | from langchain_core.runnables.schema import StreamEvent |
49 | | -from langchain_core.runnables.utils import Input, Output |
| 49 | +from langchain_core.runnables.utils import Addable |
50 | 50 | from langchain_core.tools import tool |
51 | 51 | from langchain_core.utils.aiter import aclosing |
52 | 52 | from tests.unit_tests.runnables.test_runnable_events_v1 import ( |
@@ -2116,39 +2116,39 @@ def add_one_proxy(x: int, config: RunnableConfig) -> int: |
2116 | 2116 | _assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS) |
2117 | 2117 |
|
2118 | 2118 |
|
2119 | | -class StreamingRunnable(Runnable[Input, Output]): |
| 2119 | +class StreamingRunnable(Runnable[Any, Addable]): |
2120 | 2120 | """A custom runnable used for testing purposes.""" |
2121 | 2121 |
|
2122 | | - iterable: Iterable[Any] |
| 2122 | + iterable: Iterable[Addable] |
2123 | 2123 |
|
2124 | | - def __init__(self, iterable: Iterable[Any]) -> None: |
| 2124 | + def __init__(self, iterable: Iterable[Addable]) -> None: |
2125 | 2125 | """Initialize the runnable.""" |
2126 | 2126 | self.iterable = iterable |
2127 | 2127 |
|
2128 | 2128 | @override |
2129 | 2129 | def invoke( |
2130 | | - self, input: Input, config: RunnableConfig | None = None, **kwargs: Any |
2131 | | - ) -> Output: |
| 2130 | + self, input: Any, config: RunnableConfig | None = None, **kwargs: Any |
| 2131 | + ) -> Addable: |
2132 | 2132 | """Invoke the runnable.""" |
2133 | 2133 | msg = "Server side error" |
2134 | 2134 | raise ValueError(msg) |
2135 | 2135 |
|
2136 | 2136 | @override |
2137 | 2137 | def stream( |
2138 | 2138 | self, |
2139 | | - input: Input, |
| 2139 | + input: Any, |
2140 | 2140 | config: RunnableConfig | None = None, |
2141 | 2141 | **kwargs: Any | None, |
2142 | | - ) -> Iterator[Output]: |
| 2142 | + ) -> Iterator[Addable]: |
2143 | 2143 | raise NotImplementedError |
2144 | 2144 |
|
2145 | 2145 | @override |
2146 | 2146 | async def astream( |
2147 | 2147 | self, |
2148 | | - input: Input, |
| 2148 | + input: Any, |
2149 | 2149 | config: RunnableConfig | None = None, |
2150 | 2150 | **kwargs: Any | None, |
2151 | | - ) -> AsyncIterator[Output]: |
| 2151 | + ) -> AsyncIterator[Addable]: |
2152 | 2152 | config = ensure_config(config) |
2153 | 2153 | callback_manager = get_async_callback_manager_for_config(config) |
2154 | 2154 | run_manager = await callback_manager.on_chain_start( |
@@ -2183,7 +2183,7 @@ async def astream( |
2183 | 2183 | async def test_astream_events_from_custom_runnable() -> None: |
2184 | 2184 | """Test astream events from a custom runnable.""" |
2185 | 2185 | iterator = ["1", "2", "3"] |
2186 | | - runnable: Runnable[int, str] = StreamingRunnable(iterator) |
| 2186 | + runnable = StreamingRunnable(iterator) |
2187 | 2187 | chunks = [chunk async for chunk in runnable.astream(1, version="v2")] |
2188 | 2188 | assert chunks == ["1", "2", "3"] |
2189 | 2189 | events = await _collect_events(runnable.astream_events(1, version="v2")) |
@@ -2386,7 +2386,7 @@ async def generator(_: AsyncIterator[str]) -> AsyncIterator[str]: |
2386 | 2386 | yield "1" |
2387 | 2387 | yield "2" |
2388 | 2388 |
|
2389 | | - runnable: Runnable[str, str] = RunnableGenerator(transform=generator) |
| 2389 | + runnable = RunnableGenerator(transform=generator) |
2390 | 2390 | events = await _collect_events(runnable.astream_events("hello", version="v2")) |
2391 | 2391 | _assert_events_equal_allow_superset_metadata( |
2392 | 2392 | events, |
|
0 commit comments