Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 116 additions & 46 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,8 @@ def trim_messages(
max_tokens: int,
token_counter: Callable[[list[BaseMessage]], int]
| Callable[[BaseMessage], int]
| BaseLanguageModel,
| BaseLanguageModel
| Literal["approximate"],
strategy: Literal["first", "last"] = "last",
allow_partial: bool = False,
end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
Expand Down Expand Up @@ -758,53 +759,65 @@ def trim_messages(
messages: Sequence of Message-like objects to trim.
max_tokens: Max token count of trimmed messages.
token_counter: Function or llm for counting tokens in a `BaseMessage` or a
list of `BaseMessage`. If a `BaseLanguageModel` is passed in then
`BaseLanguageModel.get_num_tokens_from_messages()` will be used.
Set to `len` to count the number of **messages** in the chat history.
list of `BaseMessage`.

!!! note
If a `BaseLanguageModel` is passed in then
`BaseLanguageModel.get_num_tokens_from_messages()` will be used. Set to
`len` to count the number of **messages** in the chat history.

You can also use string shortcuts for convenience:

- `'approximate'`: Uses `count_tokens_approximately` for fast, approximate
token counts.

Use `count_tokens_approximately` to get fast, approximate token
counts.
!!! note

This is recommended for using `trim_messages` on the hot path, where
exact token counting is not necessary.
`count_tokens_approximately` (or the shortcut `'approximate'`) is
recommended for using `trim_messages` on the hot path, where exact token
counting is not necessary.

strategy: Strategy for trimming.

- `'first'`: Keep the first `<= n_count` tokens of the messages.
- `'last'`: Keep the last `<= n_count` tokens of the messages.
allow_partial: Whether to split a message if only part of the message can be
included. If `strategy='last'` then the last partial contents of a message
are included. If `strategy='first'` then the first partial contents of a
included.

If `strategy='last'` then the last partial contents of a message are
included. If `strategy='first'` then the first partial contents of a
message are included.
end_on: The message type to end on. If specified then every message after the
last occurrence of this type is ignored. If `strategy='last'` then this
is done before we attempt to get the last `max_tokens`. If
`strategy='first'` then this is done after we get the first
`max_tokens`. Can be specified as string names (e.g. `'system'`,
`'human'`, `'ai'`, ...) or as `BaseMessage` classes (e.g.
`SystemMessage`, `HumanMessage`, `AIMessage`, ...). Can be a single
type or a list of types.

start_on: The message type to start on. Should only be specified if
`strategy='last'`. If specified then every message before
the first occurrence of this type is ignored. This is done after we trim
the initial messages to the last `max_tokens`. Does not
apply to a `SystemMessage` at index 0 if `include_system=True`. Can be
specified as string names (e.g. `'system'`, `'human'`, `'ai'`, ...) or
as `BaseMessage` classes (e.g. `SystemMessage`, `HumanMessage`,
`AIMessage`, ...). Can be a single type or a list of types.
end_on: The message type to end on.

If specified then every message after the last occurrence of this type is
ignored. If `strategy='last'` then this is done before we attempt to get the
last `max_tokens`. If `strategy='first'` then this is done after we get the
first `max_tokens`. Can be specified as string names (e.g. `'system'`,
`'human'`, `'ai'`, ...) or as `BaseMessage` classes (e.g. `SystemMessage`,
`HumanMessage`, `AIMessage`, ...). Can be a single type or a list of types.

start_on: The message type to start on.

Should only be specified if `strategy='last'`. If specified then every
message before the first occurrence of this type is ignored. This is done
after we trim the initial messages to the last `max_tokens`. Does not apply
to a `SystemMessage` at index 0 if `include_system=True`. Can be specified
as string names (e.g. `'system'`, `'human'`, `'ai'`, ...) or as
`BaseMessage` classes (e.g. `SystemMessage`, `HumanMessage`, `AIMessage`,
...). Can be a single type or a list of types.

include_system: Whether to keep the `SystemMessage` if there is one at index
`0`. Should only be specified if `strategy="last"`.
`0`.

Should only be specified if `strategy="last"`.
text_splitter: Function or `langchain_text_splitters.TextSplitter` for
splitting the string contents of a message. Only used if
`allow_partial=True`. If `strategy='last'` then the last split tokens
from a partial message will be included. if `strategy='first'` then the
first split tokens from a partial message will be included. Token splitter
assumes that separators are kept, so that split contents can be directly
concatenated to recreate the original text. Defaults to splitting on
newlines.
splitting the string contents of a message.

Only used if `allow_partial=True`. If `strategy='last'` then the last split
tokens from a partial message will be included. if `strategy='first'` then
the first split tokens from a partial message will be included. Token
splitter assumes that separators are kept, so that split contents can be
directly concatenated to recreate the original text. Defaults to splitting
on newlines.

Returns:
List of trimmed `BaseMessage`.
Expand All @@ -815,8 +828,8 @@ def trim_messages(

Example:
Trim chat history based on token count, keeping the `SystemMessage` if
present, and ensuring that the chat history starts with a `HumanMessage` (
or a `SystemMessage` followed by a `HumanMessage`).
present, and ensuring that the chat history starts with a `HumanMessage` (or a
`SystemMessage` followed by a `HumanMessage`).

```python
from langchain_core.messages import (
Expand Down Expand Up @@ -869,8 +882,34 @@ def trim_messages(
]
```

Trim chat history using approximate token counting with `'approximate'`:

```python
trim_messages(
messages,
max_tokens=45,
strategy="last",
# Using the "approximate" shortcut for fast token counting
token_counter="approximate",
start_on="human",
include_system=True,
)

# This is equivalent to using `count_tokens_approximately` directly
from langchain_core.messages.utils import count_tokens_approximately

trim_messages(
messages,
max_tokens=45,
strategy="last",
token_counter=count_tokens_approximately,
start_on="human",
include_system=True,
)
```

Trim chat history based on the message count, keeping the `SystemMessage` if
present, and ensuring that the chat history starts with a `HumanMessage` (
present, and ensuring that the chat history starts with a HumanMessage (
or a `SystemMessage` followed by a `HumanMessage`).

trim_messages(
Expand Down Expand Up @@ -992,24 +1031,44 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
raise ValueError(msg)

messages = convert_to_messages(messages)
if hasattr(token_counter, "get_num_tokens_from_messages"):
list_token_counter = token_counter.get_num_tokens_from_messages
elif callable(token_counter):

# Handle string shortcuts for token counter
if isinstance(token_counter, str):
if token_counter in _TOKEN_COUNTER_SHORTCUTS:
actual_token_counter = _TOKEN_COUNTER_SHORTCUTS[token_counter]
else:
available_shortcuts = ", ".join(
f"'{key}'" for key in _TOKEN_COUNTER_SHORTCUTS
)
msg = (
f"Invalid token_counter shortcut '{token_counter}'. "
f"Available shortcuts: {available_shortcuts}."
)
raise ValueError(msg)
else:
# Type narrowing: at this point token_counter is not a str
actual_token_counter = token_counter # type: ignore[assignment]

if hasattr(actual_token_counter, "get_num_tokens_from_messages"):
list_token_counter = actual_token_counter.get_num_tokens_from_messages
elif callable(actual_token_counter):
if (
next(iter(inspect.signature(token_counter).parameters.values())).annotation
next(
iter(inspect.signature(actual_token_counter).parameters.values())
).annotation
is BaseMessage
):

def list_token_counter(messages: Sequence[BaseMessage]) -> int:
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
return sum(actual_token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]

else:
list_token_counter = token_counter
list_token_counter = actual_token_counter
else:
msg = (
f"'token_counter' expected to be a model that implements "
f"'get_num_tokens_from_messages()' or a function. Received object of type "
f"{type(token_counter)}."
f"{type(actual_token_counter)}."
)
raise ValueError(msg)

Expand Down Expand Up @@ -1807,3 +1866,14 @@ def count_tokens_approximately(

# round up once more time in case extra_tokens_per_message is a float
return math.ceil(token_count)


# Mapping from string shortcuts to token counter functions
def _approximate_token_counter(messages: Sequence[BaseMessage]) -> int:
"""Wrapper for `count_tokens_approximately` that matches expected signature."""
return count_tokens_approximately(messages)


_TOKEN_COUNTER_SHORTCUTS = {
"approximate": _approximate_token_counter,
}
76 changes: 76 additions & 0 deletions libs/core/tests/unit_tests/messages/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,82 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
assert messages == messages_copy


def test_trim_messages_token_counter_shortcut_approximate() -> None:
"""Test that `'approximate'` shortcut works for `token_counter`."""
messages = [
SystemMessage("This is a test message"),
HumanMessage("Another test message", id="first"),
AIMessage("AI response here", id="second"),
]
messages_copy = [m.model_copy(deep=True) for m in messages]

# Test using the "approximate" shortcut
result_shortcut = trim_messages(
messages,
max_tokens=50,
token_counter="approximate",
strategy="last",
)

# Test using count_tokens_approximately directly
result_direct = trim_messages(
messages,
max_tokens=50,
token_counter=count_tokens_approximately,
strategy="last",
)

# Both should produce the same result
assert result_shortcut == result_direct
assert messages == messages_copy


def test_trim_messages_token_counter_shortcut_invalid() -> None:
"""Test that invalid `token_counter` shortcut raises `ValueError`."""
messages = [
SystemMessage("This is a test message"),
HumanMessage("Another test message"),
]

# Test with invalid shortcut - intentionally passing invalid string to verify
# runtime error handling for dynamically-constructed inputs
with pytest.raises(ValueError, match="Invalid token_counter shortcut 'invalid'"):
trim_messages( # type: ignore[call-overload]
messages,
max_tokens=50,
token_counter="invalid",
strategy="last",
)


def test_trim_messages_token_counter_shortcut_with_options() -> None:
"""Test that `'approximate'` shortcut works with different trim options."""
messages = [
SystemMessage("System instructions"),
HumanMessage("First human message", id="first"),
AIMessage("First AI response", id="ai1"),
HumanMessage("Second human message", id="second"),
AIMessage("Second AI response", id="ai2"),
]
messages_copy = [m.model_copy(deep=True) for m in messages]

# Test with various options
result = trim_messages(
messages,
max_tokens=100,
token_counter="approximate",
strategy="last",
include_system=True,
start_on="human",
)

# Should include system message and start on human
assert len(result) >= 2
assert isinstance(result[0], SystemMessage)
assert any(isinstance(msg, HumanMessage) for msg in result[1:])
assert messages == messages_copy


class FakeTokenCountingModel(FakeChatModel):
@override
def get_num_tokens_from_messages(
Expand Down