@@ -688,6 +688,7 @@ def trim_messages(
688688 * ,
689689 max_tokens : int ,
690690 token_counter : Union [
691+ Literal ["approximate" ],
691692 Callable [[list [BaseMessage ]], int ],
692693 Callable [[BaseMessage ], int ],
693694 BaseLanguageModel ,
@@ -738,11 +739,16 @@ def trim_messages(
738739 BaseMessage. If a BaseLanguageModel is passed in then
739740 BaseLanguageModel.get_num_tokens_from_messages() will be used.
740741 Set to `len` to count the number of **messages** in the chat history.
742+ You can also use string shortcuts for convenience:
743+
744+ - ``"approximate"``: Uses `count_tokens_approximately` for fast, approximate
745+ token counts.
741746
742747 .. note::
743- Use `count_tokens_approximately` to get fast, approximate token counts.
744- This is recommended for using `trim_messages` on the hot path, where
745- exact token counting is not necessary.
748+ Use `count_tokens_approximately` (or the shortcut ``"approximate"``) to get
749+ fast, approximate token counts. This is recommended for using
750+ `trim_messages` on the hot path, where exact token counting is not
751+ necessary.
746752
747753 strategy: Strategy for trimming.
748754
@@ -849,6 +855,35 @@ def trim_messages(
849855 HumanMessage(content="what do you call a speechless parrot"),
850856 ]
851857
858+ Trim chat history using approximate token counting with the "approximate" shortcut:
859+
860+ .. code-block:: python
861+
862+ trim_messages(
863+ messages,
864+ max_tokens=45,
865+ strategy="last",
866+ # Using the "approximate" shortcut for fast approximate token counting
867+ token_counter="approximate",
868+ start_on="human",
869+ include_system=True,
870+ )
871+
872+ This is equivalent to using `count_tokens_approximately` directly:
873+
874+ .. code-block:: python
875+
876+ from langchain_core.messages.utils import count_tokens_approximately
877+
878+ trim_messages(
879+ messages,
880+ max_tokens=45,
881+ strategy="last",
882+ token_counter=count_tokens_approximately,
883+ start_on="human",
884+ include_system=True,
885+ )
886+
852887 Trim chat history based on the message count, keeping the SystemMessage if
853888 present, and ensuring that the chat history starts with a HumanMessage (
854889 or a SystemMessage followed by a HumanMessage).
@@ -977,24 +1012,43 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
9771012 raise ValueError (msg )
9781013
9791014 messages = convert_to_messages (messages )
980- if hasattr (token_counter , "get_num_tokens_from_messages" ):
981- list_token_counter = token_counter .get_num_tokens_from_messages
982- elif callable (token_counter ):
1015+
1016+ # Handle string shortcuts for token counter
1017+ if isinstance (token_counter , str ):
1018+ if token_counter in _TOKEN_COUNTER_SHORTCUTS :
1019+ actual_token_counter = _TOKEN_COUNTER_SHORTCUTS [token_counter ]
1020+ else :
1021+ available_shortcuts = ", " .join (
1022+ f"'{ key } '" for key in _TOKEN_COUNTER_SHORTCUTS
1023+ )
1024+ msg = (
1025+ f"Invalid token_counter shortcut '{ token_counter } '. "
1026+ f"Available shortcuts: { available_shortcuts } ."
1027+ )
1028+ raise ValueError (msg )
1029+ else :
1030+ actual_token_counter = token_counter
1031+
1032+ if hasattr (actual_token_counter , "get_num_tokens_from_messages" ):
1033+ list_token_counter = actual_token_counter .get_num_tokens_from_messages # type: ignore[assignment]
1034+ elif callable (actual_token_counter ):
9831035 if (
984- next (iter (inspect .signature (token_counter ).parameters .values ())).annotation
1036+ next (
1037+ iter (inspect .signature (actual_token_counter ).parameters .values ())
1038+ ).annotation
9851039 is BaseMessage
9861040 ):
9871041
9881042 def list_token_counter (messages : Sequence [BaseMessage ]) -> int :
989- return sum (token_counter (msg ) for msg in messages ) # type: ignore[arg-type, misc]
1043+ return sum (actual_token_counter (msg ) for msg in messages ) # type: ignore[arg-type, misc]
9901044
9911045 else :
992- list_token_counter = token_counter
1046+ list_token_counter = actual_token_counter # type: ignore[assignment]
9931047 else :
9941048 msg = (
9951049 f"'token_counter' expected to be a model that implements "
9961050 f"'get_num_tokens_from_messages()' or a function. Received object of type "
997- f"{ type (token_counter )} ."
1051+ f"{ type (actual_token_counter )} ."
9981052 )
9991053 raise ValueError (msg )
10001054
@@ -1754,3 +1808,14 @@ def count_tokens_approximately(
17541808
17551809 # round up once more time in case extra_tokens_per_message is a float
17561810 return math .ceil (token_count )
1811+
1812+
1813+ # Mapping from string shortcuts to token counter functions
1814+ def _approximate_token_counter (messages : Sequence [BaseMessage ]) -> int :
1815+ """Wrapper for count_tokens_approximately that matches expected signature."""
1816+ return count_tokens_approximately (messages )
1817+
1818+
1819+ _TOKEN_COUNTER_SHORTCUTS = {
1820+ "approximate" : _approximate_token_counter ,
1821+ }
0 commit comments