Skip to content

Commit 24e9ff0

Browse files
author
Kartikay Khandelwal
committed
lint
1 parent def4340 commit 24e9ff0

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

torchtune/utils/_generation.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,20 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import functools
8-
from typing import Callable, List, Optional, Tuple
7+
from typing import Optional
98

109
import torch
11-
import torch.nn.functional as F
12-
from torch import nn, Tensor
1310

14-
from torchtune.modules import Tokenizer, TransformerDecoder
11+
from torchtune.modules import TransformerDecoder
1512

1613

1714
def multinomial_sample_one(probs):
1815
q = torch.empty_like(probs).exponential_(1)
1916
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
2017

18+
2119
def sample(
22-
logits: torch.Tensor,
23-
temperature: float = 1.0,
24-
top_k: Optional[int] = None
20+
logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None
2521
) -> torch.Tensor:
2622
# scale the logits based on temperature
2723
logits = logits / max(temperature, 1e-5)
@@ -42,6 +38,7 @@ def sample(
4238
token = multinomial_sample_one(probs)
4339
return token
4440

41+
4542
def generate_next_token(
4643
model: TransformerDecoder,
4744
input_pos: torch.Tensor,
@@ -82,10 +79,17 @@ def generate(
8279
max_generated_tokens (int): number of tokens to be generated. This is the max
8380
since we can stop early based on whether the eos token is respected or not
8481
temperature (float): value to scale the predicted logits by. Default is 1.0
85-
topk (Optional[int]): If specified, we prune the sampling to only token ids within
82+
top_k (Optional[int]): If specified, we prune the sampling to only token ids within
8683
the top_k probabilities. Default is None
8784
eos_id (Optional[int]): If specified, generation is stopped when the eos token is
88-
generated
85+
generated. Default is None
86+
87+
Returns:
88+
List: list of generated tokens
89+
90+
Raises:
91+
ValueError: if max_seq_len supported by the model is smaller than the number of tokens
92+
requested
8993
"""
9094

9195
prompt_length = prompt.size(0)
@@ -121,7 +125,7 @@ def generate(
121125
input_pos=input_pos,
122126
x=token.view(1, -1),
123127
temperature=temperature,
124-
top_k=top_k
128+
top_k=top_k,
125129
).clone()
126130

127131
generated_tokens.append(token)

0 commit comments

Comments
 (0)