@@ -241,20 +241,35 @@ def clear(self) -> None:
241241
242242
243243class ConversationEntityMemory (BaseChatMemory ):
244- """Entity extractor & summarizer to memory."""
244+ """Entity extractor & summarizer memory.
245+
246+ Extracts named entities from the recent chat history and generates summaries.
247+ With a swapable entity store, persisting entities across conversations.
248+ Defaults to an in-memory entity store, and can be swapped out for a Redis,
249+ SQLite, or other entity store.
250+ """
245251
246252 human_prefix : str = "Human"
247253 ai_prefix : str = "AI"
248254 llm : BaseLanguageModel
249255 entity_extraction_prompt : BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
250256 entity_summarization_prompt : BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
257+
258+ # Cache of recently detected entity names, if any
259+ # It is updated when load_memory_variables is called:
251260 entity_cache : List [str ] = []
261+
262+ # Number of recent message pairs to consider when updating entities:
252263 k : int = 3
264+
253265 chat_history_key : str = "history"
266+
267+ # Store to manage entity-related data:
254268 entity_store : BaseEntityStore = Field (default_factory = InMemoryEntityStore )
255269
256270 @property
257271 def buffer (self ) -> List [BaseMessage ]:
272+ """Access chat memory messages."""
258273 return self .chat_memory .messages
259274
260275 @property
@@ -266,63 +281,110 @@ def memory_variables(self) -> List[str]:
266281 return ["entities" , self .chat_history_key ]
267282
268283 def load_memory_variables (self , inputs : Dict [str , Any ]) -> Dict [str , Any ]:
269- """Return history buffer."""
284+ """
285+ Returns chat history and all generated entities with summaries if available,
286+ and updates or clears the recent entity cache.
287+
288+ New entity name can be found when calling this method, before the entity
289+ summaries are generated, so the entity cache values may be empty if no entity
290+ descriptions are generated yet.
291+ """
292+
293+ # Create an LLMChain for predicting entity names from the recent chat history:
270294 chain = LLMChain (llm = self .llm , prompt = self .entity_extraction_prompt )
295+
271296 if self .input_key is None :
272297 prompt_input_key = get_prompt_input_key (inputs , self .memory_variables )
273298 else :
274299 prompt_input_key = self .input_key
300+
301+ # Extract an arbitrary window of the last message pairs from
302+ # the chat history, where the hyperparameter k is the
303+ # number of message pairs:
275304 buffer_string = get_buffer_string (
276305 self .buffer [- self .k * 2 :],
277306 human_prefix = self .human_prefix ,
278307 ai_prefix = self .ai_prefix ,
279308 )
309+
310+ # Generates a comma-separated list of named entities,
311+ # e.g. "Jane, White House, UFO"
312+ # or "NONE" if no named entities are extracted:
280313 output = chain .predict (
281314 history = buffer_string ,
282315 input = inputs [prompt_input_key ],
283316 )
317+
318+ # If no named entities are extracted, assigns an empty list.
284319 if output .strip () == "NONE" :
285320 entities = []
286321 else :
322+ # Make a list of the extracted entities:
287323 entities = [w .strip () for w in output .split ("," )]
324+
325+ # Make a dictionary of entities with summary if exists:
288326 entity_summaries = {}
327+
289328 for entity in entities :
290329 entity_summaries [entity ] = self .entity_store .get (entity , "" )
330+
331+ # Replaces the entity name cache with the most recently discussed entities,
332+ # or if no entities were extracted, clears the cache:
291333 self .entity_cache = entities
334+
335+ # Should we return as message objects or as a string?
292336 if self .return_messages :
337+ # Get last `k` pair of chat messages:
293338 buffer : Any = self .buffer [- self .k * 2 :]
294339 else :
340+ # Reuse the string we made earlier:
295341 buffer = buffer_string
342+
296343 return {
297344 self .chat_history_key : buffer ,
298345 "entities" : entity_summaries ,
299346 }
300347
301348 def save_context (self , inputs : Dict [str , Any ], outputs : Dict [str , str ]) -> None :
302- """Save context from this conversation to buffer."""
349+ """
350+ Save context from this conversation history to the entity store.
351+
352+ Generates a summary for each entity in the entity cache by prompting
353+ the model, and saves these summaries to the entity store.
354+ """
355+
303356 super ().save_context (inputs , outputs )
304357
305358 if self .input_key is None :
306359 prompt_input_key = get_prompt_input_key (inputs , self .memory_variables )
307360 else :
308361 prompt_input_key = self .input_key
309362
363+ # Extract an arbitrary window of the last message pairs from
364+ # the chat history, where the hyperparameter k is the
365+ # number of message pairs:
310366 buffer_string = get_buffer_string (
311367 self .buffer [- self .k * 2 :],
312368 human_prefix = self .human_prefix ,
313369 ai_prefix = self .ai_prefix ,
314370 )
371+
315372 input_data = inputs [prompt_input_key ]
373+
374+ # Create an LLMChain for predicting entity summarization from the context
316375 chain = LLMChain (llm = self .llm , prompt = self .entity_summarization_prompt )
317376
377+ # Generate new summaries for entities and save them in the entity store
318378 for entity in self .entity_cache :
379+ # Get existing summary if it exists
319380 existing_summary = self .entity_store .get (entity , "" )
320381 output = chain .predict (
321382 summary = existing_summary ,
322383 entity = entity ,
323384 history = buffer_string ,
324385 input = input_data ,
325386 )
387+ # Save the updated summary to the entity store
326388 self .entity_store .set (entity , output .strip ())
327389
328390 def clear (self ) -> None :
0 commit comments