|
2 | 2 | import os
|
3 | 3 | from abc import ABC, abstractmethod
|
4 | 4 | from dataclasses import dataclass, field, fields
|
5 |
| -from typing import List, NamedTuple, Optional, Tuple, Union,Dict |
| 5 | +from typing import List, NamedTuple, Optional, Tuple, Union, Dict |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | from pydantic import BaseModel
|
@@ -112,25 +112,47 @@ def __init__(self, logit_bias: Dict[str, float]) -> None:
|
112 | 112 | super().__init__()
|
113 | 113 | self.logit_bias = logit_bias
|
114 | 114 | self.tokens_to_adjust = {}
|
| 115 | + try: |
| 116 | + self.tokens_to_adjust = self.process_logit_bias(logit_bias) |
| 117 | + except ValueError as e: |
| 118 | + logger.error(e) |
| 119 | + raise |
| 120 | + |
| 121 | + def process_logit_bias(self,logit_bias: Dict[str, float]) -> Dict[int, float]: |
| 122 | + valid = {} |
| 123 | + invalid = {} |
| 124 | + |
115 | 125 | for k, v in logit_bias.items():
|
116 | 126 | try:
|
117 | 127 | token_id = int(k)
|
118 |
| - self.tokens_to_adjust[token_id] = v |
| 128 | + valid[token_id] = v |
119 | 129 | except (ValueError, TypeError):
|
120 |
| - continue |
121 |
| - |
| 130 | + invalid[k] = v |
| 131 | + |
| 132 | + if invalid: |
| 133 | + raise ValueError( |
| 134 | + f"Invalid token_ids in logit_bias: {list(invalid.keys())}. " |
| 135 | + f"All keys must be integers." |
| 136 | + ) |
| 137 | + return valid |
| 138 | + |
122 | 139 | def __call__(self, req_id: int, logits: torch.Tensor,
|
123 | 140 | token_ids: List[List[int]], stream_ptr: Optional[int],
|
124 | 141 | client_id: Optional[int]) -> None:
|
125 | 142 |
|
126 | 143 | if self.tokens_to_adjust:
|
| 144 | + vocab_size = logits.size(-1) |
127 | 145 | token_ids_list = list(self.tokens_to_adjust.keys())
|
128 | 146 | bias_values = torch.tensor(
|
129 |
| - [self.tokens_to_adjust[token] for token in token_ids_list], |
130 |
| - device=logits.device, |
131 |
| - dtype=logits.dtype |
| 147 | + list(self.tokens_to_adjust.values()) |
132 | 148 | )
|
133 |
| - |
| 149 | + |
| 150 | + invalid_token_ids = [tid for tid in token_ids_list if tid >= vocab_size] |
| 151 | + if invalid_token_ids: |
| 152 | + raise ValueError( |
| 153 | + f"Token ID(s) {invalid_token_ids} exceed vocabulary size (vocab_size={vocab_size})" |
| 154 | + ) |
| 155 | + |
134 | 156 | stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
|
135 | 157 | with torch.cuda.stream(stream):
|
136 | 158 | logits[:, :, token_ids_list] += bias_values
|
|
0 commit comments