|
1 | 1 | import base64 |
2 | 2 | import json |
| 3 | +import math |
3 | 4 | import re |
4 | 5 | from collections.abc import Callable, Sequence |
5 | 6 | from typing import Any, TypedDict |
@@ -1594,6 +1595,103 @@ def test_count_tokens_approximately_mixed_content_types() -> None: |
1594 | 1595 | assert sum(count_tokens_approximately([m]) for m in messages) == token_count |
1595 | 1596 |
|
1596 | 1597 |
|
| 1598 | +def test_count_tokens_approximately_usage_metadata_scaling() -> None: |
| 1599 | + messages = [ |
| 1600 | + HumanMessage("text"), |
| 1601 | + AIMessage( |
| 1602 | + "text", |
| 1603 | + response_metadata={"model_provider": "openai"}, |
| 1604 | + usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 100}, |
| 1605 | + ), |
| 1606 | + HumanMessage("text"), |
| 1607 | + AIMessage( |
| 1608 | + "text", |
| 1609 | + response_metadata={"model_provider": "openai"}, |
| 1610 | + usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 200}, |
| 1611 | + ), |
| 1612 | + ] |
| 1613 | + |
| 1614 | + unscaled = count_tokens_approximately(messages) |
| 1615 | + scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True) |
| 1616 | + |
| 1617 | + assert scaled == 200 |
| 1618 | + assert unscaled < 100 |
| 1619 | + |
| 1620 | + messages.extend([ToolMessage("text", tool_call_id="abc123")] * 3) |
| 1621 | + |
| 1622 | + unscaled_extended = count_tokens_approximately(messages) |
| 1623 | + scaled_extended = count_tokens_approximately( |
| 1624 | + messages, use_usage_metadata_scaling=True |
| 1625 | + ) |
| 1626 | + |
| 1627 | + # scaling should still be based on the most recent AIMessage with total_tokens=200 |
| 1628 | + assert unscaled_extended > unscaled |
| 1629 | + assert scaled_extended > scaled |
| 1630 | + |
| 1631 | + # And the scaled total should be the unscaled total multiplied by the same ratio. |
| 1632 | + # ratio = 200 / unscaled (as of last AI message) |
| 1633 | + expected_scaled_extended = math.ceil(unscaled_extended * (200 / unscaled)) |
| 1634 | + assert scaled_extended == expected_scaled_extended |
| 1635 | + |
| 1636 | + |
| 1637 | +def test_count_tokens_approximately_usage_metadata_scaling_model_provider() -> None: |
| 1638 | + messages = [ |
| 1639 | + HumanMessage("Hello"), |
| 1640 | + AIMessage( |
| 1641 | + "Hi", |
| 1642 | + response_metadata={"model_provider": "openai"}, |
| 1643 | + usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 100}, |
| 1644 | + ), |
| 1645 | + HumanMessage("More text"), |
| 1646 | + AIMessage( |
| 1647 | + "More response", |
| 1648 | + response_metadata={"model_provider": "anthropic"}, |
| 1649 | + usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 200}, |
| 1650 | + ), |
| 1651 | + ] |
| 1652 | + |
| 1653 | + unscaled = count_tokens_approximately(messages) |
| 1654 | + scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True) |
| 1655 | + assert scaled == unscaled |
| 1656 | + |
| 1657 | + |
| 1658 | +def test_count_tokens_approximately_usage_metadata_scaling_total_tokens() -> None: |
| 1659 | + messages = [ |
| 1660 | + HumanMessage("Hello"), |
| 1661 | + AIMessage( |
| 1662 | + "Hi", |
| 1663 | + response_metadata={"model_provider": "openai"}, |
| 1664 | + # no usage metadata -> skip |
| 1665 | + ), |
| 1666 | + ] |
| 1667 | + |
| 1668 | + unscaled = count_tokens_approximately(messages, chars_per_token=5) |
| 1669 | + scaled = count_tokens_approximately( |
| 1670 | + messages, chars_per_token=5, use_usage_metadata_scaling=True |
| 1671 | + ) |
| 1672 | + |
| 1673 | + assert scaled == unscaled |
| 1674 | + |
| 1675 | + |
| 1676 | +def test_count_tokens_approximately_usage_metadata_scaling_floor_at_one() -> None: |
| 1677 | + messages = [ |
| 1678 | + HumanMessage("text"), |
| 1679 | + AIMessage( |
| 1680 | + "text", |
| 1681 | + response_metadata={"model_provider": "openai"}, |
| 1682 | + # Set total_tokens lower than the approximate count up through this message. |
| 1683 | + usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 1}, |
| 1684 | + ), |
| 1685 | + HumanMessage("text"), |
| 1686 | + ] |
| 1687 | + |
| 1688 | + unscaled = count_tokens_approximately(messages) |
| 1689 | + scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True) |
| 1690 | + |
| 1691 | + # scale factor would be < 1, but we floor it at 1.0 to avoid decreasing counts |
| 1692 | + assert scaled == unscaled |
| 1693 | + |
| 1694 | + |
1597 | 1695 | def test_get_buffer_string_with_structured_content() -> None: |
1598 | 1696 | """Test get_buffer_string with structured content in messages.""" |
1599 | 1697 | messages = [ |
|
0 commit comments