Skip to content

Commit f9ec55d

Browse files
committed
bug fix and some cleanup
1 parent c29e8cb commit f9ec55d

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

tests/torchtune/datasets/test_slimorca_dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len):
4848
]
4949
}
5050
]
51-
ds = slimorca_dataset(tokenizer=tokenizer, max_seq_len=max_seq_len)
51+
ds = slimorca_dataset(
52+
tokenizer=tokenizer,
53+
max_seq_len=max_seq_len,
54+
train_on_input=(max_seq_len == 128),
55+
)
5256
input, label = ds[0]
5357
assert len(input) <= max_seq_len
5458
assert len(label) <= max_seq_len

torchtune/data/_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,6 @@ def sharegpt_to_llama2_messages(
4949
for message in conversations:
5050
role = role_map[message["from"]]
5151
content = message["value"]
52-
masked = (role != "assistant") and train_on_input
52+
masked = (role != "assistant") and (not train_on_input)
5353
messages.append(Message(role=role, content=content, masked=masked))
5454
return messages

torchtune/datasets/_chat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Tuple[List[int], List[in
8181
tokens, mask = self._tokenizer.tokenize_messages(
8282
messages, max_seq_len=self.max_seq_len
8383
)
84-
labels = list(np.where(np.logical_not(mask), tokens, CROSS_ENTROPY_IGNORE_IDX))
85-
84+
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
85+
labels = list(np.where(mask, CROSS_ENTROPY_IGNORE_IDX, tokens))
8686
assert len(tokens) == len(labels)
8787

8888
return tokens, labels

torchtune/datasets/_instruct.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Tuple[List[int], List[in
9393
tokens, mask = self._tokenizer.tokenize_messages(
9494
messages, max_seq_len=self.max_seq_len
9595
)
96-
labels = list(np.where(np.logical_not(mask), tokens, CROSS_ENTROPY_IGNORE_IDX))
96+
97+
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
98+
labels = list(np.where(mask, CROSS_ENTROPY_IGNORE_IDX, tokens))
9799
assert len(tokens) == len(labels)
98100

99101
return tokens, labels

0 commit comments

Comments
 (0)