Skip to content

Commit 97f536b

Browse files
committed
Address PR comments
1 parent d4851e0 commit 97f536b

File tree

4 files changed

+11
-8
lines changed

4 files changed

+11
-8
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@
22
style = 'google'
33
check-return-types = 'False'
44
exclude = 'tests/torchtune/models/llama2/scripts/'
5+
6+
[tool.pytest.ini_options]
7+
addopts = ["--showlocals"] # show local variables in tracebacks

tests/test_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import math
88
import unittest
99
import uuid
10+
from pathlib import Path
1011
from typing import Any, Union
1112

1213
import torch
@@ -19,6 +20,10 @@
1920
)
2021

2122

23+
def get_assets_path():
24+
return Path(__file__).parent / "assets"
25+
26+
2227
def init_weights_with_constant(model: nn.Module, constant: float = 1.0) -> None:
2328
for p in model.parameters():
2429
nn.init.constant_(p, constant)

tests/torchtune/datasets/test_slimorca_dataset.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,21 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import random
7-
from pathlib import Path
87

98
import pytest
109

1110
from torchtune import datasets
1211
from torchtune.modules.tokenizer import Tokenizer
1312

14-
ASSETS = Path(__file__).parent.parent.parent / "assets"
13+
from tests.test_utils import get_assets_path
1514

1615

1716
class TestSlimOrcaDataset:
1817
@pytest.fixture
1918
def tokenizer(self):
2019
# m.model is a pretrained Sentencepiece model using the following command:
2120
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
22-
return Tokenizer.from_file(str(ASSETS / "m.model"))
23-
24-
def test_slim_orca_dataset(self, tokenizer):
25-
dataset = datasets.get_dataset("slimorca", tokenizer=tokenizer)
26-
assert len(dataset) == 363_491
21+
return Tokenizer.from_file(str(get_assets_path() / "m.model"))
2722

2823
def test_prompt_label_generation(self, tokenizer):
2924
dataset = datasets.get_dataset("slimorca", tokenizer=tokenizer)

torchtune/datasets/slimorca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class SlimOrcaDataset(Dataset):
4747
**kwargs: Additional keyword arguments to pass to the SlimOrca Dataset.
4848
4949
Keyword Arguments:
50-
max_token_length (int): Maximum number of tokens in the returned input and label token id lists. This value needs to be at least 4 though it is generally set it to max sequence length accepted by the model. Default is 1024.
50+
max_token_length (int): Maximum number of tokens in the returned input and label token id lists. This value needs to be at least 4 though it is generally set to max sequence length accepted by the model. Default is 1024.
5151
5252
Raises:
5353
ValueError: If `max_token_length` is less than 4.

0 commit comments

Comments
 (0)