Skip to content

Commit 5776a55

Browse files
authored
feat(return_word_box): support word box of English and number text (#423)
* chore: update files * chore: update files * chore: update files * chore: update files * chore: update files * chore: update files * chore: format code * chore: format code * chore: format code
1 parent f3f6f93 commit 5776a55

File tree

9 files changed

+314
-231
lines changed

9 files changed

+314
-231
lines changed

python/demo.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: [email protected]
4+
from pathlib import Path
5+
46
from rapidocr import RapidOCR
57

6-
engine = RapidOCR(params={"Global.with_torch": True})
8+
engine = RapidOCR()
9+
10+
# img_url = "https://img1.baidu.com/it/u=3619974146,1266987475&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=516"
11+
12+
# img_paths = ["tmp/en.jpg", "tmp/ch.jpg", "tmp/ch_en.jpg"]
713

8-
img_url = "https://img1.baidu.com/it/u=3619974146,1266987475&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=516"
9-
result = engine(img_url)
10-
print(result)
14+
# for img_path in img_paths:
15+
# result = engine(img_path, return_word_box=True)
16+
# # print(result)
1117

12-
result.vis("vis_result.jpg")
18+
# result.vis(f"tmp/vis_{Path(img_path).stem}.jpg")
19+
img_path = "tmp/ch.jpg"
20+
result = engine(img_path, return_word_box=True)
21+
result.vis(f"tmp/vis_{Path(img_path).stem}.jpg")

python/rapidocr/cal_rec_boxes/main.py

Lines changed: 133 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,40 @@
33
# @Contact: [email protected]
44
import copy
55
import math
6-
from typing import List, Optional, Tuple
6+
from enum import Enum
7+
from typing import List, Tuple
78

89
import cv2
910
import numpy as np
1011

11-
from ..ch_ppocr_rec.utils import TextRecOutput
12+
from ..ch_ppocr_rec.typings import TextRecOutput, WordInfo, WordType
13+
from ..utils.utils import quads_to_rect_bbox
14+
15+
16+
class Direction(Enum):
17+
HORIZONTAL = "horizontal_direct" # 水平
18+
VERTICAL = "vertical_direct" # 垂直
1219

1320

1421
class CalRecBoxes:
1522
"""计算识别文字的汉字单字和英文单词的坐标框。
1623
代码借鉴自PaddlePaddle/PaddleOCR和fanqie03/char-detection"""
1724

18-
def __init__(self):
19-
pass
20-
2125
def __call__(
22-
self,
23-
imgs: Optional[List[np.ndarray]],
24-
dt_boxes: Optional[List[np.ndarray]],
25-
rec_res: TextRecOutput,
26+
self, imgs: List[np.ndarray], dt_boxes: List[np.ndarray], rec_res: TextRecOutput
2627
) -> TextRecOutput:
2728
word_results = []
2829
for idx, (img, box) in enumerate(zip(imgs, dt_boxes)):
29-
direction = self.get_box_direction(box)
30-
31-
rec_txt = rec_res.txts[idx]
32-
rec_word_info = rec_res.word_results[idx]
30+
if rec_res.txts is None:
31+
continue
3332

3433
h, w = img.shape[:2]
3534
img_box = np.array([[0, 0], [w, 0], [w, h], [0, h]])
3635
word_box_content_list, word_box_list, conf_list = self.cal_ocr_word_box(
37-
rec_txt, img_box, rec_word_info
36+
rec_res.txts[idx], img_box, rec_res.word_results[idx]
3837
)
3938
word_box_list = self.adjust_box_overlap(copy.deepcopy(word_box_list))
39+
direction = self.get_box_direction(box)
4040
word_box_list = self.reverse_rotate_crop_image(
4141
copy.deepcopy(box), word_box_list, direction
4242
)
@@ -48,97 +48,128 @@ def __call__(
4848
return rec_res
4949

5050
@staticmethod
51-
def get_box_direction(box: np.ndarray) -> str:
52-
direction = "w"
53-
img_crop_width = int(
54-
max(
55-
np.linalg.norm(box[0] - box[1]),
56-
np.linalg.norm(box[2] - box[3]),
57-
)
58-
)
59-
img_crop_height = int(
60-
max(
61-
np.linalg.norm(box[0] - box[3]),
62-
np.linalg.norm(box[1] - box[2]),
63-
)
64-
)
65-
if img_crop_height * 1.0 / img_crop_width >= 1.5:
66-
direction = "h"
67-
return direction
51+
def get_box_direction(box: np.ndarray) -> Direction:
52+
edge_lengths = [
53+
float(np.linalg.norm(box[0] - box[1])), # 上边
54+
float(np.linalg.norm(box[1] - box[2])), # 右边
55+
float(np.linalg.norm(box[2] - box[3])), # 下边
56+
float(np.linalg.norm(box[3] - box[0])), # 左边
57+
]
58+
59+
# 宽和高取对边的最大距离
60+
width = max(edge_lengths[0], edge_lengths[2])
61+
height = max(edge_lengths[1], edge_lengths[3])
62+
63+
if width < 1e-6:
64+
return Direction.VERTICAL
65+
66+
aspect_ratio = round(height / width, 2)
67+
return Direction.VERTICAL if aspect_ratio >= 1.5 else Direction.HORIZONTAL
6868

69-
@staticmethod
7069
def cal_ocr_word_box(
71-
rec_txt: str, box: np.ndarray, rec_word_info: List[Tuple[str, List[int]]]
72-
) -> Tuple[List[str], List[List[int]], List[float]]:
70+
self, rec_txt: str, bbox: np.ndarray, word_info: WordInfo
71+
) -> Tuple[List[str], List[List[List[float]]], List[float]]:
7372
"""Calculate the detection frame for each word based on the results of recognition and detection of ocr
7473
汉字坐标是单字的
7574
英语坐标是单词级别的
75+
三种情况:
76+
1. 全是汉字
77+
2. 全是英文
78+
3. 中英混合
7679
"""
80+
bbox_points = quads_to_rect_bbox(bbox[None, ...])
81+
avg_col_width = (bbox_points[2] - bbox_points[0]) / word_info.line_txt_len
7782

78-
col_num, word_list, word_col_list, state_list, conf_list = rec_word_info
79-
box = box.tolist()
80-
bbox_x_start = box[0][0]
81-
bbox_x_end = box[1][0]
82-
bbox_y_start = box[0][1]
83-
bbox_y_end = box[2][1]
84-
85-
cell_width = (bbox_x_end - bbox_x_start) / col_num
86-
word_box_list = []
87-
word_box_content_list = []
88-
cn_width_list = []
89-
en_width_list = []
90-
cn_col_list = []
91-
en_col_list = []
92-
93-
def cal_char_width(width_list, word_col_):
94-
if len(word_col_) == 1:
95-
return
96-
char_total_length = (word_col_[-1] - word_col_[0]) * cell_width
97-
char_width = char_total_length / (len(word_col_) - 1)
98-
width_list.append(char_width)
99-
100-
def cal_box(col_list, width_list, word_box_list_):
101-
if len(col_list) == 0:
102-
return
103-
if len(width_list) != 0:
104-
avg_char_width = np.mean(width_list)
105-
else:
106-
avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_txt)
107-
108-
for center_idx in col_list:
109-
center_x = (center_idx + 0.5) * cell_width
110-
cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start
111-
cell_x_end = (
112-
min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start)
113-
+ bbox_x_start
114-
)
115-
cell = [
116-
[cell_x_start, bbox_y_start],
117-
[cell_x_end, bbox_y_start],
118-
[cell_x_end, bbox_y_end],
119-
[cell_x_start, bbox_y_end],
120-
]
121-
word_box_list_.append(cell)
122-
123-
for word, word_col, state in zip(word_list, word_col_list, state_list):
124-
if state == "cn":
125-
cal_char_width(cn_width_list, word_col)
126-
cn_col_list += word_col
127-
word_box_content_list += word
83+
is_all_en_num = all(v is WordType.EN_NUM for v in word_info.word_types)
84+
85+
line_cols, char_widths, word_contents = [], [], []
86+
for word, word_col in zip(word_info.words, word_info.word_cols):
87+
if is_all_en_num:
88+
line_cols.append(word_col)
89+
word_contents.append("".join(word))
12890
else:
129-
cal_char_width(en_width_list, word_col)
130-
en_col_list += word_col
131-
word_box_content_list += word
91+
line_cols.extend(word_col)
92+
word_contents.extend(word)
93+
94+
if len(word_col) == 1:
95+
continue
13296

133-
cal_box(cn_col_list, cn_width_list, word_box_list)
134-
cal_box(en_col_list, en_width_list, word_box_list)
135-
sorted_word_box_list = sorted(word_box_list, key=lambda box: box[0][0])
136-
return word_box_content_list, sorted_word_box_list, conf_list
97+
avg_width = self.calc_avg_char_width(word_col, avg_col_width)
98+
char_widths.append(avg_width)
99+
100+
avg_char_width = self.calc_all_char_avg_width(
101+
char_widths, bbox_points[0], bbox_points[2], len(rec_txt)
102+
)
103+
104+
if is_all_en_num:
105+
word_boxes = self.calc_en_num_box(
106+
line_cols, avg_char_width, avg_col_width, bbox_points
107+
)
108+
else:
109+
word_boxes = self.calc_box(
110+
line_cols, avg_char_width, avg_col_width, bbox_points
111+
)
112+
return word_contents, word_boxes, word_info.confs
113+
114+
def calc_en_num_box(
115+
self,
116+
line_cols: List[List[int]],
117+
avg_char_width: float,
118+
avg_col_width: float,
119+
bbox_points: Tuple[float, float, float, float],
120+
) -> List[List[List[float]]]:
121+
results = []
122+
for one_col in line_cols:
123+
cur_word_cell = self.calc_box(
124+
one_col, avg_char_width, avg_col_width, bbox_points
125+
)
126+
x0, y0, x1, y1 = quads_to_rect_bbox(np.array(cur_word_cell))
127+
results.append([[x0, y0], [x1, y0], [x1, y1], [x0, y1]])
128+
return results
129+
130+
@staticmethod
131+
def calc_box(
132+
line_cols: List[int],
133+
avg_char_width: float,
134+
avg_col_width: float,
135+
bbox_points: Tuple[float, float, float, float],
136+
) -> List[List[List[float]]]:
137+
x0, y0, x1, y1 = bbox_points
138+
139+
results = []
140+
for col_idx in line_cols:
141+
# 将中心点定位在列的中间位置
142+
center_x = (col_idx + 0.5) * avg_col_width
143+
144+
# 计算字符单元格的左右边界
145+
char_x0 = max(int(center_x - avg_char_width / 2), 0) + x0
146+
char_x1 = min(int(center_x + avg_char_width / 2), x1 - x0) + x0
147+
cell = [
148+
[char_x0, y0],
149+
[char_x1, y0],
150+
[char_x1, y1],
151+
[char_x0, y1],
152+
]
153+
results.append(cell)
154+
return sorted(results, key=lambda x: x[0][0])
155+
156+
@staticmethod
157+
def calc_avg_char_width(word_col: List[int], each_col_width: float) -> float:
158+
char_total_length = (word_col[-1] - word_col[0]) * each_col_width
159+
return char_total_length / (len(word_col) - 1)
160+
161+
@staticmethod
162+
def calc_all_char_avg_width(
163+
width_list: List[float], bbox_x0: float, bbox_x1: float, txt_len: int
164+
) -> float:
165+
if len(width_list) > 0:
166+
return sum(width_list) / len(width_list)
167+
return (bbox_x1 - bbox_x0) / txt_len
137168

138169
@staticmethod
139170
def adjust_box_overlap(
140-
word_box_list: List[List[List[int]]],
141-
) -> List[List[List[int]]]:
171+
word_box_list: List[List[List[float]]],
172+
) -> List[List[List[float]]]:
142173
# 调整bbox有重叠的地方
143174
for i in range(len(word_box_list) - 1):
144175
cur, nxt = word_box_list[i], word_box_list[i + 1]
@@ -153,8 +184,8 @@ def adjust_box_overlap(
153184
def reverse_rotate_crop_image(
154185
self,
155186
bbox_points: np.ndarray,
156-
word_points_list: List[List[List[int]]],
157-
direction: str = "w",
187+
word_points_list: List[List[List[float]]],
188+
direction: Direction,
158189
) -> List[List[List[int]]]:
159190
"""
160191
get_rotate_crop_image的逆操作
@@ -163,8 +194,6 @@ def reverse_rotate_crop_image(
163194
bbox_points为part_img中对应在原图的bbox, 四个点,左上,右上,右下,左下
164195
part_points为在part_img中的点[(x, y), (x, y)]
165196
"""
166-
bbox_points = np.float32(bbox_points)
167-
168197
left = int(np.min(bbox_points[:, 0]))
169198
top = int(np.min(bbox_points[:, 1]))
170199
bbox_points[:, 0] = bbox_points[:, 0] - left
@@ -189,13 +218,13 @@ def reverse_rotate_crop_image(
189218
new_word_points = []
190219
for point in word_points:
191220
new_point = point
192-
if direction == "h":
221+
if direction == Direction.VERTICAL:
193222
new_point = self.s_rotate(
194223
math.radians(-90), new_point[0], new_point[1], 0, 0
195224
)
196225
new_point[0] = new_point[0] + img_crop_width
197226

198-
p = np.float32(new_point + [1])
227+
p = np.array(new_point + [1])
199228
x, y, z = np.dot(IM, p)
200229
new_point = [x / z, y / z]
201230

@@ -225,18 +254,18 @@ def s_rotate(angle, valuex, valuey, pointx, pointy):
225254
return [sRotatex, sRotatey]
226255

227256
@staticmethod
228-
def order_points(box: List[List[int]]) -> List[List[int]]:
257+
def order_points(ori_box: List[List[int]]) -> List[List[int]]:
229258
"""矩形框顺序排列"""
230259

231260
def convert_to_1x2(p):
232261
if p.shape == (2,):
233262
return p.reshape((1, 2))
234-
elif p.shape == (1, 2):
263+
264+
if p.shape == (1, 2):
235265
return p
236-
else:
237-
return p[:1, :]
266+
return p[:1, :]
238267

239-
box = np.array(box).reshape((-1, 2))
268+
box = np.array(ori_box).reshape((-1, 2))
240269
center_x, center_y = np.mean(box[:, 0]), np.mean(box[:, 1])
241270
if np.any(box[:, 0] == center_x) and np.any(
242271
box[:, 1] == center_y

python/rapidocr/ch_ppocr_rec/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# @Author: SWHL
33
# @Contact: [email protected]
44
from .main import TextRecognizer
5-
from .utils import TextRecInput, TextRecOutput
5+
from .typings import TextRecInput, TextRecOutput

python/rapidocr/ch_ppocr_rec/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323

2424
from ..utils import Logger
2525
from ..utils.download_file import DownloadFile, DownloadFileInput
26-
from .utils import CTCLabelDecode, TextRecInput, TextRecOutput
26+
from .utils import CTCLabelDecode
27+
from .typings import TextRecInput, TextRecOutput
2728

2829
DEFAULT_DICT_PATH = Path(__file__).parent.parent / "models" / "ppocr_keys_v1.txt"
2930
DEFAULT_DICT_URL = "https://www.modelscope.cn/models/RapidAI/RapidOCR/resolve/v2.0.7/paddle/PP-OCRv4/rec/ch_PP-OCRv4_rec_infer/ppocr_keys_v1.txt"

0 commit comments

Comments
 (0)