|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import argparse |
| 5 | +import copy |
| 6 | +import os |
| 7 | +from dataclasses import dataclass |
| 8 | + |
| 9 | +import cv2 |
| 10 | +import numpy as np |
| 11 | +import regex as re |
| 12 | +from PIL import Image |
| 13 | +from transformers import DonutProcessor |
| 14 | + |
| 15 | +from vllm import LLM, SamplingParams |
| 16 | +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt |
| 17 | +from vllm.multimodal.utils import fetch_image |
| 18 | + |
| 19 | + |
| 20 | +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py |
| 21 | +@dataclass |
| 22 | +class ImageDimensions: |
| 23 | + original_w: int |
| 24 | + original_h: int |
| 25 | + padded_w: int |
| 26 | + padded_h: int |
| 27 | + |
| 28 | + |
| 29 | +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py |
| 30 | +def map_to_original_coordinates( |
| 31 | + x1, y1, x2, y2, dims: ImageDimensions |
| 32 | +) -> tuple[int, int, int, int]: |
| 33 | + try: |
| 34 | + top = (dims.padded_h - dims.original_h) // 2 |
| 35 | + left = (dims.padded_w - dims.original_w) // 2 |
| 36 | + orig_x1 = max(0, x1 - left) |
| 37 | + orig_y1 = max(0, y1 - top) |
| 38 | + orig_x2 = min(dims.original_w, x2 - left) |
| 39 | + orig_y2 = min(dims.original_h, y2 - top) |
| 40 | + if orig_x2 <= orig_x1: |
| 41 | + orig_x2 = min(orig_x1 + 1, dims.original_w) |
| 42 | + if orig_y2 <= orig_y1: |
| 43 | + orig_y2 = min(orig_y1 + 1, dims.original_h) |
| 44 | + return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2) |
| 45 | + except Exception as e: |
| 46 | + print(f"map_to_original_coordinates error: {str(e)}") |
| 47 | + return 0, 0, min(100, dims.original_w), min(100, dims.original_h) |
| 48 | + |
| 49 | + |
| 50 | +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py |
| 51 | +def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2): |
| 52 | + if isinstance(image, str): |
| 53 | + image = cv2.imread(image) |
| 54 | + img_h, img_w = image.shape[:2] |
| 55 | + new_boxes = [] |
| 56 | + for box in boxes: |
| 57 | + best_box = copy.deepcopy(box) |
| 58 | + |
| 59 | + def check_edge(img, current_box, i, is_vertical): |
| 60 | + edge = current_box[i] |
| 61 | + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
| 62 | + _, binary = cv2.threshold( |
| 63 | + gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU |
| 64 | + ) |
| 65 | + if is_vertical: |
| 66 | + line = binary[current_box[1] : current_box[3] + 1, edge] |
| 67 | + else: |
| 68 | + line = binary[edge, current_box[0] : current_box[2] + 1] |
| 69 | + transitions = np.abs(np.diff(line)) |
| 70 | + return np.sum(transitions) / len(transitions) |
| 71 | + |
| 72 | + edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)] |
| 73 | + current_box = copy.deepcopy(box) |
| 74 | + current_box[0] = min(max(current_box[0], 0), img_w - 1) |
| 75 | + current_box[1] = min(max(current_box[1], 0), img_h - 1) |
| 76 | + current_box[2] = min(max(current_box[2], 0), img_w - 1) |
| 77 | + current_box[3] = min(max(current_box[3], 0), img_h - 1) |
| 78 | + |
| 79 | + for i, direction, is_vertical in edges: |
| 80 | + best_score = check_edge(image, current_box, i, is_vertical) |
| 81 | + if best_score <= threshold: |
| 82 | + continue |
| 83 | + for step in range(max_pixels): |
| 84 | + current_box[i] += direction |
| 85 | + if i == 0 or i == 2: |
| 86 | + current_box[i] = min(max(current_box[i], 0), img_w - 1) |
| 87 | + else: |
| 88 | + current_box[i] = min(max(current_box[i], 0), img_h - 1) |
| 89 | + score = check_edge(image, current_box, i, is_vertical) |
| 90 | + if score < best_score: |
| 91 | + best_score = score |
| 92 | + best_box = copy.deepcopy(current_box) |
| 93 | + if score <= threshold: |
| 94 | + break |
| 95 | + new_boxes.append(best_box) |
| 96 | + return new_boxes |
| 97 | + |
| 98 | + |
| 99 | +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py |
| 100 | +def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None): |
| 101 | + try: |
| 102 | + x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h) |
| 103 | + x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h) |
| 104 | + x1, y1, x2, y2 = ( |
| 105 | + max(0, min(x1, dims.padded_w - 1)), |
| 106 | + max(0, min(y1, dims.padded_h - 1)), |
| 107 | + max(0, min(x2, dims.padded_w)), |
| 108 | + max(0, min(y2, dims.padded_h)), |
| 109 | + ) |
| 110 | + if x2 <= x1: |
| 111 | + x2 = min(x1 + 1, dims.padded_w) |
| 112 | + if y2 <= y1: |
| 113 | + y2 = min(y1 + 1, dims.padded_h) |
| 114 | + new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]]) |
| 115 | + x1, y1, x2, y2 = new_boxes[0] |
| 116 | + x1, y1, x2, y2 = ( |
| 117 | + max(0, min(x1, dims.padded_w - 1)), |
| 118 | + max(0, min(y1, dims.padded_h - 1)), |
| 119 | + max(0, min(x2, dims.padded_w)), |
| 120 | + max(0, min(y2, dims.padded_h)), |
| 121 | + ) |
| 122 | + if x2 <= x1: |
| 123 | + x2 = min(x1 + 1, dims.padded_w) |
| 124 | + if y2 <= y1: |
| 125 | + y2 = min(y1 + 1, dims.padded_h) |
| 126 | + if previous_box is not None: |
| 127 | + prev_x1, prev_y1, prev_x2, prev_y2 = previous_box |
| 128 | + if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1): |
| 129 | + y1 = prev_y2 |
| 130 | + y1 = min(y1, dims.padded_h - 1) |
| 131 | + if y2 <= y1: |
| 132 | + y2 = min(y1 + 1, dims.padded_h) |
| 133 | + new_previous_box = [x1, y1, x2, y2] |
| 134 | + orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates( |
| 135 | + x1, y1, x2, y2, dims |
| 136 | + ) |
| 137 | + return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box |
| 138 | + except Exception as e: |
| 139 | + print(f"process_coordinates error: {str(e)}") |
| 140 | + orig_x1, orig_y1, orig_x2, orig_y2 = ( |
| 141 | + 0, |
| 142 | + 0, |
| 143 | + min(100, dims.original_w), |
| 144 | + min(100, dims.original_h), |
| 145 | + ) |
| 146 | + return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] |
| 147 | + |
| 148 | + |
| 149 | +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py |
| 150 | +def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]: |
| 151 | + try: |
| 152 | + image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
| 153 | + original_h, original_w = image_cv.shape[:2] |
| 154 | + max_size = max(original_h, original_w) |
| 155 | + top = (max_size - original_h) // 2 |
| 156 | + bottom = max_size - original_h - top |
| 157 | + left = (max_size - original_w) // 2 |
| 158 | + right = max_size - original_w - left |
| 159 | + padded_image = cv2.copyMakeBorder( |
| 160 | + image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0) |
| 161 | + ) |
| 162 | + padded_h, padded_w = padded_image.shape[:2] |
| 163 | + dimensions = ImageDimensions( |
| 164 | + original_w=original_w, |
| 165 | + original_h=original_h, |
| 166 | + padded_w=padded_w, |
| 167 | + padded_h=padded_h, |
| 168 | + ) |
| 169 | + return padded_image, dimensions |
| 170 | + except Exception as e: |
| 171 | + print(f"prepare_image error: {str(e)}") |
| 172 | + h, w = image.height, image.width |
| 173 | + dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h) |
| 174 | + return np.zeros((h, w, 3), dtype=np.uint8), dimensions |
| 175 | + |
| 176 | + |
| 177 | +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py |
| 178 | +def parse_layout_string(bbox_str): |
| 179 | + """Parse layout string using regular expressions""" |
| 180 | + pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" |
| 181 | + matches = re.finditer(pattern, bbox_str) |
| 182 | + |
| 183 | + parsed_results = [] |
| 184 | + for match in matches: |
| 185 | + coords = [float(match.group(i)) for i in range(1, 5)] |
| 186 | + label = match.group(5).strip() |
| 187 | + parsed_results.append((coords, label)) |
| 188 | + |
| 189 | + return parsed_results |
| 190 | + |
| 191 | + |
| 192 | +model_id = "ByteDance/Dolphin" |
| 193 | + |
| 194 | +# The input image size for Dolphin is 896 x 896, |
| 195 | +# and the patch_size is 4 x 4. |
| 196 | +# Therefore, the initial number of patches is: |
| 197 | +# Height: 896 / 4 = 224 patches |
| 198 | +# Width: 896 / 4 = 224 patches |
| 199 | + |
| 200 | +# The Dolphin model uses a staged downsampling approach, |
| 201 | +# defined by the "depths": [2, 2, 14, 2] configuration. |
| 202 | +# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, |
| 203 | +# which halves the feature map's dimensions (dividing both height and width by 2). |
| 204 | +# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112. |
| 205 | +# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56. |
| 206 | +# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28. |
| 207 | + |
| 208 | +# Because vLLM needs to fill the image features with an encoder_prompt, |
| 209 | +# and the encoder_prompt will have `<pad>` tokens added when tokenized, |
| 210 | +# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783. |
| 211 | +encoder_prompt = "".join(["0"] * 783) |
| 212 | +sampling_params = SamplingParams( |
| 213 | + temperature=0.0, |
| 214 | + max_tokens=2048, |
| 215 | +) |
| 216 | + |
| 217 | +processor = DonutProcessor.from_pretrained(model_id) |
| 218 | +llm = LLM( |
| 219 | + model=model_id, |
| 220 | + dtype="float16", |
| 221 | + max_num_seqs=8, |
| 222 | + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, |
| 223 | +) |
| 224 | + |
| 225 | +parser = argparse.ArgumentParser() |
| 226 | +parser.add_argument( |
| 227 | + "--image_path", type=str, default=None, help="Path to a local image file." |
| 228 | +) |
| 229 | +args = parser.parse_args() |
| 230 | + |
| 231 | +if args.image_path: |
| 232 | + if not os.path.exists(args.image_path): |
| 233 | + raise FileNotFoundError(f"Error: File not found at {args.image_path}") |
| 234 | + image = Image.open(args.image_path).convert("RGB") |
| 235 | +else: |
| 236 | + image = fetch_image( |
| 237 | + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" |
| 238 | + ) |
| 239 | + |
| 240 | + |
| 241 | +prompt = "Parse the reading order of this document. " |
| 242 | +decoder_prompt = f"<s>{prompt}<Answer/>" |
| 243 | +decoder_prompt_tokens = TokensPrompt( |
| 244 | + prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ |
| 245 | + "input_ids" |
| 246 | + ] |
| 247 | +) |
| 248 | +enc_dec_prompt = ExplicitEncoderDecoderPrompt( |
| 249 | + encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), |
| 250 | + decoder_prompt=decoder_prompt_tokens, |
| 251 | +) |
| 252 | +layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params) |
| 253 | +layout_result_str = layout_outputs[0].outputs[0].text |
| 254 | +print(f"Layout analysis output:\n{layout_result_str}") |
| 255 | + |
| 256 | +padded_image, dims = prepare_image(image) |
| 257 | +layout_results = parse_layout_string(layout_result_str) |
| 258 | +text_table_elements = [] |
| 259 | +previous_box = None |
| 260 | +reading_order = 0 |
| 261 | +for bbox_coords, label in layout_results: |
| 262 | + if label == "fig": |
| 263 | + continue |
| 264 | + try: |
| 265 | + x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = ( |
| 266 | + process_coordinates(bbox_coords, padded_image, dims, previous_box) |
| 267 | + ) |
| 268 | + cropped = padded_image[y1:y2, x1:x2] |
| 269 | + if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: |
| 270 | + pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) |
| 271 | + prompt_ocr = ( |
| 272 | + "Parse the table in the image. " |
| 273 | + if label == "tab" |
| 274 | + else "Read text in the image. " |
| 275 | + ) |
| 276 | + text_table_elements.append( |
| 277 | + { |
| 278 | + "crop": pil_crop, |
| 279 | + "prompt": prompt_ocr, |
| 280 | + "reading_order": reading_order, |
| 281 | + } |
| 282 | + ) |
| 283 | + reading_order += 1 |
| 284 | + except Exception as e: |
| 285 | + print(f"Error processing bbox (label: {label}): {str(e)}") |
| 286 | + continue |
| 287 | + |
| 288 | +if text_table_elements: |
| 289 | + batch_prompts = [] |
| 290 | + for elem in text_table_elements: |
| 291 | + decoder_prompt_str = f"<s>{elem['prompt']}<Answer/>" |
| 292 | + decoder_prompt_tokens = TokensPrompt( |
| 293 | + prompt_token_ids=processor.tokenizer( |
| 294 | + decoder_prompt_str, add_special_tokens=False |
| 295 | + )["input_ids"] |
| 296 | + ) |
| 297 | + enc_dec_prompt = ExplicitEncoderDecoderPrompt( |
| 298 | + encoder_prompt=TextPrompt( |
| 299 | + prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]} |
| 300 | + ), |
| 301 | + decoder_prompt=decoder_prompt_tokens, |
| 302 | + ) |
| 303 | + batch_prompts.append(enc_dec_prompt) |
| 304 | + batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params) |
| 305 | + for i, output in enumerate(batch_outputs): |
| 306 | + text_table_elements[i]["text"] = output.outputs[0].text.strip() |
| 307 | + |
| 308 | +print("------" * 8) |
| 309 | +text_table_elements.sort(key=lambda x: x["reading_order"]) |
| 310 | +for elem in text_table_elements: |
| 311 | + print(elem.get("text", "")) |
0 commit comments