3344import copy
55import math
6- from typing import List , Optional , Tuple
6+ from enum import Enum
7+ from typing import List , Tuple
78
89import cv2
910import numpy as np
1011
11- from ..ch_ppocr_rec .utils import TextRecOutput
12+ from ..ch_ppocr_rec .typings import TextRecOutput , WordInfo , WordType
13+ from ..utils .utils import quads_to_rect_bbox
14+
15+
16+ class Direction (Enum ):
17+ HORIZONTAL = "horizontal_direct" # 水平
18+ VERTICAL = "vertical_direct" # 垂直
1219
1320
1421class CalRecBoxes :
1522 """计算识别文字的汉字单字和英文单词的坐标框。
1623 代码借鉴自PaddlePaddle/PaddleOCR和fanqie03/char-detection"""
1724
18- def __init__ (self ):
19- pass
20-
2125 def __call__ (
22- self ,
23- imgs : Optional [List [np .ndarray ]],
24- dt_boxes : Optional [List [np .ndarray ]],
25- rec_res : TextRecOutput ,
26+ self , imgs : List [np .ndarray ], dt_boxes : List [np .ndarray ], rec_res : TextRecOutput
2627 ) -> TextRecOutput :
2728 word_results = []
2829 for idx , (img , box ) in enumerate (zip (imgs , dt_boxes )):
29- direction = self .get_box_direction (box )
30-
31- rec_txt = rec_res .txts [idx ]
32- rec_word_info = rec_res .word_results [idx ]
30+ if rec_res .txts is None :
31+ continue
3332
3433 h , w = img .shape [:2 ]
3534 img_box = np .array ([[0 , 0 ], [w , 0 ], [w , h ], [0 , h ]])
3635 word_box_content_list , word_box_list , conf_list = self .cal_ocr_word_box (
37- rec_txt , img_box , rec_word_info
36+ rec_res . txts [ idx ] , img_box , rec_res . word_results [ idx ]
3837 )
3938 word_box_list = self .adjust_box_overlap (copy .deepcopy (word_box_list ))
39+ direction = self .get_box_direction (box )
4040 word_box_list = self .reverse_rotate_crop_image (
4141 copy .deepcopy (box ), word_box_list , direction
4242 )
@@ -48,97 +48,128 @@ def __call__(
4848 return rec_res
4949
5050 @staticmethod
51- def get_box_direction (box : np .ndarray ) -> str :
52- direction = "w"
53- img_crop_width = int (
54- max (
55- np .linalg .norm (box [0 ] - box [1 ]),
56- np .linalg .norm (box [2 ] - box [3 ]),
57- )
58- )
59- img_crop_height = int (
60- max (
61- np . linalg . norm ( box [ 0 ] - box [3 ]),
62- np . linalg . norm ( box [ 1 ] - box [ 2 ]),
63- )
64- )
65- if img_crop_height * 1.0 / img_crop_width >= 1.5 :
66- direction = "h"
67- return direction
51+ def get_box_direction (box : np .ndarray ) -> Direction :
52+ edge_lengths = [
53+ float ( np . linalg . norm ( box [ 0 ] - box [ 1 ])), # 上边
54+ float ( np . linalg . norm ( box [ 1 ] - box [ 2 ])), # 右边
55+ float ( np .linalg .norm (box [2 ] - box [3 ])), # 下边
56+ float ( np .linalg .norm (box [3 ] - box [0 ])), # 左边
57+ ]
58+
59+ # 宽和高取对边的最大距离
60+ width = max (edge_lengths [ 0 ], edge_lengths [ 2 ])
61+ height = max ( edge_lengths [ 1 ], edge_lengths [3 ])
62+
63+ if width < 1e-6 :
64+ return Direction . VERTICAL
65+
66+ aspect_ratio = round ( height / width , 2 )
67+ return Direction . VERTICAL if aspect_ratio >= 1.5 else Direction . HORIZONTAL
6868
69- @staticmethod
7069 def cal_ocr_word_box (
71- rec_txt : str , box : np .ndarray , rec_word_info : List [ Tuple [ str , List [ int ]]]
72- ) -> Tuple [List [str ], List [List [int ]], List [float ]]:
70+ self , rec_txt : str , bbox : np .ndarray , word_info : WordInfo
71+ ) -> Tuple [List [str ], List [List [List [ float ] ]], List [float ]]:
7372 """Calculate the detection frame for each word based on the results of recognition and detection of ocr
7473 汉字坐标是单字的
7574 英语坐标是单词级别的
75+ 三种情况:
76+ 1. 全是汉字
77+ 2. 全是英文
78+ 3. 中英混合
7679 """
80+ bbox_points = quads_to_rect_bbox (bbox [None , ...])
81+ avg_col_width = (bbox_points [2 ] - bbox_points [0 ]) / word_info .line_txt_len
7782
78- col_num , word_list , word_col_list , state_list , conf_list = rec_word_info
79- box = box .tolist ()
80- bbox_x_start = box [0 ][0 ]
81- bbox_x_end = box [1 ][0 ]
82- bbox_y_start = box [0 ][1 ]
83- bbox_y_end = box [2 ][1 ]
84-
85- cell_width = (bbox_x_end - bbox_x_start ) / col_num
86- word_box_list = []
87- word_box_content_list = []
88- cn_width_list = []
89- en_width_list = []
90- cn_col_list = []
91- en_col_list = []
92-
93- def cal_char_width (width_list , word_col_ ):
94- if len (word_col_ ) == 1 :
95- return
96- char_total_length = (word_col_ [- 1 ] - word_col_ [0 ]) * cell_width
97- char_width = char_total_length / (len (word_col_ ) - 1 )
98- width_list .append (char_width )
99-
100- def cal_box (col_list , width_list , word_box_list_ ):
101- if len (col_list ) == 0 :
102- return
103- if len (width_list ) != 0 :
104- avg_char_width = np .mean (width_list )
105- else :
106- avg_char_width = (bbox_x_end - bbox_x_start ) / len (rec_txt )
107-
108- for center_idx in col_list :
109- center_x = (center_idx + 0.5 ) * cell_width
110- cell_x_start = max (int (center_x - avg_char_width / 2 ), 0 ) + bbox_x_start
111- cell_x_end = (
112- min (int (center_x + avg_char_width / 2 ), bbox_x_end - bbox_x_start )
113- + bbox_x_start
114- )
115- cell = [
116- [cell_x_start , bbox_y_start ],
117- [cell_x_end , bbox_y_start ],
118- [cell_x_end , bbox_y_end ],
119- [cell_x_start , bbox_y_end ],
120- ]
121- word_box_list_ .append (cell )
122-
123- for word , word_col , state in zip (word_list , word_col_list , state_list ):
124- if state == "cn" :
125- cal_char_width (cn_width_list , word_col )
126- cn_col_list += word_col
127- word_box_content_list += word
83+ is_all_en_num = all (v is WordType .EN_NUM for v in word_info .word_types )
84+
85+ line_cols , char_widths , word_contents = [], [], []
86+ for word , word_col in zip (word_info .words , word_info .word_cols ):
87+ if is_all_en_num :
88+ line_cols .append (word_col )
89+ word_contents .append ("" .join (word ))
12890 else :
129- cal_char_width (en_width_list , word_col )
130- en_col_list += word_col
131- word_box_content_list += word
91+ line_cols .extend (word_col )
92+ word_contents .extend (word )
93+
94+ if len (word_col ) == 1 :
95+ continue
13296
133- cal_box (cn_col_list , cn_width_list , word_box_list )
134- cal_box (en_col_list , en_width_list , word_box_list )
135- sorted_word_box_list = sorted (word_box_list , key = lambda box : box [0 ][0 ])
136- return word_box_content_list , sorted_word_box_list , conf_list
97+ avg_width = self .calc_avg_char_width (word_col , avg_col_width )
98+ char_widths .append (avg_width )
99+
100+ avg_char_width = self .calc_all_char_avg_width (
101+ char_widths , bbox_points [0 ], bbox_points [2 ], len (rec_txt )
102+ )
103+
104+ if is_all_en_num :
105+ word_boxes = self .calc_en_num_box (
106+ line_cols , avg_char_width , avg_col_width , bbox_points
107+ )
108+ else :
109+ word_boxes = self .calc_box (
110+ line_cols , avg_char_width , avg_col_width , bbox_points
111+ )
112+ return word_contents , word_boxes , word_info .confs
113+
114+ def calc_en_num_box (
115+ self ,
116+ line_cols : List [List [int ]],
117+ avg_char_width : float ,
118+ avg_col_width : float ,
119+ bbox_points : Tuple [float , float , float , float ],
120+ ) -> List [List [List [float ]]]:
121+ results = []
122+ for one_col in line_cols :
123+ cur_word_cell = self .calc_box (
124+ one_col , avg_char_width , avg_col_width , bbox_points
125+ )
126+ x0 , y0 , x1 , y1 = quads_to_rect_bbox (np .array (cur_word_cell ))
127+ results .append ([[x0 , y0 ], [x1 , y0 ], [x1 , y1 ], [x0 , y1 ]])
128+ return results
129+
130+ @staticmethod
131+ def calc_box (
132+ line_cols : List [int ],
133+ avg_char_width : float ,
134+ avg_col_width : float ,
135+ bbox_points : Tuple [float , float , float , float ],
136+ ) -> List [List [List [float ]]]:
137+ x0 , y0 , x1 , y1 = bbox_points
138+
139+ results = []
140+ for col_idx in line_cols :
141+ # 将中心点定位在列的中间位置
142+ center_x = (col_idx + 0.5 ) * avg_col_width
143+
144+ # 计算字符单元格的左右边界
145+ char_x0 = max (int (center_x - avg_char_width / 2 ), 0 ) + x0
146+ char_x1 = min (int (center_x + avg_char_width / 2 ), x1 - x0 ) + x0
147+ cell = [
148+ [char_x0 , y0 ],
149+ [char_x1 , y0 ],
150+ [char_x1 , y1 ],
151+ [char_x0 , y1 ],
152+ ]
153+ results .append (cell )
154+ return sorted (results , key = lambda x : x [0 ][0 ])
155+
156+ @staticmethod
157+ def calc_avg_char_width (word_col : List [int ], each_col_width : float ) -> float :
158+ char_total_length = (word_col [- 1 ] - word_col [0 ]) * each_col_width
159+ return char_total_length / (len (word_col ) - 1 )
160+
161+ @staticmethod
162+ def calc_all_char_avg_width (
163+ width_list : List [float ], bbox_x0 : float , bbox_x1 : float , txt_len : int
164+ ) -> float :
165+ if len (width_list ) > 0 :
166+ return sum (width_list ) / len (width_list )
167+ return (bbox_x1 - bbox_x0 ) / txt_len
137168
138169 @staticmethod
139170 def adjust_box_overlap (
140- word_box_list : List [List [List [int ]]],
141- ) -> List [List [List [int ]]]:
171+ word_box_list : List [List [List [float ]]],
172+ ) -> List [List [List [float ]]]:
142173 # 调整bbox有重叠的地方
143174 for i in range (len (word_box_list ) - 1 ):
144175 cur , nxt = word_box_list [i ], word_box_list [i + 1 ]
@@ -153,8 +184,8 @@ def adjust_box_overlap(
153184 def reverse_rotate_crop_image (
154185 self ,
155186 bbox_points : np .ndarray ,
156- word_points_list : List [List [List [int ]]],
157- direction : str = "w" ,
187+ word_points_list : List [List [List [float ]]],
188+ direction : Direction ,
158189 ) -> List [List [List [int ]]]:
159190 """
160191 get_rotate_crop_image的逆操作
@@ -163,8 +194,6 @@ def reverse_rotate_crop_image(
163194 bbox_points为part_img中对应在原图的bbox, 四个点,左上,右上,右下,左下
164195 part_points为在part_img中的点[(x, y), (x, y)]
165196 """
166- bbox_points = np .float32 (bbox_points )
167-
168197 left = int (np .min (bbox_points [:, 0 ]))
169198 top = int (np .min (bbox_points [:, 1 ]))
170199 bbox_points [:, 0 ] = bbox_points [:, 0 ] - left
@@ -189,13 +218,13 @@ def reverse_rotate_crop_image(
189218 new_word_points = []
190219 for point in word_points :
191220 new_point = point
192- if direction == "h" :
221+ if direction == Direction . VERTICAL :
193222 new_point = self .s_rotate (
194223 math .radians (- 90 ), new_point [0 ], new_point [1 ], 0 , 0
195224 )
196225 new_point [0 ] = new_point [0 ] + img_crop_width
197226
198- p = np .float32 (new_point + [1 ])
227+ p = np .array (new_point + [1 ])
199228 x , y , z = np .dot (IM , p )
200229 new_point = [x / z , y / z ]
201230
@@ -225,18 +254,18 @@ def s_rotate(angle, valuex, valuey, pointx, pointy):
225254 return [sRotatex , sRotatey ]
226255
227256 @staticmethod
228- def order_points (box : List [List [int ]]) -> List [List [int ]]:
257+ def order_points (ori_box : List [List [int ]]) -> List [List [int ]]:
229258 """矩形框顺序排列"""
230259
231260 def convert_to_1x2 (p ):
232261 if p .shape == (2 ,):
233262 return p .reshape ((1 , 2 ))
234- elif p .shape == (1 , 2 ):
263+
264+ if p .shape == (1 , 2 ):
235265 return p
236- else :
237- return p [:1 , :]
266+ return p [:1 , :]
238267
239- box = np .array (box ).reshape ((- 1 , 2 ))
268+ box = np .array (ori_box ).reshape ((- 1 , 2 ))
240269 center_x , center_y = np .mean (box [:, 0 ]), np .mean (box [:, 1 ])
241270 if np .any (box [:, 0 ] == center_x ) and np .any (
242271 box [:, 1 ] == center_y
0 commit comments