Skip to content

Commit 09654f4

Browse files
authored
feat(core): allow scaling by reported usage when counting tokens approximately (#34996)
1 parent 8072a51 commit 09654f4

2 files changed

Lines changed: 137 additions & 1 deletion

File tree

libs/core/langchain_core/messages/utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2188,6 +2188,7 @@ def count_tokens_approximately(
21882188
extra_tokens_per_message: float = 3.0,
21892189
count_name: bool = True,
21902190
tokens_per_image: int = 85,
2191+
use_usage_metadata_scaling: bool = False,
21912192
) -> int:
21922193
"""Approximate the total number of tokens in messages.
21932194
@@ -2211,6 +2212,11 @@ def count_tokens_approximately(
22112212
count_name: Whether to include message names in the count.
22122213
tokens_per_image: Fixed token cost per image (default: 85, aligned with
22132214
OpenAI's low-resolution image token cost).
2215+
use_usage_metadata_scaling: If True, and all AI messages have consistent
2216+
`response_metadata['model_provider']`, scale the approximate token count
2217+
using the **most recent** AI message that has
2218+
`usage_metadata['total_tokens']`. The scaling factor is:
2219+
`AI_total_tokens / approx_tokens_up_to_that_AI_message`
22142220
22152221
Returns:
22162222
Approximate number of tokens in the messages.
@@ -2225,8 +2231,16 @@ def count_tokens_approximately(
22252231
22262232
!!! version-added "Added in `langchain-core` 0.3.46"
22272233
"""
2234+
converted_messages = convert_to_messages(messages)
2235+
22282236
token_count = 0.0
2229-
for message in convert_to_messages(messages):
2237+
2238+
ai_model_provider: str | None = None
2239+
invalid_model_provider = False
2240+
last_ai_total_tokens: int | None = None
2241+
approx_at_last_ai: float | None = None
2242+
2243+
for message in converted_messages:
22302244
message_chars = 0
22312245

22322246
if isinstance(message.content, str):
@@ -2284,6 +2298,30 @@ def count_tokens_approximately(
22842298
# add extra tokens per message
22852299
token_count += extra_tokens_per_message
22862300

2301+
if use_usage_metadata_scaling and isinstance(message, AIMessage):
2302+
model_provider = message.response_metadata.get("model_provider")
2303+
if ai_model_provider is None:
2304+
ai_model_provider = model_provider
2305+
elif model_provider != ai_model_provider:
2306+
invalid_model_provider = True
2307+
2308+
if message.usage_metadata and isinstance(
2309+
(total_tokens := message.usage_metadata.get("total_tokens")), int
2310+
):
2311+
last_ai_total_tokens = total_tokens
2312+
approx_at_last_ai = token_count
2313+
2314+
if (
2315+
use_usage_metadata_scaling
2316+
and not invalid_model_provider
2317+
and ai_model_provider is not None
2318+
and last_ai_total_tokens is not None
2319+
and approx_at_last_ai
2320+
and approx_at_last_ai > 0
2321+
):
2322+
scale_factor = last_ai_total_tokens / approx_at_last_ai
2323+
token_count *= max(1.0, scale_factor)
2324+
22872325
# round up once more time in case extra_tokens_per_message is a float
22882326
return math.ceil(token_count)
22892327

libs/core/tests/unit_tests/messages/test_utils.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import json
3+
import math
34
import re
45
from collections.abc import Callable, Sequence
56
from typing import Any, TypedDict
@@ -1594,6 +1595,103 @@ def test_count_tokens_approximately_mixed_content_types() -> None:
15941595
assert sum(count_tokens_approximately([m]) for m in messages) == token_count
15951596

15961597

1598+
def test_count_tokens_approximately_usage_metadata_scaling() -> None:
1599+
messages = [
1600+
HumanMessage("text"),
1601+
AIMessage(
1602+
"text",
1603+
response_metadata={"model_provider": "openai"},
1604+
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 100},
1605+
),
1606+
HumanMessage("text"),
1607+
AIMessage(
1608+
"text",
1609+
response_metadata={"model_provider": "openai"},
1610+
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 200},
1611+
),
1612+
]
1613+
1614+
unscaled = count_tokens_approximately(messages)
1615+
scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True)
1616+
1617+
assert scaled == 200
1618+
assert unscaled < 100
1619+
1620+
messages.extend([ToolMessage("text", tool_call_id="abc123")] * 3)
1621+
1622+
unscaled_extended = count_tokens_approximately(messages)
1623+
scaled_extended = count_tokens_approximately(
1624+
messages, use_usage_metadata_scaling=True
1625+
)
1626+
1627+
# scaling should still be based on the most recent AIMessage with total_tokens=200
1628+
assert unscaled_extended > unscaled
1629+
assert scaled_extended > scaled
1630+
1631+
# And the scaled total should be the unscaled total multiplied by the same ratio.
1632+
# ratio = 200 / unscaled (as of last AI message)
1633+
expected_scaled_extended = math.ceil(unscaled_extended * (200 / unscaled))
1634+
assert scaled_extended == expected_scaled_extended
1635+
1636+
1637+
def test_count_tokens_approximately_usage_metadata_scaling_model_provider() -> None:
1638+
messages = [
1639+
HumanMessage("Hello"),
1640+
AIMessage(
1641+
"Hi",
1642+
response_metadata={"model_provider": "openai"},
1643+
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 100},
1644+
),
1645+
HumanMessage("More text"),
1646+
AIMessage(
1647+
"More response",
1648+
response_metadata={"model_provider": "anthropic"},
1649+
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 200},
1650+
),
1651+
]
1652+
1653+
unscaled = count_tokens_approximately(messages)
1654+
scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True)
1655+
assert scaled == unscaled
1656+
1657+
1658+
def test_count_tokens_approximately_usage_metadata_scaling_total_tokens() -> None:
1659+
messages = [
1660+
HumanMessage("Hello"),
1661+
AIMessage(
1662+
"Hi",
1663+
response_metadata={"model_provider": "openai"},
1664+
# no usage metadata -> skip
1665+
),
1666+
]
1667+
1668+
unscaled = count_tokens_approximately(messages, chars_per_token=5)
1669+
scaled = count_tokens_approximately(
1670+
messages, chars_per_token=5, use_usage_metadata_scaling=True
1671+
)
1672+
1673+
assert scaled == unscaled
1674+
1675+
1676+
def test_count_tokens_approximately_usage_metadata_scaling_floor_at_one() -> None:
1677+
messages = [
1678+
HumanMessage("text"),
1679+
AIMessage(
1680+
"text",
1681+
response_metadata={"model_provider": "openai"},
1682+
# Set total_tokens lower than the approximate count up through this message.
1683+
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 1},
1684+
),
1685+
HumanMessage("text"),
1686+
]
1687+
1688+
unscaled = count_tokens_approximately(messages)
1689+
scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True)
1690+
1691+
# scale factor would be < 1, but we floor it at 1.0 to avoid decreasing counts
1692+
assert scaled == unscaled
1693+
1694+
15971695
def test_get_buffer_string_with_structured_content() -> None:
15981696
"""Test get_buffer_string with structured content in messages."""
15991697
messages = [

0 commit comments

Comments
 (0)