Skip to content

Commit 8183b42

Browse files
authored
Refactor datasets and tokenizer (#624)
1 parent 0a82ea4 commit 8183b42

23 files changed

+984
-821
lines changed

tests/test_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020
from torch import nn
21-
21+
from torchtune.modules import Tokenizer
2222

2323
skip_if_cuda_not_available = unittest.skipIf(
2424
not torch.cuda.is_available(), "CUDA is not available"
@@ -31,8 +31,11 @@
3131
"llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt",
3232
}
3333

34+
# Inherit from tokenizer class to reuse its tokenize_messages method
35+
class DummyTokenizer(Tokenizer):
36+
def __init__(self):
37+
self.encodes_whitespace = False
3438

35-
class DummyTokenizer:
3639
def encode(self, text, add_bos=True, add_eos=True, **kwargs):
3740
words = text.split()
3841
tokens = [len(word) for word in words]

tests/torchtune/config/test_config_utils.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
import pytest
1010
from torchtune.config._utils import (
1111
_get_component_from_path,
12-
_get_template,
1312
_merge_yaml_and_cli_args,
1413
InstantiationError,
1514
)
16-
from torchtune.data import AlpacaInstructTemplate
1715
from torchtune.utils.argparse import TuneRecipeArgumentParser
1816

1917
_CONFIG = {
@@ -109,33 +107,3 @@ def test_merge_yaml_and_cli_args(self, mock_load):
109107
ValueError, match="Command-line overrides must be in the form of key=value"
110108
):
111109
_ = _merge_yaml_and_cli_args(yaml_args, cli_args)
112-
113-
def test_get_template(self):
114-
# Test valid template class
115-
template = _get_template("AlpacaInstructTemplate")
116-
assert isinstance(template, AlpacaInstructTemplate)
117-
118-
# Test invalid template class
119-
with pytest.raises(
120-
ValueError,
121-
match="Must be a PromptTemplate class or a string with placeholders.",
122-
):
123-
_ = _get_template("InvalidTemplate")
124-
125-
# Test valid template strings
126-
valid_templates = [
127-
"Instruction: {instruction}\nInput: {input}",
128-
"Instruction: {instruction}",
129-
"{a}",
130-
]
131-
for template in valid_templates:
132-
assert _get_template(template) == template
133-
134-
# Test invalid template strings
135-
invalid_templates = ["hello", "{}", "a}{b"]
136-
for template in invalid_templates:
137-
with pytest.raises(
138-
ValueError,
139-
match="Must be a PromptTemplate class or a string with placeholders.",
140-
):
141-
_ = _get_template(template)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
from torchtune.data import ChatMLFormat, Llama2ChatFormat, Message, MistralChatFormat
9+
10+
# Taken from Open-Orca/SlimOrca-Dedup on HuggingFace:
11+
# https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
12+
CHAT_SAMPLE = [
13+
Message(
14+
role="system",
15+
content="You are an AI assistant. User will you give you a task. "
16+
"Your goal is to complete the task as faithfully as you can. "
17+
"While performing the task think step-by-step and justify your steps.",
18+
),
19+
Message(
20+
role="user",
21+
content="Please briefly summarize this news article:\n\nAOL.com Video - "
22+
"Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your "
23+
"8-year-old drive your car? How about on an icy road? Well one father in "
24+
"Russia did just that, and recorded the entire thing. To her credit, the "
25+
"child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , "
26+
"caught on camera , child driver , pix11\n\nSummary:",
27+
),
28+
Message(
29+
role="assistant",
30+
content="A father in Russia allowed his 8-year-old child to drive his car "
31+
"on an icy road and recorded the event. The child appeared to be handling the "
32+
"situation well, showcasing their driving skills despite the challenging conditions.",
33+
),
34+
]
35+
36+
37+
def _assert_dialogue_equal(actual, expected):
38+
assert len(actual) == len(expected)
39+
for i in range(len(actual)):
40+
assert actual[i].role == expected[i].role
41+
assert actual[i].content == expected[i].content
42+
43+
44+
class TestLlama2ChatFormat:
45+
expected_dialogue = [
46+
Message(
47+
role="user",
48+
content="[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. "
49+
"Your goal is to complete the task as faithfully as you can. While performing "
50+
"the task think step-by-step and justify your steps.\n<</SYS>>\n\nPlease "
51+
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
52+
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
53+
"How about on an icy road? Well one father in Russia did just that, and recorded "
54+
"the entire thing. To her credit, the child seemed to be doing a great job. "
55+
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
56+
"Summary: [/INST] ",
57+
),
58+
Message(
59+
role="assistant",
60+
content="A father in Russia allowed his 8-year-old child to drive his car on an "
61+
"icy road and recorded the event. The child appeared to be handling the situation well, "
62+
"showcasing their driving skills despite the challenging conditions.",
63+
),
64+
]
65+
66+
def test_format(self):
67+
actual = Llama2ChatFormat.format(CHAT_SAMPLE)
68+
_assert_dialogue_equal(actual, self.expected_dialogue)
69+
70+
71+
class TestMistralChatFormat:
72+
expected_dialogue = [
73+
Message(
74+
role="user",
75+
content="[INST] Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
76+
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
77+
"How about on an icy road? Well one father in Russia did just that, and recorded "
78+
"the entire thing. To her credit, the child seemed to be doing a great job. "
79+
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
80+
"Summary: [/INST] ",
81+
),
82+
Message(
83+
role="assistant",
84+
content="A father in Russia allowed his 8-year-old child to drive his car on an "
85+
"icy road and recorded the event. The child appeared to be handling the situation well, "
86+
"showcasing their driving skills despite the challenging conditions.",
87+
),
88+
]
89+
90+
def test_format(self):
91+
no_system_sample = CHAT_SAMPLE[1:]
92+
actual = MistralChatFormat.format(no_system_sample)
93+
_assert_dialogue_equal(actual, self.expected_dialogue)
94+
95+
def test_format_with_system_prompt_raises(self):
96+
with pytest.raises(
97+
ValueError, match="System prompts are not supported in MistralChatFormat"
98+
):
99+
_ = MistralChatFormat.format(CHAT_SAMPLE)
100+
101+
102+
class TestChatMLFormat:
103+
expected_dialogue = [
104+
Message(
105+
role="system",
106+
content="<|im_start|>system\nYou are an AI assistant. User will you give you a task. "
107+
"Your goal is to complete the task as faithfully as you can. While performing "
108+
"the task think step-by-step and justify your steps.<|im_end|>\n",
109+
),
110+
Message(
111+
role="user",
112+
content="<|im_start|>user\nPlease "
113+
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
114+
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
115+
"How about on an icy road? Well one father in Russia did just that, and recorded "
116+
"the entire thing. To her credit, the child seemed to be doing a great job. "
117+
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
118+
"Summary:<|im_end|>\n",
119+
),
120+
Message(
121+
role="assistant",
122+
content="<|im_start|>assistant\nA father in Russia allowed his 8-year-old child to drive his car on an "
123+
"icy road and recorded the event. The child appeared to be handling the situation well, "
124+
"showcasing their driving skills despite the challenging conditions.<|im_end|>",
125+
),
126+
]
127+
128+
def test_format(self):
129+
actual = ChatMLFormat.format(CHAT_SAMPLE)
130+
_assert_dialogue_equal(actual, self.expected_dialogue)

tests/torchtune/data/test_data_utils.py

Lines changed: 10 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -4,84 +4,21 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from tests.test_utils import DummyTokenizer
8-
from torchtune.data import tokenize_prompt_and_response, truncate
9-
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
10-
11-
12-
def test_tokenize_prompt_and_response():
13-
tokenizer = DummyTokenizer()
14-
prompt = "Instruction:\nThis is an instruction.\n\nInput:\nThis is an input.\n\nResponse: "
15-
response = "I always know what I'm doing, do you?"
16-
prompt_length = 12
17-
expected_tokenized_prompt = [
18-
0,
19-
12,
20-
4,
21-
2,
22-
2,
23-
12,
24-
6,
25-
4,
26-
2,
27-
2,
28-
6,
29-
9,
30-
1,
31-
6,
32-
4,
33-
4,
34-
3,
35-
6,
36-
2,
37-
4,
38-
-1,
39-
]
40-
expected_tokenized_label = [CROSS_ENTROPY_IGNORE_IDX] * prompt_length + [
41-
1,
42-
6,
43-
4,
44-
4,
45-
3,
46-
6,
47-
2,
48-
4,
49-
-1,
50-
]
51-
52-
tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
53-
tokenizer, prompt, response
54-
)
55-
assert tokenized_prompt == expected_tokenized_prompt
56-
assert tokenized_label == expected_tokenized_label
57-
58-
tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
59-
tokenizer, prompt, response, train_on_input=True
60-
)
61-
assert tokenized_prompt == expected_tokenized_prompt
62-
assert tokenized_label == expected_tokenized_prompt
7+
from torchtune.data import truncate
638

649

6510
def test_truncate():
66-
prompt_tokens = [1, 2, 3, 4, -1]
67-
label_tokens = [1, 2, 3, 4, -1]
11+
tokens = [1, 2, 3, 4, -1]
6812

6913
# Test no truncation
70-
truncated_prompt_tokens, truncated_label_tokens = truncate(
71-
tokenizer=DummyTokenizer(),
72-
prompt_tokens=prompt_tokens,
73-
label_tokens=label_tokens,
14+
truncated_tokens = truncate(
15+
tokens=tokens,
7416
max_seq_len=5,
17+
eos_id=-1,
7518
)
76-
assert truncated_prompt_tokens == prompt_tokens
77-
assert truncated_label_tokens == label_tokens
19+
assert truncated_tokens == tokens
7820

79-
# Test truncated
80-
truncated_prompt_tokens, truncated_label_tokens = truncate(
81-
tokenizer=DummyTokenizer(),
82-
prompt_tokens=prompt_tokens,
83-
label_tokens=label_tokens,
84-
max_seq_len=4,
85-
)
86-
assert truncated_prompt_tokens == [1, 2, 3, -1]
87-
assert truncated_label_tokens == [1, 2, 3, -1]
21+
masks = [True, True, False, True, False]
22+
# Test truncated mask
23+
truncated_masks = truncate(tokens=masks, max_seq_len=4, eos_id=False)
24+
assert truncated_masks == [True, True, False, False]

0 commit comments

Comments
 (0)