10
10
from torch import nn
11
11
from transformers import BatchFeature , PretrainedConfig
12
12
from transformers .models .cohere2_vision import Cohere2VisionConfig
13
+ from transformers .models .cohere2_vision .image_processing_cohere2_vision_fast import ( # noqa: E501
14
+ get_optimal_tiled_canvas )
13
15
from transformers .models .cohere2_vision .processing_cohere2_vision import (
14
16
Cohere2VisionProcessor )
15
17
@@ -150,14 +152,46 @@ def get_image_size_with_most_features(self) -> ImageSize:
150
152
max_patches = image_processor .max_patches
151
153
return ImageSize (height = height * max_patches , width = width )
152
154
153
- def get_num_patches (self , image_width : int , image_height : int ) -> int :
155
+ def get_num_patches (
156
+ self ,
157
+ * ,
158
+ image_width : int ,
159
+ image_height : int ,
160
+ processor : Optional [Cohere2VisionProcessor ],
161
+ ) -> int :
154
162
"""
155
163
Calculate the number of image patches for a given image.
156
164
Uses the HF processor to determine the actual number of patches.
157
165
"""
158
- return self .get_hf_processor (
159
- ).image_processor .get_number_of_image_patches (image_height ,
160
- image_width , {})
166
+ if processor is None :
167
+ processor = self .get_hf_processor ()
168
+
169
+ image_processor = processor .image_processor
170
+
171
+ # The current implementation of get_number_of_image_patches
172
+ # is incorrect, so we patch it here.
173
+ # return image_processor.get_number_of_image_patches(image_height,
174
+ # image_width, {})
175
+
176
+ min_patches = image_processor .min_patches
177
+ max_patches = image_processor .max_patches
178
+ patch_size = image_processor .size
179
+ crop_to_patches = image_processor .crop_to_patches
180
+
181
+ if not crop_to_patches :
182
+ return 1
183
+
184
+ num_columns , num_rows = get_optimal_tiled_canvas (
185
+ (image_height , image_width ),
186
+ (patch_size ["height" ], patch_size ["width" ]),
187
+ min_patches ,
188
+ max_patches ,
189
+ )
190
+ num_patches = num_columns * num_rows
191
+ if num_patches > 1 :
192
+ num_patches += 1 # Thumbnail image
193
+
194
+ return num_patches
161
195
162
196
163
197
class Cohere2VisionDummyInputsBuilder (
@@ -208,6 +242,8 @@ def _call_hf_processor(
208
242
# Ensure num_patches is available for proper tensor splitting
209
243
if "num_patches" not in processed_outputs and (
210
244
images := mm_data .get ("images" )) is not None :
245
+ hf_processor = self .info .get_hf_processor (** mm_kwargs )
246
+
211
247
# Fallback calculation if HF processor didn't provide num_patches
212
248
parsed_images = self ._get_data_parser ().parse_mm_data ({
213
249
"image" :
@@ -217,8 +253,9 @@ def _call_hf_processor(
217
253
num_patches = [
218
254
self .info .get_num_patches (
219
255
image_width = parsed_images .get_image_size (i ).width ,
220
- image_height = parsed_images .get_image_size (i ).height )
221
- for i in range (len (parsed_images ))
256
+ image_height = parsed_images .get_image_size (i ).height ,
257
+ processor = hf_processor ,
258
+ ) for i in range (len (parsed_images ))
222
259
]
223
260
processed_outputs ["num_patches" ] = torch .tensor (num_patches )
224
261
@@ -245,25 +282,25 @@ def _get_prompt_updates(
245
282
) -> Sequence [PromptUpdate ]:
246
283
hf_processor = self .info .get_hf_processor (** hf_processor_mm_kwargs )
247
284
image_token = hf_processor .image_token
285
+ img_tokens_per_tile = int (hf_processor .patch_size ** 2 )
248
286
img_line_break_token = hf_processor .img_line_break_token
249
287
boi_token = hf_processor .boi_token
250
288
eoi_token = hf_processor .eoi_token
251
289
252
290
def get_replacement (item_idx : int ):
253
- images : ImageProcessorItems = mm_items .get ("image" ,
254
- ImageProcessorItems )
291
+ images = mm_items .get_items ("image" , ImageProcessorItems )
255
292
image_size : ImageSize = images .get_image_size (item_idx )
256
293
257
- num_patches = self .info .get_num_patches (image_size . height ,
258
- image_size .width )
259
- img_tokens_per_tile = int ( hf_processor . patch_size ** 2 )
260
- single_tile_tokens = image_token * img_tokens_per_tile + \
261
- img_line_break_token
262
- img_string = f" { boi_token } \
263
- { single_tile_tokens * num_patches } \
264
- { eoi_token } "
294
+ num_patches = self .info .get_num_patches (
295
+ image_width = image_size .width ,
296
+ image_height = image_size . height ,
297
+ processor = hf_processor ,
298
+ )
299
+ patch_tokens = ( image_token * img_tokens_per_tile +
300
+ img_line_break_token )
301
+ repl = f" { boi_token } { patch_tokens * num_patches } { eoi_token } "
265
302
266
- return PromptUpdateDetails .select_text (img_string , image_token )
303
+ return PromptUpdateDetails .select_text (repl , image_token )
267
304
268
305
return [
269
306
PromptReplacement (
0 commit comments