Skip to content

Commit 3c65c65

Browse files
krammnicmori360
authored andcommitted
Fix eos_token problem in all required models (meta-pytorch#1806)
1 parent 4d0096a commit 3c65c65

File tree

8 files changed

+889
-7
lines changed

8 files changed

+889
-7
lines changed

tests/torchtune/models/gemma/test_gemma_tokenizer.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,232 @@ def test_tokenize_messages(self, tokenizer):
242242
expected_mask = [True] * 75 + [False] * 125
243243
assert expected_tokens == tokens
244244
assert expected_mask == mask
245+
246+
def test_tokenize_messages_drop_eos(self, tokenizer):
247+
messages = [
248+
Message(
249+
role="user",
250+
content="Below is an instruction that describes a task. Write a response "
251+
"that appropriately completes the request.\n\n### Instruction:\nGenerate "
252+
"a realistic dating profile bio.\n\n### Response:\n",
253+
masked=True,
254+
),
255+
Message(
256+
role="assistant",
257+
content="I'm an outgoing and friendly person who loves spending time with "
258+
"friends and family. I'm also a big-time foodie and love trying out new "
259+
"restaurants and different cuisines. I'm a big fan of the arts and enjoy "
260+
"going to museums and galleries. I'm looking for someone who shares my "
261+
"interest in exploring new places, as well as someone who appreciates a "
262+
"good conversation over coffee.",
263+
),
264+
]
265+
tokens, mask = tokenizer.tokenize_messages(messages, add_eos=False)
266+
expected_tokens = [
267+
1,
268+
323,
269+
418,
270+
202,
271+
31,
272+
128,
273+
15,
274+
120,
275+
47,
276+
88,
277+
584,
278+
23,
279+
1665,
280+
182,
281+
9,
282+
434,
283+
295,
284+
85,
285+
4,
286+
780,
287+
47,
288+
636,
289+
9,
290+
1094,
291+
213,
292+
23,
293+
9,
294+
69,
295+
69,
296+
164,
297+
1153,
298+
299,
299+
35,
300+
961,
301+
132,
302+
237,
303+
7,
304+
5,
305+
761,
306+
4,
307+
12,
308+
0,
309+
313,
310+
120,
311+
47,
312+
88,
313+
584,
314+
166,
315+
493,
316+
171,
317+
54,
318+
299,
319+
9,
320+
906,
321+
244,
322+
19,
323+
186,
324+
767,
325+
303,
326+
671,
327+
92,
328+
209,
329+
24,
330+
190,
331+
52,
332+
38,
333+
4,
334+
12,
335+
0,
336+
1243,
337+
7,
338+
69,
339+
135,
340+
213,
341+
166,
342+
6,
343+
21,
344+
45,
345+
128,
346+
71,
347+
58,
348+
38,
349+
14,
350+
10,
351+
652,
352+
35,
353+
462,
354+
101,
355+
1306,
356+
7,
357+
341,
358+
171,
359+
20,
360+
14,
361+
127,
362+
26,
363+
652,
364+
7,
365+
10,
366+
1268,
367+
4,
368+
6,
369+
21,
370+
45,
371+
591,
372+
9,
373+
566,
374+
22,
375+
994,
376+
913,
377+
38,
378+
20,
379+
52,
380+
24,
381+
10,
382+
1306,
383+
734,
384+
14,
385+
71,
386+
365,
387+
1382,
388+
7,
389+
10,
390+
801,
391+
105,
392+
88,
393+
244,
394+
985,
395+
7,
396+
4,
397+
6,
398+
21,
399+
45,
400+
9,
401+
566,
402+
126,
403+
180,
404+
11,
405+
5,
406+
1137,
407+
7,
408+
10,
409+
1089,
410+
151,
411+
8,
412+
1156,
413+
213,
414+
342,
415+
7,
416+
10,
417+
384,
418+
104,
419+
54,
420+
470,
421+
4,
422+
6,
423+
21,
424+
45,
425+
287,
426+
14,
427+
33,
428+
125,
429+
135,
430+
24,
431+
101,
432+
512,
433+
66,
434+
7,
435+
28,
436+
822,
437+
15,
438+
542,
439+
69,
440+
59,
441+
110,
442+
14,
443+
365,
444+
229,
445+
7,
446+
3,
447+
36,
448+
267,
449+
36,
450+
125,
451+
135,
452+
24,
453+
101,
454+
1503,
455+
182,
456+
9,
457+
222,
458+
1661,
459+
191,
460+
332,
461+
92,
462+
92,
463+
24,
464+
24,
465+
4,
466+
2,
467+
]
468+
# Drop eos token.
469+
expected_tokens = expected_tokens[:-1]
470+
# On 1 less then with eos
471+
expected_mask = [True] * 75 + [False] * 124
472+
assert expected_tokens == tokens
473+
assert expected_mask == mask

0 commit comments

Comments
 (0)