Skip to content

Commit 73647e2

Browse files
authored
Configure max_seq_len in InstructDataset (#620)
1 parent 6d9368f commit 73647e2

File tree

6 files changed

+82
-12
lines changed

6 files changed

+82
-12
lines changed

tests/test_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,23 @@
3333

3434

3535
class DummyTokenizer:
36-
def encode(self, text, **kwargs):
36+
def encode(self, text, add_bos=True, add_eos=True, **kwargs):
3737
words = text.split()
38-
return [len(word) for word in words]
38+
tokens = [len(word) for word in words]
39+
if add_bos:
40+
tokens = [self.bos_id] + tokens
41+
if add_eos:
42+
tokens = tokens + [self.eos_id]
43+
return tokens
3944

4045
@property
4146
def eos_id(self):
4247
return -1
4348

49+
@property
50+
def bos_id(self):
51+
return 0
52+
4453

4554
def get_assets_path():
4655
return Path(__file__).parent / "assets"

tests/torchtune/data/test_data_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ def test_tokenize_prompt_and_response():
1313
tokenizer = DummyTokenizer()
1414
prompt = "Instruction:\nThis is an instruction.\n\nInput:\nThis is an input.\n\nResponse: "
1515
response = "I always know what I'm doing, do you?"
16-
prompt_length = 11
16+
prompt_length = 12
1717
expected_tokenized_prompt = [
18+
0,
1819
12,
1920
4,
2021
2,
@@ -34,6 +35,7 @@ def test_tokenize_prompt_and_response():
3435
6,
3536
2,
3637
4,
38+
-1,
3739
]
3840
expected_tokenized_label = [CROSS_ENTROPY_IGNORE_IDX] * prompt_length + [
3941
1,
@@ -44,6 +46,7 @@ def test_tokenize_prompt_and_response():
4446
6,
4547
2,
4648
4,
49+
-1,
4750
]
4851

4952
tokenized_prompt, tokenized_label = tokenize_prompt_and_response(

tests/torchtune/datasets/test_chat_dataset.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_get_item(self, mock_load_dataset, template, dialogue):
126126
mock_load_dataset.return_value = dialogue
127127
expected_tokenized_prompts = [
128128
[
129+
0,
129130
7,
130131
3,
131132
3,
@@ -146,15 +147,18 @@ def test_get_item(self, mock_load_dataset, template, dialogue):
146147
4,
147148
2,
148149
3,
150+
-1,
151+
0,
149152
5,
150153
6,
151154
11,
152155
10,
153156
1,
157+
6,
154158
-1,
155159
]
156160
]
157-
prompt_lengths = (14, 4)
161+
prompt_lengths = (15, 5)
158162
expected_labels = [
159163
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0]
160164
+ [
@@ -164,9 +168,10 @@ def test_get_item(self, mock_load_dataset, template, dialogue):
164168
4,
165169
2,
166170
3,
171+
-1,
167172
]
168173
+ [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1]
169-
+ [1, -1]
174+
+ [1, 6, -1]
170175
]
171176

172177
ds = ChatDataset(

tests/torchtune/datasets/test_instruct_dataset.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,34 @@ class TestInstructDataset:
3232
"Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse: "
3333
)
3434
expected_tokenized_prompts = [
35-
[12, 4, 2, 3, 2, 12, 10, 6, 4, 2, 3, 2, 6, 10, 9, 1, 5, 4, 4, 3, 6, 2, 4],
36-
[12, 4, 2, 2, 12, 10, 6, 4, 2, 2, 6, 10, 9, 1, 6, 4, 4, 3, 6, 2, 4],
35+
[
36+
0,
37+
12,
38+
4,
39+
2,
40+
3,
41+
2,
42+
12,
43+
10,
44+
6,
45+
4,
46+
2,
47+
3,
48+
2,
49+
6,
50+
10,
51+
9,
52+
1,
53+
5,
54+
4,
55+
4,
56+
3,
57+
6,
58+
2,
59+
4,
60+
-1,
61+
],
62+
[0, 12, 4, 2, 2, 12, 10, 6, 4, 2, 2, 6, 10, 9, 1, 6, 4, 4, 3, 6, 2, 4, -1],
3763
]
3864

3965
def get_samples(self):
@@ -53,10 +79,12 @@ def get_samples(self):
5379
@mock.patch("torchtune.datasets._instruct.load_dataset")
5480
def test_get_item_no_train_on_input(self, mock_load_dataset):
5581
mock_load_dataset.return_value = self.get_samples()
56-
prompt_lengths = (15, 13)
82+
prompt_lengths = (16, 14)
5783
expected_labels = [
58-
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] + [1, 5, 4, 4, 3, 6, 2, 4],
59-
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1] + [1, 6, 4, 4, 3, 6, 2, 4],
84+
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0]
85+
+ [1, 5, 4, 4, 3, 6, 2, 4, -1],
86+
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1]
87+
+ [1, 6, 4, 4, 3, 6, 2, 4, -1],
6088
]
6189

6290
dataset = InstructDataset(

torchtune/datasets/_alpaca.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def alpaca_dataset(
1313
tokenizer: Tokenizer,
1414
train_on_input: bool = True,
1515
use_clean: bool = False,
16+
max_seq_len: int = 512,
1617
) -> InstructDataset:
1718
"""
1819
Support for the Alpaca dataset and its variants from Hugging Face Datasets.
@@ -39,6 +40,10 @@ def alpaca_dataset(
3940
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
4041
train_on_input (bool): Whether the model is trained on the prompt or not. Default is True.
4142
use_clean (bool): Whether to use the cleaned version of the dataset or not. Default is False.
43+
max_seq_len (int): Maximum number of tokens in the returned input and label token id lists.
44+
Default is 512, as set by Stanford Alpaca (https://github.com/tatsu-lab/stanford_alpaca?tab=readme-ov-file#fine-tuning),
45+
but we recommend setting this to the highest you can fit in memory and is supported by the model.
46+
For example, llama2-7B supports up to 4096 for sequence length.
4247
4348
Returns:
4449
InstructDataset: dataset configured with Alpaca source data and template
@@ -56,5 +61,6 @@ def alpaca_dataset(
5661
source="yahma/alpaca-cleaned" if use_clean else "tatsu-lab/alpaca",
5762
template=AlpacaInstructTemplate(),
5863
train_on_input=train_on_input,
64+
max_seq_len=max_seq_len,
5965
split="train",
6066
)

torchtune/datasets/_instruct.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from torchtune.config._utils import _get_template
1313

14-
from torchtune.data import PromptTemplate, tokenize_prompt_and_response
14+
from torchtune.data import PromptTemplate, tokenize_prompt_and_response, truncate
1515
from torchtune.modules import Tokenizer
1616

1717

@@ -43,6 +43,9 @@ class InstructDataset(Dataset):
4343
column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names in the template
4444
to the column/key names in the sample. If None, assume these are identical.
4545
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
46+
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
47+
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
48+
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
4649
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to `load_dataset`.
4750
"""
4851

@@ -54,6 +57,7 @@ def __init__(
5457
transform: Optional[Callable] = None,
5558
column_map: Optional[Dict[str, str]] = None,
5659
train_on_input: bool = False,
60+
max_seq_len: Optional[int] = None,
5761
**load_dataset_kwargs: Dict[str, Any],
5862
) -> None:
5963
self._tokenizer = tokenizer
@@ -62,6 +66,7 @@ def __init__(
6266
self._transform = transform
6367
self._column_map = column_map
6468
self.train_on_input = train_on_input
69+
self.max_seq_len = max_seq_len
6570

6671
def __len__(self):
6772
return len(self._data)
@@ -80,20 +85,30 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Tuple[List[int], List[in
8085
else "output"
8186
)
8287

83-
return tokenize_prompt_and_response(
88+
prompt_tokens, label_tokens = tokenize_prompt_and_response(
8489
tokenizer=self._tokenizer,
8590
prompt=prompt,
8691
response=transformed_sample[key_output],
8792
train_on_input=self.train_on_input,
8893
)
8994

95+
if self.max_seq_len is not None:
96+
prompt_tokens, label_tokens = truncate(
97+
self._tokenizer, prompt_tokens, label_tokens, self.max_seq_len
98+
)
99+
100+
assert len(prompt_tokens) == len(label_tokens)
101+
102+
return prompt_tokens, label_tokens
103+
90104

91105
def instruct_dataset(
92106
tokenizer: Tokenizer,
93107
source: str,
94108
template: str,
95109
column_map: Optional[Dict[str, str]] = None,
96110
train_on_input: bool = False,
111+
max_seq_len: Optional[int] = None,
97112
**load_dataset_kwargs: Dict[str, Any],
98113
) -> InstructDataset:
99114
"""
@@ -110,6 +125,9 @@ def instruct_dataset(
110125
column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names in the template
111126
to the column/key names in the sample. If None, assume these are identical.
112127
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
128+
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
129+
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
130+
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
113131
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to `load_dataset`.
114132
115133
Returns:
@@ -121,5 +139,6 @@ def instruct_dataset(
121139
template=_get_template(template),
122140
column_map=column_map,
123141
train_on_input=train_on_input,
142+
max_seq_len=max_seq_len,
124143
**load_dataset_kwargs,
125144
)

0 commit comments

Comments
 (0)