1212# Not ideal to import this type here but it's needed for the transform function
1313from torchtune .modules import Tokenizer
1414
15- _CROSS_ENTROPY_IGNORE_IDX = - 100
15+
16+ class Llama2ChatFormatConstants :
17+ CROSS_ENTROPY_IGNORE_IDX = - 100
18+ B_INST , E_INST = "[INST]" , "[/INST]"
19+ B_SYS , E_SYS = "<<SYS>>\n " , "\n <</SYS>>\n \n "
1620
1721
1822class SlimOrcaDataset (Dataset ):
@@ -22,10 +26,11 @@ class SlimOrcaDataset(Dataset):
2226
2327 The data is formatted to adhere to Llama2 Chat Format.
2428 This format is required if the base model is Llama2 Chat Model.
29+ The base Llama2 Model doesn't prescribe a particular format.
2530
2631 The returned data is a tuple of input token id list and label token id
2732 list. If `max_token_length` keyword argument is provided, the returned
28- input token id list is ensured (by truncation if necssary ) to be within
33+ input token id list is ensured (by truncation if necessary ) to be within
2934 that length.
3035
3136 Args:
@@ -35,7 +40,7 @@ class SlimOrcaDataset(Dataset):
3540 max sequence length accepted by the model.
3641
3742 Keyword Arguments:
38- max_token_length (int): Maximum number of tokens in the returned.
43+ max_token_length (int): Maximum number of tokens in the returned input and label token id lists .
3944 Default is 1024.
4045
4146 Data input format:
@@ -51,15 +56,16 @@ class SlimOrcaDataset(Dataset):
5156 their funeral." } ]
5257
5358 Example:
54- >>> slimorca_ds = SlimOrcaDataset(tokenizer=tokenizer)
55- >>> for batch in Dataloader(slimorca_ds, batch_size=8):
56- print(f"Batch size: {len(batch)}")
57- Batch size: 8
59+ >>> ds = SlimOrcaDataset(tokenizer=tokenizer, max_token_length=10)
60+ >>> for input, label in ds:
61+ print(input)
62+ print(label)
63+
64+ Sample Ouput:
65+ [1, 351, 82, 391, 221, 220, 193, 12, 471, ..., 2]
66+ [-100, -100, -100, -100, -100, -100, -100, -100, 471, ..., 2]
5867 """
5968
60- B_INST , E_INST = "[INST]" , "[/INST]"
61- B_SYS , E_SYS = "<<SYS>>\n " , "\n <</SYS>>\n \n "
62-
6369 def __init__ (self , tokenizer : Tokenizer , ** kwargs ) -> None :
6470 self ._data = load_dataset ("Open-Orca/SlimOrca-Dedup" , split = "train" )
6571 self ._tokenizer = tokenizer
@@ -72,18 +78,25 @@ def __init__(self, tokenizer: Tokenizer, **kwargs) -> None:
7278 def __len__ (self ):
7379 return len (self ._data )
7480
75- def prompt_with_system (self , content : str ) -> str :
76- return f"{ self .B_INST } { self .B_SYS } { content } { self .E_SYS } { self .E_INST } "
77-
78- def prompt_without_system (self , content : str ) -> str :
79- return f"{ self .B_INST } { content } { self .E_INST } "
80-
8181 def __getitem__ (self , index : int ) -> Tuple [List [int ], List [int ]]:
8282 data = self ._data [index ]["conversations" ]
83- prompt , label = self .generate_prompt_label (data )
84- return self .generate_tokens (prompt , label )
83+ prompt , label = self ._generate_prompt_label (data )
84+ return self ._generate_tokens (prompt , label )
85+
86+ def _generate_tokens (self , prompt : str , label : str ) -> Tuple [List [int ], List [int ]]:
87+ """
88+ Given a prompt string and label string, generate input and label token id lists.
89+
90+ Tokenizer is used to tokenize both the strings.
91+ The prompt token list is truncated to `max_token_length` - 2
92+ (so that there is at least one label token, as EOS takes one token).
93+
94+ The label token list is truncated to `max_token_length` - len(prompt_token_list)
95+
96+ Finally input token list is the concatenation of prompt and label token lists.
8597
86- def generate_tokens (self , prompt : str , label : str ) -> Tuple [List [int ], List [int ]]:
98+ Label token list is padded with cross entropy ignore idx value to match the length of input token list.
99+ """
87100 prompt_tokens = self ._tokenizer .encode (prompt , add_bos = True , add_eos = False )
88101 # Truncate to max token length - 2 (so that there is at least one label token)
89102 prompt_tokens = prompt_tokens [: self ._max_token_length - 2 ]
@@ -99,12 +112,16 @@ def generate_tokens(self, prompt: str, label: str) -> Tuple[List[int], List[int]
99112
100113 input = prompt_tokens + label_tokens
101114 label = [
102- _CROSS_ENTROPY_IGNORE_IDX for _ in range (len (prompt_tokens ))
115+ Llama2ChatFormatConstants .CROSS_ENTROPY_IGNORE_IDX
116+ for _ in range (len (prompt_tokens ))
103117 ] + label_tokens
104- assert len (input ) == len (label )
105118 return input , label
106119
107- def generate_prompt_label (self , data : List [Dict [str , str ]]) -> Tuple [str , str ]:
120+ def _generate_prompt_label (self , data : List [Dict [str , str ]]) -> Tuple [str , str ]:
121+ """
122+ Construct prompt and label strings adhering to Llama2 Chat Format.
123+ This method supports only back-and-forth conversation per sample (as it is sufficient for SlimOrca dataset).
124+ """
108125 agent_text_dict = {}
109126 # agents can be {system, human, gpt}
110127 for conversation in data :
@@ -113,10 +130,10 @@ def generate_prompt_label(self, data: List[Dict[str, str]]) -> Tuple[str, str]:
113130 agent_text_dict [agent ] = text
114131
115132 # Llama2 Chat Format - https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L284
116- if len ( agent_text_dict [ "system" ]) > 0 :
117- prompt = f"{ self .B_INST } { self .B_SYS } { agent_text_dict ['system' ]} { self .E_SYS } { agent_text_dict ['human' ]} { self .E_INST } "
133+ if "system" in agent_text_dict :
134+ prompt = f"{ Llama2ChatFormatConstants .B_INST } { Llama2ChatFormatConstants .B_SYS } { agent_text_dict ['system' ]} { Llama2ChatFormatConstants .E_SYS } { agent_text_dict ['human' ]} { Llama2ChatFormatConstants .E_INST } " # noqa: B950
118135 else :
119- prompt = f"{ self .B_INST } { agent_text_dict ['human' ]} { self .E_INST } "
136+ prompt = f"{ Llama2ChatFormatConstants .B_INST } { agent_text_dict ['human' ]} { Llama2ChatFormatConstants .E_INST } "
120137
121138 response = f" { agent_text_dict ['gpt' ]} "
122139 return prompt , response
0 commit comments