1111
1212from torchtune .config ._utils import _get_template
1313
14- from torchtune .data import PromptTemplate , tokenize_prompt_and_response
14+ from torchtune .data import PromptTemplate , tokenize_prompt_and_response , truncate
1515from torchtune .modules import Tokenizer
1616
1717
@@ -43,6 +43,9 @@ class InstructDataset(Dataset):
4343 column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names in the template
4444 to the column/key names in the sample. If None, assume these are identical.
4545 train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
46+ max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
47+ Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
48+ and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
4649 **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to `load_dataset`.
4750 """
4851
@@ -54,6 +57,7 @@ def __init__(
5457 transform : Optional [Callable ] = None ,
5558 column_map : Optional [Dict [str , str ]] = None ,
5659 train_on_input : bool = False ,
60+ max_seq_len : Optional [int ] = None ,
5761 ** load_dataset_kwargs : Dict [str , Any ],
5862 ) -> None :
5963 self ._tokenizer = tokenizer
@@ -62,6 +66,7 @@ def __init__(
6266 self ._transform = transform
6367 self ._column_map = column_map
6468 self .train_on_input = train_on_input
69+ self .max_seq_len = max_seq_len
6570
6671 def __len__ (self ):
6772 return len (self ._data )
@@ -80,20 +85,30 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Tuple[List[int], List[in
8085 else "output"
8186 )
8287
83- return tokenize_prompt_and_response (
88+ prompt_tokens , label_tokens = tokenize_prompt_and_response (
8489 tokenizer = self ._tokenizer ,
8590 prompt = prompt ,
8691 response = transformed_sample [key_output ],
8792 train_on_input = self .train_on_input ,
8893 )
8994
95+ if self .max_seq_len is not None :
96+ prompt_tokens , label_tokens = truncate (
97+ self ._tokenizer , prompt_tokens , label_tokens , self .max_seq_len
98+ )
99+
100+ assert len (prompt_tokens ) == len (label_tokens )
101+
102+ return prompt_tokens , label_tokens
103+
90104
91105def instruct_dataset (
92106 tokenizer : Tokenizer ,
93107 source : str ,
94108 template : str ,
95109 column_map : Optional [Dict [str , str ]] = None ,
96110 train_on_input : bool = False ,
111+ max_seq_len : Optional [int ] = None ,
97112 ** load_dataset_kwargs : Dict [str , Any ],
98113) -> InstructDataset :
99114 """
@@ -110,6 +125,9 @@ def instruct_dataset(
110125 column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names in the template
111126 to the column/key names in the sample. If None, assume these are identical.
112127 train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
128+ max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
129+ Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
130+ and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
113131 **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to `load_dataset`.
114132
115133 Returns:
@@ -121,5 +139,6 @@ def instruct_dataset(
121139 template = _get_template (template ),
122140 column_map = column_map ,
123141 train_on_input = train_on_input ,
142+ max_seq_len = max_seq_len ,
124143 ** load_dataset_kwargs ,
125144 )
0 commit comments