Skip to content

Commit 21cbe80

Browse files
committed
pad
1 parent 00d5e8a commit 21cbe80

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

machine-learning/immich_ml/models/ocr/detection.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import cv2
44
import numpy as np
55
from numpy.typing import NDArray
6-
from PIL import Image
6+
from PIL import Image, ImageOps
77
from rapidocr.ch_ppocr_det.utils import DBPostProcess
88
from rapidocr.inference_engine.base import FileInfo, InferSession
99
from rapidocr.utils.download_file import DownloadFile, DownloadFileInput
@@ -78,18 +78,15 @@ def _predict(self, inputs: Image.Image) -> TextDetectionOutput:
7878

7979
# adapted from RapidOCR
8080
def _transform(self, img: Image.Image) -> NDArray[np.float32]:
81-
if img.height < img.width:
82-
ratio = float(self.max_resolution) / img.height
81+
aspect_ratio = img.width / img.height
82+
if aspect_ratio > 1.25:
83+
target_dims = (self.max_resolution * 2, self.max_resolution)
84+
elif aspect_ratio < 0.75:
85+
target_dims = (self.max_resolution, self.max_resolution * 2)
8386
else:
84-
ratio = float(self.max_resolution) / img.width
85-
86-
resize_h = int(img.height * ratio)
87-
resize_w = int(img.width * ratio)
88-
89-
resize_h = int(round(resize_h / 32) * 32)
90-
resize_w = int(round(resize_w / 32) * 32)
91-
resized_img = img.resize((int(resize_w), int(resize_h)), resample=Image.Resampling.LANCZOS)
87+
target_dims = (self.max_resolution, self.max_resolution)
9288

89+
resized_img = ImageOps.pad(img, target_dims, color=(0, 0, 0), method=Image.Resampling.LANCZOS)
9390
img_np: NDArray[np.float32] = cv2.cvtColor(np.array(resized_img, dtype=np.float32), cv2.COLOR_RGB2BGR) # type: ignore
9491
img_np -= self.mean
9592
img_np *= self.std_inv
@@ -116,8 +113,8 @@ def sorted_boxes(self, dt_boxes: NDArray[np.float32]) -> NDArray[np.float32]:
116113
return sorted_boxes
117114

118115
def configure(self, **kwargs: Any) -> None:
119-
if (max_resolution := kwargs.get("maxResolution")) is not None:
120-
self.max_resolution = max_resolution
116+
if (max_resolution := kwargs.get("maxResolution")) is not None and max_resolution != self.max_resolution:
117+
self.max_resolution = int(round(max_resolution / 32) * 32)
121118
if (min_score := kwargs.get("minScore")) is not None:
122119
self.postprocess.box_thresh = min_score
123120
if (score_mode := kwargs.get("scoreMode")) is not None:

0 commit comments

Comments
 (0)