Skip to content

Commit b0074fb

Browse files
committed
Address PR comments
1 parent d7c0682 commit b0074fb

File tree

3 files changed

+92
-32
lines changed

3 files changed

+92
-32
lines changed

tests/torchtune/datasets/test_slimorca_dataset.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import random
67
from pathlib import Path
78

89
import pytest
@@ -22,7 +23,7 @@ def tokenizer(self):
2223

2324
def test_slim_orca_dataset(self, tokenizer):
2425
dataset = datasets.get_dataset("slimorca", tokenizer=tokenizer)
25-
assert dataset
26+
assert len(dataset) == 363_491
2627

2728
def test_prompt_label_generation(self, tokenizer):
2829
dataset = datasets.get_dataset("slimorca", tokenizer=tokenizer)
@@ -40,15 +41,35 @@ def test_prompt_label_generation(self, tokenizer):
4041
"value": "lo",
4142
},
4243
]
43-
prompt, label = dataset.generate_prompt_label(sample)
44-
assert prompt == "[INST] <<SYS>>\nhi\n<</SYS>>\n\nmid [/INST]"
44+
prompt, label = dataset._generate_prompt_label(sample)
45+
assert (
46+
prompt
47+
== f"{datasets.Llama2ChatFormatConstants.B_INST} {datasets.Llama2ChatFormatConstants.B_SYS}hi{datasets.Llama2ChatFormatConstants.E_SYS}mid {datasets.Llama2ChatFormatConstants.E_INST}" # noqa: B950
48+
)
49+
assert label == " lo "
50+
51+
sample = [
52+
{
53+
"from": "human",
54+
"value": "mid",
55+
},
56+
{
57+
"from": "gpt",
58+
"value": "lo",
59+
},
60+
]
61+
prompt, label = dataset._generate_prompt_label(sample)
62+
assert (
63+
prompt
64+
== f"{datasets.Llama2ChatFormatConstants.B_INST} mid {datasets.Llama2ChatFormatConstants.E_INST}"
65+
)
4566
assert label == " lo "
4667

4768
def test_token_generation(self, tokenizer):
4869
dataset = datasets.get_dataset(
4970
"slimorca", tokenizer=tokenizer, max_token_length=4096
5071
)
51-
input, label = dataset.generate_tokens("Hello ", "world!")
72+
input, label = dataset._generate_tokens("Hello ", "world!")
5273
assert input == [tokenizer.bos_id, 12, 1803, 1024, 103, tokenizer.eos_id]
5374
assert label == ([-100] * 3 + [1024, 103, tokenizer.eos_id])
5475

@@ -57,7 +78,7 @@ def test_truncated_token_generation(self, tokenizer):
5778
"slimorca", tokenizer=tokenizer, max_token_length=5
5879
)
5980
# 5 is enough for full prompt, but not for label
60-
input, label = dataset.generate_tokens("Hello ", "world!")
81+
input, label = dataset._generate_tokens("Hello ", "world!")
6182
assert input == [tokenizer.bos_id, 12, 1803, 1024, tokenizer.eos_id]
6283
assert label == ([-100] * 3 + [1024, tokenizer.eos_id])
6384

@@ -66,10 +87,32 @@ def test_truncated_token_generation(self, tokenizer):
6687
dataset = datasets.get_dataset(
6788
"slimorca", tokenizer=tokenizer, max_token_length=4
6889
)
69-
input, label = dataset.generate_tokens("Hello ", "world!")
90+
input, label = dataset._generate_tokens("Hello ", "world!")
7091
assert input == [tokenizer.bos_id, 12, 1024, tokenizer.eos_id]
7192
assert label == ([-100] * 2 + [1024, tokenizer.eos_id])
7293

7394
def test_value_error(self, tokenizer):
7495
with pytest.raises(ValueError):
7596
datasets.get_dataset("slimorca", tokenizer=tokenizer, max_token_length=3)
97+
98+
@pytest.mark.parametrize("max_token_length", [128, 512, 1024, 4096])
99+
def test_dataset_get_item(self, tokenizer, max_token_length):
100+
ds = datasets.get_dataset(
101+
"slimorca", tokenizer=tokenizer, max_token_length=max_token_length
102+
)
103+
index = random.randint(0, len(ds))
104+
input, label = ds[index]
105+
assert (
106+
len(input) <= max_token_length
107+
), f"{index} in slimorca fails input token length check"
108+
assert (
109+
len(label) <= max_token_length
110+
), f"{index} in slimorca fails label token length check"
111+
assert len(input) == len(
112+
label
113+
), f"{index} in slimorca fails token lists equality check"
114+
assert input[0] == tokenizer.bos_id, f"{index} in slimorca fails bos check"
115+
assert input[-1] == tokenizer.eos_id, f"{index} in slimorca fails eos check"
116+
assert (
117+
label[-1] == tokenizer.eos_id
118+
), f"{index} in slimorca fails label eos check"

torchtune/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.utils.data import Dataset
88

99
from .alpaca import AlpacaDataset
10-
from .slimorca import SlimOrcaDataset
10+
from .slimorca import Llama2ChatFormatConstants, SlimOrcaDataset # noqa
1111

1212
_DATASET_DICT = {"alpaca": AlpacaDataset, "slimorca": SlimOrcaDataset}
1313

torchtune/datasets/slimorca.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# Not ideal to import this type here but it's needed for the transform function
1313
from torchtune.modules import Tokenizer
1414

15-
_CROSS_ENTROPY_IGNORE_IDX = -100
15+
16+
class Llama2ChatFormatConstants:
17+
CROSS_ENTROPY_IGNORE_IDX = -100
18+
B_INST, E_INST = "[INST]", "[/INST]"
19+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
1620

1721

1822
class SlimOrcaDataset(Dataset):
@@ -22,10 +26,11 @@ class SlimOrcaDataset(Dataset):
2226
2327
The data is formatted to adhere to Llama2 Chat Format.
2428
This format is required if the base model is Llama2 Chat Model.
29+
The base Llama2 Model doesn't prescribe a particular format.
2530
2631
The returned data is a tuple of input token id list and label token id
2732
list. If `max_token_length` keyword argument is provided, the returned
28-
input token id list is ensured (by truncation if necssary) to be within
33+
input token id list is ensured (by truncation if necessary) to be within
2934
that length.
3035
3136
Args:
@@ -35,7 +40,7 @@ class SlimOrcaDataset(Dataset):
3540
max sequence length accepted by the model.
3641
3742
Keyword Arguments:
38-
max_token_length (int): Maximum number of tokens in the returned.
43+
max_token_length (int): Maximum number of tokens in the returned input and label token id lists.
3944
Default is 1024.
4045
4146
Data input format:
@@ -51,15 +56,16 @@ class SlimOrcaDataset(Dataset):
5156
their funeral." } ]
5257
5358
Example:
54-
>>> slimorca_ds = SlimOrcaDataset(tokenizer=tokenizer)
55-
>>> for batch in Dataloader(slimorca_ds, batch_size=8):
56-
print(f"Batch size: {len(batch)}")
57-
Batch size: 8
59+
>>> ds = SlimOrcaDataset(tokenizer=tokenizer, max_token_length=10)
60+
>>> for input, label in ds:
61+
print(input)
62+
print(label)
63+
64+
Sample Ouput:
65+
[1, 351, 82, 391, 221, 220, 193, 12, 471, ..., 2]
66+
[-100, -100, -100, -100, -100, -100, -100, -100, 471, ..., 2]
5867
"""
5968

60-
B_INST, E_INST = "[INST]", "[/INST]"
61-
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
62-
6369
def __init__(self, tokenizer: Tokenizer, **kwargs) -> None:
6470
self._data = load_dataset("Open-Orca/SlimOrca-Dedup", split="train")
6571
self._tokenizer = tokenizer
@@ -72,18 +78,25 @@ def __init__(self, tokenizer: Tokenizer, **kwargs) -> None:
7278
def __len__(self):
7379
return len(self._data)
7480

75-
def prompt_with_system(self, content: str) -> str:
76-
return f"{self.B_INST} {self.B_SYS}{content}{self.E_SYS} {self.E_INST}"
77-
78-
def prompt_without_system(self, content: str) -> str:
79-
return f"{self.B_INST} {content} {self.E_INST}"
80-
8181
def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
8282
data = self._data[index]["conversations"]
83-
prompt, label = self.generate_prompt_label(data)
84-
return self.generate_tokens(prompt, label)
83+
prompt, label = self._generate_prompt_label(data)
84+
return self._generate_tokens(prompt, label)
85+
86+
def _generate_tokens(self, prompt: str, label: str) -> Tuple[List[int], List[int]]:
87+
"""
88+
Given a prompt string and label string, generate input and label token id lists.
89+
90+
Tokenizer is used to tokenize both the strings.
91+
The prompt token list is truncated to `max_token_length` - 2
92+
(so that there is at least one label token, as EOS takes one token).
93+
94+
The label token list is truncated to `max_token_length` - len(prompt_token_list)
95+
96+
Finally input token list is the concatenation of prompt and label token lists.
8597
86-
def generate_tokens(self, prompt: str, label: str) -> Tuple[List[int], List[int]]:
98+
Label token list is padded with cross entropy ignore idx value to match the length of input token list.
99+
"""
87100
prompt_tokens = self._tokenizer.encode(prompt, add_bos=True, add_eos=False)
88101
# Truncate to max token length - 2 (so that there is at least one label token)
89102
prompt_tokens = prompt_tokens[: self._max_token_length - 2]
@@ -99,12 +112,16 @@ def generate_tokens(self, prompt: str, label: str) -> Tuple[List[int], List[int]
99112

100113
input = prompt_tokens + label_tokens
101114
label = [
102-
_CROSS_ENTROPY_IGNORE_IDX for _ in range(len(prompt_tokens))
115+
Llama2ChatFormatConstants.CROSS_ENTROPY_IGNORE_IDX
116+
for _ in range(len(prompt_tokens))
103117
] + label_tokens
104-
assert len(input) == len(label)
105118
return input, label
106119

107-
def generate_prompt_label(self, data: List[Dict[str, str]]) -> Tuple[str, str]:
120+
def _generate_prompt_label(self, data: List[Dict[str, str]]) -> Tuple[str, str]:
121+
"""
122+
Construct prompt and label strings adhering to Llama2 Chat Format.
123+
This method supports only back-and-forth conversation per sample (as it is sufficient for SlimOrca dataset).
124+
"""
108125
agent_text_dict = {}
109126
# agents can be {system, human, gpt}
110127
for conversation in data:
@@ -113,10 +130,10 @@ def generate_prompt_label(self, data: List[Dict[str, str]]) -> Tuple[str, str]:
113130
agent_text_dict[agent] = text
114131

115132
# Llama2 Chat Format - https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L284
116-
if len(agent_text_dict["system"]) > 0:
117-
prompt = f"{self.B_INST} {self.B_SYS}{agent_text_dict['system']}{self.E_SYS}{agent_text_dict['human']} {self.E_INST}"
133+
if "system" in agent_text_dict:
134+
prompt = f"{Llama2ChatFormatConstants.B_INST} {Llama2ChatFormatConstants.B_SYS}{agent_text_dict['system']}{Llama2ChatFormatConstants.E_SYS}{agent_text_dict['human']} {Llama2ChatFormatConstants.E_INST}" # noqa: B950
118135
else:
119-
prompt = f"{self.B_INST} {agent_text_dict['human']} {self.E_INST}"
136+
prompt = f"{Llama2ChatFormatConstants.B_INST} {agent_text_dict['human']} {Llama2ChatFormatConstants.E_INST}"
120137

121138
response = f" {agent_text_dict['gpt']} "
122139
return prompt, response

0 commit comments

Comments
 (0)