Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,7 @@ will return much faster than the first query and we'll be certain the authors ma
| `batch_size` | `1` | Batch size for calling LLMs. |
| `texts_index_mmr_lambda` | `1.0` | Lambda for MMR in text index. |
| `verbosity` | `0` | Integer verbosity level for logging (0-3). 3 = all LLM/Embeddings calls logged. |
| `custom_context_serializer` | `None` | Custom function (see typing for signature) to override the default answer context serialization. |
| `answer.evidence_k` | `10` | Number of evidence pieces to retrieve. |
| `answer.evidence_detailed_citations` | `True` | Include detailed citations in summaries. |
| `answer.evidence_retrieval` | `True` | Use retrieval vs processing all docs. |
Expand Down
77 changes: 7 additions & 70 deletions src/paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import tempfile
import urllib.request
import warnings
from collections import defaultdict
from collections.abc import Callable, Sequence
from datetime import datetime
from io import BytesIO
Expand All @@ -30,7 +29,7 @@
from paperqa.prompts import CANNOT_ANSWER_PHRASE
from paperqa.readers import read_doc
from paperqa.settings import MaybeSettings, get_settings
from paperqa.types import Context, Doc, DocDetails, DocKey, PQASession, Text
from paperqa.types import Doc, DocDetails, DocKey, PQASession, Text
from paperqa.utils import (
citation_to_docname,
get_loop,
Expand Down Expand Up @@ -710,7 +709,7 @@ def query(
)
)

async def aquery( # noqa: PLR0912
async def aquery(
self,
query: PQASession | str,
settings: MaybeSettings = None,
Expand Down Expand Up @@ -765,75 +764,13 @@ async def aquery( # noqa: PLR0912
session.add_tokens(pre)
pre_str = pre.text

# sort by first score, then name
filtered_contexts = sorted(
contexts,
key=lambda x: (-x.score, x.text.name),
)[: answer_config.answer_max_sources]
# remove any contexts with a score below the cutoff
filtered_contexts = [
c
for c in filtered_contexts
if c.score >= answer_config.evidence_relevance_score_cutoff
]

# shim deprecated flag
# TODO: remove in v6
context_inner_prompt = prompt_config.context_inner
if (
not answer_config.evidence_detailed_citations
and "\nFrom {citation}" in context_inner_prompt
):
# Only keep "\nFrom {citation}" if we are showing detailed citations
context_inner_prompt = context_inner_prompt.replace("\nFrom {citation}", "")

context_str_body = ""
if answer_config.group_contexts_by_question:
contexts_by_question: dict[str, list[Context]] = defaultdict(list)
for c in filtered_contexts:
# Fallback to the main session question if not available.
# question attribute is optional, so if a user
# sets contexts externally, it may not have a question.
question = getattr(c, "question", session.question)
contexts_by_question[question].append(c)

context_sections = []
for question, contexts_in_group in contexts_by_question.items():
inner_strs = [
context_inner_prompt.format(
name=c.id,
text=c.context,
citation=c.text.doc.formatted_citation,
**(c.model_extra or {}),
)
for c in contexts_in_group
]
# Create a section with a question heading
section_header = f'Contexts related to the question: "{question}"'
section = f"{section_header}\n\n" + "\n\n".join(inner_strs)
context_sections.append(section)
context_str_body = "\n\n----\n\n".join(context_sections)
else:
inner_context_strs = [
context_inner_prompt.format(
name=c.id,
text=c.context,
citation=c.text.doc.formatted_citation,
**(c.model_extra or {}),
)
for c in filtered_contexts
]
context_str_body = "\n\n".join(inner_context_strs)

if pre_str:
context_str_body += f"\n\nExtra background information: {pre_str}"

context_str = prompt_config.context_outer.format(
context_str=context_str_body,
valid_keys=", ".join([c.id for c in filtered_contexts]),
context_str = query_settings.context_serializer(
contexts=contexts,
question=session.question,
pre_str=pre_str,
)

if len(context_str_body.strip()) < 10: # noqa: PLR2004
if len(context_str.strip()) < 10: # noqa: PLR2004
answer_text = (
f"{CANNOT_ANSWER_PHRASE} this question due to insufficient information."
)
Expand Down
115 changes: 113 additions & 2 deletions src/paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@
import os
import pathlib
import warnings
from collections import defaultdict
from collections.abc import Callable, Mapping, Sequence
from enum import StrEnum
from pydoc import locate
from typing import Any, ClassVar, Self, TypeAlias, assert_never, cast
from typing import (
Any,
ClassVar,
Protocol,
Self,
TypeAlias,
assert_never,
cast,
runtime_checkable,
)

import anyio
from aviary.core import Tool, ToolSelector
Expand Down Expand Up @@ -55,6 +65,7 @@
summary_prompt,
)
from paperqa.readers import PDFParserFn
from paperqa.types import Context
from paperqa.utils import hexdigest, pqa_directory
from paperqa.version import __version__

Expand All @@ -63,8 +74,21 @@
_EnvironmentState: TypeAlias = Any


@runtime_checkable
class ContextSerializer(Protocol):
"""Protocol for generating a context string from settings and context."""

def __call__(
self,
settings: "Settings",
contexts: list[Context],
question: str,
pre_str: str | None,
) -> str: ...


class AnswerSettings(BaseModel):
model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)

evidence_k: int = Field(
default=10, description="Number of evidence pieces to retrieve."
Expand Down Expand Up @@ -791,6 +815,11 @@ class Settings(BaseSettings):
exclude=True,
frozen=True,
)
custom_context_serializer: ContextSerializer | None = Field(
default=None,
description="Function to turn settings and contexts into an answer context str.",
exclude=True,
)

@model_validator(mode="after")
def _deprecated_field(self) -> Self:
Expand Down Expand Up @@ -1026,6 +1055,88 @@ def adjust_tools_for_agent_llm(self, tools: list[Tool]) -> None:
# Gemini fixed this server-side by mid-April 2025,
# so this method is now just available for use

def context_serializer(
self, contexts: list[Context], question: str, pre_str: str | None
) -> str:
"""Default function for sorting ranked contexts and inserting into a context string."""
if self.custom_context_serializer:
return self.custom_context_serializer(
settings=self, contexts=contexts, question=question, pre_str=pre_str
)

answer_config = self.answer
prompt_config = self.prompts

# sort by first score, then name
filtered_contexts = sorted(
contexts,
key=lambda x: (-x.score, x.text.name),
)[: answer_config.answer_max_sources]
# remove any contexts with a score below the cutoff
filtered_contexts = [
c
for c in filtered_contexts
if c.score >= answer_config.evidence_relevance_score_cutoff
]

# shim deprecated flag
# TODO: remove in v6
context_inner_prompt = prompt_config.context_inner
if (
not answer_config.evidence_detailed_citations
and "\nFrom {citation}" in context_inner_prompt
):
# Only keep "\nFrom {citation}" if we are showing detailed citations
context_inner_prompt = context_inner_prompt.replace("\nFrom {citation}", "")

context_str_body = ""
if answer_config.group_contexts_by_question:
contexts_by_question: dict[str, list[Context]] = defaultdict(list)
for c in filtered_contexts:
# Fallback to the main session question if not available.
# question attribute is optional, so if a user
# sets contexts externally, it may not have a question.
context_question = getattr(c, "question", question)
contexts_by_question[context_question].append(c)

context_sections = []
for context_question, contexts_in_group in contexts_by_question.items():
inner_strs = [
context_inner_prompt.format(
name=c.id,
text=c.context,
citation=c.text.doc.formatted_citation,
**(c.model_extra or {}),
)
for c in contexts_in_group
]
# Create a section with a question heading
section_header = (
f'Contexts related to the question: "{context_question}"'
)
section = f"{section_header}\n\n" + "\n\n".join(inner_strs)
context_sections.append(section)
context_str_body = "\n\n----\n\n".join(context_sections)
else:
inner_context_strs = [
context_inner_prompt.format(
name=c.id,
text=c.context,
citation=c.text.doc.formatted_citation,
**(c.model_extra or {}),
)
for c in filtered_contexts
]
context_str_body = "\n\n".join(inner_context_strs)

if pre_str:
context_str_body += f"\n\nExtra background information: {pre_str}"

return prompt_config.context_outer.format(
context_str=context_str_body,
valid_keys=", ".join([c.id for c in filtered_contexts]),
)


# Settings: already Settings
# dict[str, Any]: serialized Settings
Expand Down
29 changes: 29 additions & 0 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from paperqa.prompts import CANNOT_ANSWER_PHRASE
from paperqa.prompts import qa_prompt as default_qa_prompt
from paperqa.readers import PDFParserFn, read_doc
from paperqa.settings import ContextSerializer
from paperqa.types import ChunkMetadata, Context
from paperqa.utils import (
clean_possessives,
Expand Down Expand Up @@ -566,6 +567,34 @@ async def test_query(docs_fixture) -> None:
await docs_fixture.aquery("Is XAI usable in chemistry?", settings=settings)


@pytest.mark.asyncio
async def test_custom_context_str_fn(docs_fixture) -> None:

def custom_context_str_fn(
settings: Settings, # noqa: ARG001
contexts: list[Context], # noqa: ARG001
question: str, # noqa: ARG001
pre_str: str | None = None, # noqa: ARG001
) -> str:
return "TEST OVERRIDE"

assert isinstance(custom_context_str_fn, ContextSerializer)

settings = Settings(
custom_context_serializer=custom_context_str_fn,
prompts={"answer_iteration_prompt": None},
)

session = await docs_fixture.aquery(
"Is XAI usable in chemistry?", settings=settings
)
assert (
session.context == "TEST OVERRIDE"
), "Expected custom context string to be returned."

assert session.contexts, "Expected contexts to be present in session."


@pytest.mark.asyncio
async def test_aquery_groups_contexts_by_question(docs_fixture) -> None:

Expand Down
Loading