Skip to content

Commit 364f8e7

Browse files
authored
Better Entity Memory code documentation (#6318)
Just adds some comments and docstring improvements. There was some behaviour that was quite unclear to me at first like: - "when do things get updated?" - "why are there only entity names and no summaries?" - "why do the entity names disappear?" Now it can be much more obvious to many. I am lukestanley on Twitter.
1 parent af18413 commit 364f8e7

File tree

1 file changed

+65
-3
lines changed

1 file changed

+65
-3
lines changed

langchain/memory/entity.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,20 +241,35 @@ def clear(self) -> None:
241241

242242

243243
class ConversationEntityMemory(BaseChatMemory):
244-
"""Entity extractor & summarizer to memory."""
244+
"""Entity extractor & summarizer memory.
245+
246+
Extracts named entities from the recent chat history and generates summaries.
247+
With a swapable entity store, persisting entities across conversations.
248+
Defaults to an in-memory entity store, and can be swapped out for a Redis,
249+
SQLite, or other entity store.
250+
"""
245251

246252
human_prefix: str = "Human"
247253
ai_prefix: str = "AI"
248254
llm: BaseLanguageModel
249255
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
250256
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
257+
258+
# Cache of recently detected entity names, if any
259+
# It is updated when load_memory_variables is called:
251260
entity_cache: List[str] = []
261+
262+
# Number of recent message pairs to consider when updating entities:
252263
k: int = 3
264+
253265
chat_history_key: str = "history"
266+
267+
# Store to manage entity-related data:
254268
entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore)
255269

256270
@property
257271
def buffer(self) -> List[BaseMessage]:
272+
"""Access chat memory messages."""
258273
return self.chat_memory.messages
259274

260275
@property
@@ -266,63 +281,110 @@ def memory_variables(self) -> List[str]:
266281
return ["entities", self.chat_history_key]
267282

268283
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
269-
"""Return history buffer."""
284+
"""
285+
Returns chat history and all generated entities with summaries if available,
286+
and updates or clears the recent entity cache.
287+
288+
New entity name can be found when calling this method, before the entity
289+
summaries are generated, so the entity cache values may be empty if no entity
290+
descriptions are generated yet.
291+
"""
292+
293+
# Create an LLMChain for predicting entity names from the recent chat history:
270294
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
295+
271296
if self.input_key is None:
272297
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
273298
else:
274299
prompt_input_key = self.input_key
300+
301+
# Extract an arbitrary window of the last message pairs from
302+
# the chat history, where the hyperparameter k is the
303+
# number of message pairs:
275304
buffer_string = get_buffer_string(
276305
self.buffer[-self.k * 2 :],
277306
human_prefix=self.human_prefix,
278307
ai_prefix=self.ai_prefix,
279308
)
309+
310+
# Generates a comma-separated list of named entities,
311+
# e.g. "Jane, White House, UFO"
312+
# or "NONE" if no named entities are extracted:
280313
output = chain.predict(
281314
history=buffer_string,
282315
input=inputs[prompt_input_key],
283316
)
317+
318+
# If no named entities are extracted, assigns an empty list.
284319
if output.strip() == "NONE":
285320
entities = []
286321
else:
322+
# Make a list of the extracted entities:
287323
entities = [w.strip() for w in output.split(",")]
324+
325+
# Make a dictionary of entities with summary if exists:
288326
entity_summaries = {}
327+
289328
for entity in entities:
290329
entity_summaries[entity] = self.entity_store.get(entity, "")
330+
331+
# Replaces the entity name cache with the most recently discussed entities,
332+
# or if no entities were extracted, clears the cache:
291333
self.entity_cache = entities
334+
335+
# Should we return as message objects or as a string?
292336
if self.return_messages:
337+
# Get last `k` pair of chat messages:
293338
buffer: Any = self.buffer[-self.k * 2 :]
294339
else:
340+
# Reuse the string we made earlier:
295341
buffer = buffer_string
342+
296343
return {
297344
self.chat_history_key: buffer,
298345
"entities": entity_summaries,
299346
}
300347

301348
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
302-
"""Save context from this conversation to buffer."""
349+
"""
350+
Save context from this conversation history to the entity store.
351+
352+
Generates a summary for each entity in the entity cache by prompting
353+
the model, and saves these summaries to the entity store.
354+
"""
355+
303356
super().save_context(inputs, outputs)
304357

305358
if self.input_key is None:
306359
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
307360
else:
308361
prompt_input_key = self.input_key
309362

363+
# Extract an arbitrary window of the last message pairs from
364+
# the chat history, where the hyperparameter k is the
365+
# number of message pairs:
310366
buffer_string = get_buffer_string(
311367
self.buffer[-self.k * 2 :],
312368
human_prefix=self.human_prefix,
313369
ai_prefix=self.ai_prefix,
314370
)
371+
315372
input_data = inputs[prompt_input_key]
373+
374+
# Create an LLMChain for predicting entity summarization from the context
316375
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
317376

377+
# Generate new summaries for entities and save them in the entity store
318378
for entity in self.entity_cache:
379+
# Get existing summary if it exists
319380
existing_summary = self.entity_store.get(entity, "")
320381
output = chain.predict(
321382
summary=existing_summary,
322383
entity=entity,
323384
history=buffer_string,
324385
input=input_data,
325386
)
387+
# Save the updated summary to the entity store
326388
self.entity_store.set(entity, output.strip())
327389

328390
def clear(self) -> None:

0 commit comments

Comments
 (0)