Skip to content

Commit a39959e

Browse files
committed
Fixing gen_answer failover leaving raw_answer blank (#1077)
1 parent d1fc92b commit a39959e

File tree

4 files changed

+19
-18
lines changed

4 files changed

+19
-18
lines changed

src/paperqa/agents/tools.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,6 @@ async def gen_answer(self, state: EnvironmentState) -> str:
309309
Args:
310310
state: Current state.
311311
"""
312-
if not state.docs.docs:
313-
raise EmptyDocsError("Not generating an answer due to having no papers.")
314-
315312
logger.info(f"Generating answer for '{state.session.question}'.")
316313

317314
if f"{self.TOOL_FN_NAME}_initialized" in self.settings.agent.callbacks:

src/paperqa/docs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
NumpyVectorStore,
2727
VectorStore,
2828
)
29-
from paperqa.prompts import CANNOT_ANSWER_PHRASE
29+
from paperqa.prompts import CANNOT_ANSWER_PHRASE, EMPTY_CONTEXTS
3030
from paperqa.readers import read_doc
3131
from paperqa.settings import MaybeSettings, get_settings
3232
from paperqa.types import Doc, DocDetails, DocKey, PQASession, Text
@@ -738,7 +738,7 @@ async def aquery(
738738
contexts = session.contexts
739739
if answer_config.get_evidence_if_no_contexts and not contexts:
740740
session = await self.aget_evidence(
741-
query=session,
741+
session,
742742
callbacks=callbacks,
743743
settings=settings,
744744
embedding_model=embedding_model,
@@ -770,9 +770,10 @@ async def aquery(
770770
pre_str=pre_str,
771771
)
772772

773-
if len(context_str.strip()) < 10: # noqa: PLR2004
773+
if len(context_str.strip()) <= EMPTY_CONTEXTS:
774774
answer_text = (
775-
f"{CANNOT_ANSWER_PHRASE} this question due to insufficient information."
775+
f"{CANNOT_ANSWER_PHRASE} this question due to"
776+
f" {'having no papers' if not self.docs else 'insufficient information.'}."
776777
)
777778
answer_reasoning = None
778779
else:

src/paperqa/prompts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,5 +136,6 @@
136136
)
137137

138138
CONTEXT_OUTER_PROMPT = "{context_str}\n\nValid Keys: {valid_keys}"
139+
EMPTY_CONTEXTS = len(CONTEXT_OUTER_PROMPT.format(context_str="", valid_keys="").strip())
139140
CONTEXT_INNER_PROMPT_NOT_DETAILED = "{name}: {text}"
140141
CONTEXT_INNER_PROMPT = f"{CONTEXT_INNER_PROMPT_NOT_DETAILED}\nFrom {{citation}}"

tests/test_agents.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import wraps
1313
from pathlib import Path
1414
from typing import cast
15-
from unittest.mock import AsyncMock, MagicMock, patch
15+
from unittest.mock import AsyncMock, patch
1616
from uuid import uuid4
1717

1818
import ldp.agent
@@ -21,6 +21,7 @@
2121
Environment,
2222
Tool,
2323
ToolRequestMessage,
24+
ToolResponseMessage,
2425
ToolsAdapter,
2526
ToolSelector,
2627
)
@@ -469,26 +470,27 @@ async def test_timeout(agent_test_settings: Settings, agent_type: str | type) ->
469470
agent_test_settings.agent.timeout = 0.05 # Give time for Environment.reset()
470471
agent_test_settings.llm = "gpt-4o-mini"
471472
agent_test_settings.agent.tool_names = {"gen_answer", "complete"}
472-
docs = Docs()
473+
orig_exec_tool_calls = PaperQAEnvironment.exec_tool_calls
474+
tool_responses: list[list[ToolResponseMessage]] = []
473475

474-
async def custom_aget_evidence(*_, **kwargs) -> PQASession: # noqa: RUF029
475-
return kwargs["query"]
476+
async def spy_exec_tool_calls(*args, **kwargs) -> list[ToolResponseMessage]:
477+
responses = await orig_exec_tool_calls(*args, **kwargs)
478+
tool_responses.append(responses)
479+
return responses
476480

477-
with (
478-
patch.object(docs, "docs", {"stub_key": MagicMock(spec_set=Doc)}),
479-
patch.multiple(
480-
Docs, clear_docs=MagicMock(), aget_evidence=custom_aget_evidence
481-
),
482-
):
481+
with patch.object(PaperQAEnvironment, "exec_tool_calls", spy_exec_tool_calls):
483482
response = await agent_query(
484483
query="Are COVID-19 vaccines effective?",
485484
settings=agent_test_settings,
486-
docs=docs,
487485
agent_type=agent_type,
488486
)
489487
# Ensure that GenerateAnswerTool was called in truncation's failover
490488
assert response.status == AgentStatus.TRUNCATED, "Agent did not timeout"
491489
assert CANNOT_ANSWER_PHRASE in response.session.answer
490+
(last_response,) = tool_responses[-1]
491+
assert (
492+
"no papers" in last_response.content
493+
), "Expecting agent to been shown specifics on the failure"
492494

493495

494496
@pytest.mark.flaky(reruns=5, only_rerun=["AssertionError"])

0 commit comments

Comments
 (0)