|
6 | 6 |
|
7 | 7 | from functools import partial |
8 | 8 |
|
9 | | -from typing import Any, Dict, Mapping, Optional, Union |
| 9 | +from typing import Any, Dict, Optional, Union |
| 10 | + |
| 11 | +from torchtune.data._messages import AlpacaToMessages |
10 | 12 |
|
11 | | -from torchtune.data._messages import Message |
12 | 13 | from torchtune.datasets._packed import PackedDataset |
13 | 14 | from torchtune.datasets._sft import SFTDataset |
14 | 15 | from torchtune.modules.tokenizers import ModelTokenizer |
15 | | -from torchtune.modules.transforms import Transform |
16 | | - |
17 | | - |
18 | | -class AlpacaToMessages(Transform): |
19 | | - """ |
20 | | - Message transform class for Alpaca-style datasets with "instruction", "input", and "output" |
21 | | - (or equivalent fields specified in column_map) columns. User messages are formed from the |
22 | | - instruction + input columns and assistant messages are formed from the output column. Prompt |
23 | | - templating is conditional on the presence of the "input" column, and thus is handled directly |
24 | | - in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class |
25 | | - due to this custom logic. |
26 | | -
|
27 | | - Args: |
28 | | - train_on_input (bool): Whether the model is trained on the user prompt or not. |
29 | | - Default is True. |
30 | | - column_map (Optional[Dict[str, str]]): a mapping to change the expected "instruction", "input", |
31 | | - and "output" column names to the actual column names in the dataset. Default is None, |
32 | | - keeping the default column names. |
33 | | - """ |
34 | | - |
35 | | - def __init__( |
36 | | - self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None |
37 | | - ): |
38 | | - self.train_on_input = train_on_input |
39 | | - self.column_map = column_map |
40 | | - self.template = { |
41 | | - "prompt_input": ( |
42 | | - "Below is an instruction that describes a task, paired with an input that provides further context. " |
43 | | - "Write a response that appropriately completes the request.\n\n" |
44 | | - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" |
45 | | - ), |
46 | | - "prompt_no_input": ( |
47 | | - "Below is an instruction that describes a task. " |
48 | | - "Write a response that appropriately completes the request.\n\n" |
49 | | - "### Instruction:\n{instruction}\n\n### Response:\n" |
50 | | - ), |
51 | | - } |
52 | | - |
53 | | - def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: |
54 | | - column_map = self.column_map or {} |
55 | | - key_input = column_map.get("input", "input") |
56 | | - key_instruction = column_map.get("instruction", "instruction") |
57 | | - key_output = column_map.get("output", "output") |
58 | | - |
59 | | - if key_input in sample and sample[key_input]: |
60 | | - prompt = self.template["prompt_input"].format( |
61 | | - instruction=sample[key_instruction], input=sample[key_input] |
62 | | - ) |
63 | | - else: |
64 | | - prompt = self.template["prompt_no_input"].format( |
65 | | - instruction=sample[key_instruction] |
66 | | - ) |
67 | | - |
68 | | - messages = [ |
69 | | - Message( |
70 | | - role="user", |
71 | | - content=prompt, |
72 | | - masked=not self.train_on_input, |
73 | | - eot=True, |
74 | | - ), |
75 | | - Message( |
76 | | - role="assistant", |
77 | | - content=sample[key_output], |
78 | | - masked=False, |
79 | | - eot=True, |
80 | | - ), |
81 | | - ] |
82 | | - return {"messages": messages} |
83 | 16 |
|
84 | 17 |
|
85 | 18 | def alpaca_dataset( |
|
0 commit comments