-
Notifications
You must be signed in to change notification settings - Fork 23k
Expand file tree
/
Copy pathbase.py
More file actions
472 lines (384 loc) · 15.7 KB
/
base.py
File metadata and controls
472 lines (384 loc) · 15.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
"""Base class for prompt templates."""
from __future__ import annotations
import contextlib
import json
import typing
from abc import ABC, abstractmethod
from collections.abc import Mapping
from functools import cached_property
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
TypeVar,
Union,
)
import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self, override
from langchain_core.exceptions import ErrorCode, create_message
from langchain_core.load import dumpd
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
PromptValue,
StringPromptValue,
)
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import ensure_config
from langchain_core.utils.pydantic import create_model_v2
if TYPE_CHECKING:
from langchain_core.documents import Document
FormatOutputType = TypeVar("FormatOutputType")
class BasePromptTemplate(
RunnableSerializable[dict, PromptValue], Generic[FormatOutputType], ABC
):
"""Base class for all prompt templates, returning a prompt."""
input_variables: list[str]
"""A list of the names of the variables whose values are required as inputs to the
prompt."""
optional_variables: list[str] = Field(default=[])
"""optional_variables: A list of the names of the variables for placeholder
or MessagePlaceholder that are optional. These variables are auto inferred
from the prompt and user need not provide them."""
input_types: typing.Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
"""A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Any] = Field(default_factory=dict)
"""A dictionary of the partial variables the prompt template carries.
Partial variables populate the template so that you don't need to
pass them in every time you call the prompt."""
metadata: Optional[typing.Dict[str, Any]] = None # noqa: UP006
"""Metadata to be used for tracing."""
tags: Optional[list[str]] = None
"""Tags to be used for tracing."""
@model_validator(mode="after")
def validate_variable_names(self) -> Self:
"""Validate variable names do not include restricted names."""
if "stop" in self.input_variables:
msg = (
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
raise ValueError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
if "stop" in self.partial_variables:
msg = (
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
raise ValueError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
overall = set(self.input_variables).intersection(self.partial_variables)
if overall:
msg = f"Found overlapping input and partial variables: {overall}"
raise ValueError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
return self
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Returns ["langchain", "schema", "prompt_template"].
"""
return ["langchain", "schema", "prompt_template"]
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable.
Returns True.
"""
return True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
@property
@override
def OutputType(self) -> Any:
"""Return the output type of the prompt."""
return Union[StringPromptValue, ChatPromptValueConcrete]
@override
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
"""Get the input schema for the prompt.
Args:
config: RunnableConfig, configuration for the prompt.
Returns:
Type[BaseModel]: The input schema for the prompt.
"""
# This is correct, but pydantic typings/mypy don't think so.
required_input_variables = {
k: (self.input_types.get(k, str), ...) for k in self.input_variables
}
optional_input_variables = {
k: (self.input_types.get(k, str), None) for k in self.optional_variables
}
return create_model_v2(
"PromptInput",
field_definitions={**required_input_variables, **optional_input_variables},
)
def _validate_input(self, inner_input: Any) -> dict:
if not isinstance(inner_input, dict):
if len(self.input_variables) == 1:
var_name = self.input_variables[0]
inner_input = {var_name: inner_input}
else:
msg = (
f"Expected mapping type as input to {self.__class__.__name__}. "
f"Received {type(inner_input)}."
)
raise TypeError(
create_message(
message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT
)
)
missing = set(self.input_variables).difference(inner_input)
if missing:
msg = (
f"Input to {self.__class__.__name__} is missing variables {missing}. "
f" Expected: {self.input_variables}"
f" Received: {list(inner_input.keys())}"
)
example_key = missing.pop()
msg += (
f"\nNote: if you intended {{{example_key}}} to be part of the string"
" and not a variable, please escape it with double curly braces like: "
f"'{{{{{example_key}}}}}'."
)
raise KeyError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
return inner_input
def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return self.format_prompt(**_inner_input)
async def _aformat_prompt_with_error_handling(
self, inner_input: dict
) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return await self.aformat_prompt(**_inner_input)
@override
def invoke(
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> PromptValue:
"""Invoke the prompt.
Args:
input: Dict, input to the prompt.
config: RunnableConfig, configuration for the prompt.
Returns:
PromptValue: The output of the prompt.
"""
config = ensure_config(config)
if self.metadata:
config["metadata"] = {**config["metadata"], **self.metadata}
if self.tags:
config["tags"] = config["tags"] + self.tags
return self._call_with_config(
self._format_prompt_with_error_handling,
input,
config,
run_type="prompt",
serialized=self._serialized,
)
@override
async def ainvoke(
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> PromptValue:
"""Async invoke the prompt.
Args:
input: Dict, input to the prompt.
config: RunnableConfig, configuration for the prompt.
Returns:
PromptValue: The output of the prompt.
"""
config = ensure_config(config)
if self.metadata:
config["metadata"].update(self.metadata)
if self.tags:
config["tags"].extend(self.tags)
return await self._acall_with_config(
self._aformat_prompt_with_error_handling,
input,
config,
run_type="prompt",
serialized=self._serialized,
)
@abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Prompt Value.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
PromptValue: The output of the prompt.
"""
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
"""Async create Prompt Value.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
PromptValue: The output of the prompt.
"""
return self.format_prompt(**kwargs)
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
"""Return a partial of the prompt template.
Args:
kwargs: Union[str, Callable[[], str]], partial variables to set.
Returns:
BasePromptTemplate: A partial of the prompt template.
"""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if not callable(v) else v() for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}
@abstractmethod
def format(self, **kwargs: Any) -> FormatOutputType:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
async def aformat(self, **kwargs: Any) -> FormatOutputType:
"""Async format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
await prompt.aformat(variable1="foo")
"""
return self.format(**kwargs)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> dict:
"""Return dictionary representation of prompt.
Args:
kwargs: Any additional arguments to pass to the dictionary.
Returns:
Dict: Dictionary representation of the prompt.
Raises:
NotImplementedError: If the prompt type is not implemented.
"""
prompt_dict = super().model_dump(**kwargs)
with contextlib.suppress(NotImplementedError):
prompt_dict["_type"] = self._prompt_type
return prompt_dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the prompt.
Args:
file_path: Path to directory to save prompt to.
Raises:
ValueError: If the prompt has partial variables.
ValueError: If the file path is not json or yaml.
NotImplementedError: If the prompt type is not implemented.
Example:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
"""
if self.partial_variables:
msg = "Cannot save prompt with partial variables."
raise ValueError(msg)
# Fetch dictionary to save
prompt_dict = self.dict()
if "_type" not in prompt_dict:
msg = f"Prompt {self} does not support saving."
raise NotImplementedError(msg)
# Convert file to Path object.
save_path = Path(file_path)
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with save_path.open("w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")):
with save_path.open("w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
msg = f"{save_path} must be json or yaml"
raise ValueError(msg)
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> dict:
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
msg = (
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
raise ValueError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
return {k: base_info[k] for k in prompt.input_variables}
def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
1. page_content:
This takes the information from the `document.page_content`
and assigns it to a variable named `page_content`.
2. metadata:
This takes information from `document.metadata` and assigns
it to variables of the same name.
Those variables are then passed into the `prompt` to produce a formatted string.
Args:
doc: Document, the page_content and metadata will be used to create
the final string.
prompt: BasePromptTemplate, will be used to format the page_content
and metadata into the final string.
Returns:
string of the document formatted.
Example:
.. code-block:: python
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
doc = Document(page_content="This is a joke", metadata={"page": "1"})
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
format_document(doc, prompt)
>>> "Page 1: This is a joke"
"""
return prompt.format(**_get_document_info(doc, prompt))
async def aformat_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
"""Async format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
1. page_content:
This takes the information from the `document.page_content`
and assigns it to a variable named `page_content`.
2. metadata:
This takes information from `document.metadata` and assigns
it to variables of the same name.
Those variables are then passed into the `prompt` to produce a formatted string.
Args:
doc: Document, the page_content and metadata will be used to create
the final string.
prompt: BasePromptTemplate, will be used to format the page_content
and metadata into the final string.
Returns:
string of the document formatted.
"""
return await prompt.aformat(**_get_document_info(doc, prompt))