44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- import functools
8- from typing import Callable , List , Optional , Tuple
7+ from typing import Optional
98
109import torch
11- import torch .nn .functional as F
12- from torch import nn , Tensor
1310
14- from torchtune .modules import Tokenizer , TransformerDecoder
11+ from torchtune .modules import TransformerDecoder
1512
1613
1714def multinomial_sample_one (probs ):
1815 q = torch .empty_like (probs ).exponential_ (1 )
1916 return torch .argmax (probs / q , dim = - 1 , keepdim = True ).to (dtype = torch .int )
2017
18+
2119def sample (
22- logits : torch .Tensor ,
23- temperature : float = 1.0 ,
24- top_k : Optional [int ] = None
20+ logits : torch .Tensor , temperature : float = 1.0 , top_k : Optional [int ] = None
2521) -> torch .Tensor :
2622 # scale the logits based on temperature
2723 logits = logits / max (temperature , 1e-5 )
@@ -42,6 +38,7 @@ def sample(
4238 token = multinomial_sample_one (probs )
4339 return token
4440
41+
4542def generate_next_token (
4643 model : TransformerDecoder ,
4744 input_pos : torch .Tensor ,
@@ -82,10 +79,17 @@ def generate(
8279 max_generated_tokens (int): number of tokens to be generated. This is the max
8380 since we can stop early based on whether the eos token is respected or not
8481 temperature (float): value to scale the predicted logits by. Default is 1.0
85- topk (Optional[int]): If specified, we prune the sampling to only token ids within
82+ top_k (Optional[int]): If specified, we prune the sampling to only token ids within
8683 the top_k probabilities. Default is None
8784 eos_id (Optional[int]): If specified, generation is stopped when the eos token is
88- generated
85+ generated. Default is None
86+
87+ Returns:
88+ List: list of generated tokens
89+
90+ Raises:
91+ ValueError: if max_seq_len supported by the model is smaller than the number of tokens
92+ requested
8993 """
9094
9195 prompt_length = prompt .size (0 )
@@ -121,7 +125,7 @@ def generate(
121125 input_pos = input_pos ,
122126 x = token .view (1 , - 1 ),
123127 temperature = temperature ,
124- top_k = top_k
128+ top_k = top_k ,
125129 ).clone ()
126130
127131 generated_tokens .append (token )
0 commit comments