-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Description
请提出你的问题 Please ask your question
import sys
import os
import logging
import threading
import time
在导入 paddleocr 之前设置环境变量,强制使用离线模式
os.environ['PADDLE_PDX_MODEL_DOWNLOAD_SOURCE'] = 'LOCAL'
from PySide6.QtWidgets import (QApplication, QMainWindow, QVBoxLayout, QHBoxLayout,
QPushButton, QWidget, QLabel, QTextEdit, QFileDialog,
QDialog, QGroupBox, QCheckBox, QRadioButton, QSpinBox,
QDoubleSpinBox, QFormLayout, QComboBox)
from PySide6.QtGui import QPixmap, QPainter, QPen, QGuiApplication, QScreen, QColor, QFont, QImage
from PySide6.QtCore import Qt, QRect, QTimer, Signal, Slot
导入 pynput 用于全局热键
try:
from pynput import keyboard
except ImportError:
print("错误: pynput 库未安装。请运行 'pip install pynput' 来安装。")
sys.exit(1)
from paddleocr import PaddleOCR
from PIL import Image, ImageQt
import numpy as np
屏蔽不必要的日志
logging.disable(logging.WARNING)
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
--- 全新的专业截图窗口 ---
class ScreenshotWindow(QWidget):
screenshot_taken = Signal(QPixmap)
def __init__(self):
super().__init__()
self.setWindowFlags(Qt.WindowType.FramelessWindowHint | Qt.WindowType.WindowStaysOnTopHint | Qt.Tool)
self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
self.setCursor(Qt.CursorShape.CrossCursor)
self.screen = QApplication.primaryScreen()
self.setGeometry(self.screen.geometry())
self.start_point = None
self.end_point = None
self.selection_rect = QRect()
def paintEvent(self, event):
painter = QPainter(self)
painter.fillRect(self.rect(), QColor(0, 0, 0, 100)) # 半透明黑色遮罩
if not self.selection_rect.isNull():
painter.setCompositionMode(QPainter.CompositionMode.CompositionMode_Clear)
painter.fillRect(self.selection_rect, Qt.GlobalColor.transparent)
painter.setCompositionMode(QPainter.CompositionMode.CompositionMode_SourceOver)
pen = QPen(QColor(0, 150, 255), 2)
painter.setPen(pen)
painter.drawRect(self.selection_rect)
def mousePressEvent(self, event):
self.start_point = event.position().toPoint()
self.selection_rect = QRect(self.start_point, self.start_point)
self.update()
def mouseMoveEvent(self, event):
if self.start_point:
self.end_point = event.position().toPoint()
self.selection_rect = QRect(self.start_point, self.end_point).normalized()
self.update()
def mouseReleaseEvent(self, event):
if self.start_point and self.end_point and self.selection_rect.width() > 5 and self.selection_rect.height() > 5:
self.hide()
screenshot = self.screen.grabWindow(
0,
self.selection_rect.x(),
self.selection_rect.y(),
self.selection_rect.width(),
self.selection_rect.height()
)
self.screenshot_taken.emit(screenshot)
else:
self.screenshot_taken.emit(QPixmap())
self.close()
def keyPressEvent(self, event):
if event.key() == Qt.Key_Escape:
self.screenshot_taken.emit(QPixmap())
self.close()
--- 主应用窗口 ---
class OCRApp(QMainWindow):
trigger_screenshot_signal = Signal()
def __init__(self):
super().__init__()
self.setWindowTitle("高精度OCR工具 (最终版)")
self.setGeometry(100, 100, 1200, 800)
self.ocr_engine = None
self.current_image_input = None
self.screenshot_window = None
self.init_ui()
self.setWindowTitle("高精度OCR工具 (正在后台加载模型...)")
QTimer.singleShot(100, self.initialize_ocr_engine)
self.trigger_screenshot_signal.connect(self.take_screenshot)
def initialize_ocr_engine(self):
self.btn_run_ocr.setText("正在加载模型...")
self.btn_run_ocr.setEnabled(False)
QApplication.processEvents()
try:
lang_code = self.lang_combo.currentData()
print(f"正在加载语言模型: {lang_code}...")
self.ocr_engine = PaddleOCR(
use_textline_orientation=True,
lang=lang_code,
device='gpu:0',
ocr_version='PP-OCRv5',
text_detection_model_name='PP-OCRv5_server_det',
text_recognition_model_name='PP-OCRv5_server_rec'
)
print(f"语言模型 {lang_code} 初始化成功!")
self.setWindowTitle("高精度OCR工具 (模型加载完成)")
self.btn_run_ocr.setText("开始识别")
self.btn_run_ocr.setEnabled(True)
except Exception as e:
self.result_text.setText(f"引擎初始化失败:\n{e}\n\n请确保已正确安装paddlepaddle-gpu版本。")
self.setWindowTitle("高精度OCR工具 (引擎初始化失败!)")
logging.exception("Engine Init Error")
def init_ui(self):
central_widget = QWidget()
self.setCentralWidget(central_widget)
main_layout = QHBoxLayout(central_widget)
left_panel = QWidget()
left_layout = QVBoxLayout(left_panel)
left_panel.setFixedWidth(320)
self.btn_load = QPushButton("1. 加载本地图片")
self.btn_fullscreen = QPushButton("2. 截取全屏")
self.btn_select_area = QPushButton("3. 框选识别区域 (Ctrl+Alt+X)")
self.btn_run_ocr = QPushButton("开始识别 (正在加载模型...)")
self.btn_run_ocr.setEnabled(False)
config_groupbox = QGroupBox("高级配置")
form_layout = QFormLayout()
self.lang_combo = QComboBox()
self.lang_combo.addItem("自动选择 (中/英/繁/日)", "ch")
self.lang_combo.addItem("繁体中文 (chinese_cht)", "chinese_cht")
self.lang_combo.addItem("英文 (en)", "en")
self.lang_combo.addItem("日文 (japan)", "japan")
self.lang_combo.addItem("韩文 (korean)", "korean")
self.lang_combo.currentIndexChanged.connect(self.initialize_ocr_engine)
form_layout.addRow("选择语言:", self.lang_combo)
self.check_doc_orientation = QCheckBox()
self.check_doc_orientation.setChecked(False)
form_layout.addRow("启用文档方向分类:", self.check_doc_orientation)
self.check_doc_unwarping = QCheckBox()
self.check_doc_unwarping.setChecked(False)
form_layout.addRow("启用文档扭曲矫正:", self.check_doc_unwarping)
self.check_textline_orientation = QCheckBox()
self.check_textline_orientation.setChecked(True)
form_layout.addRow("启用文本行方向分类:", self.check_textline_orientation)
self.radio_limit_type_min = QRadioButton("短边")
self.radio_limit_type_max = QRadioButton("长边")
self.radio_limit_type_min.setChecked(True)
limit_type_layout = QHBoxLayout()
limit_type_layout.addWidget(self.radio_limit_type_min)
limit_type_layout.addWidget(self.radio_limit_type_max)
form_layout.addRow("图像边长限制类型:", limit_type_layout)
self.spin_limit_side_len = QSpinBox()
self.spin_limit_side_len.setRange(32, 4096)
self.spin_limit_side_len.setValue(960)
form_layout.addRow("图像边长限制:", self.spin_limit_side_len)
self.spin_det_thresh = QDoubleSpinBox()
self.spin_det_thresh.setRange(0.0, 1.0)
self.spin_det_thresh.setSingleStep(0.05)
self.spin_det_thresh.setValue(0.3)
form_layout.addRow("文本检测像素阈值:", self.spin_det_thresh)
self.spin_box_thresh = QDoubleSpinBox()
self.spin_box_thresh.setRange(0.0, 1.0)
self.spin_box_thresh.setSingleStep(0.05)
self.spin_box_thresh.setValue(0.6)
form_layout.addRow("文本检测框阈值:", self.spin_box_thresh)
self.spin_unclip_ratio = QDoubleSpinBox()
self.spin_unclip_ratio.setRange(1.0, 10.0)
self.spin_unclip_ratio.setSingleStep(0.1)
self.spin_unclip_ratio.setValue(2.0)
form_layout.addRow("文本检测扩张系数:", self.spin_unclip_ratio)
self.spin_rec_score_thresh = QDoubleSpinBox()
self.spin_rec_score_thresh.setRange(0.0, 1.0)
self.spin_rec_score_thresh.setSingleStep(0.05)
self.spin_rec_score_thresh.setValue(0.0)
form_layout.addRow("文本识别阈值:", self.spin_rec_score_thresh)
config_groupbox.setLayout(form_layout)
self.result_text = QTextEdit()
self.result_text.setPlaceholderText("识别结果将显示在这里...")
font = QFont("SimSun", 10)
self.result_text.setFont(font)
left_layout.addWidget(self.btn_load)
left_layout.addWidget(self.btn_fullscreen)
left_layout.addWidget(self.btn_select_area)
left_layout.addWidget(config_groupbox)
left_layout.addWidget(self.btn_run_ocr)
left_layout.addWidget(QLabel("识别结果 (保留排版):"))
left_layout.addWidget(self.result_text)
self.image_label = QLabel("请加载图片或截图")
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setStyleSheet("border: 2px dashed #aaa;")
self.image_label.setScaledContents(True)
main_layout.addWidget(left_panel)
main_layout.addWidget(self.image_label, 1)
self.btn_load.clicked.connect(self.load_image)
self.btn_fullscreen.clicked.connect(self.capture_fullscreen)
self.btn_select_area.clicked.connect(self.take_screenshot)
self.btn_run_ocr.clicked.connect(self.run_ocr)
@Slot()
def take_screenshot(self):
self.hide()
time.sleep(0.2)
self.screenshot_window = ScreenshotWindow()
self.screenshot_window.screenshot_taken.connect(self.process_screenshot)
self.screenshot_window.show()
@Slot(QPixmap)
def process_screenshot(self, pixmap: QPixmap):
self.show()
if not pixmap.isNull():
self.current_image_input = pixmap.toImage()
self.display_image(pixmap)
self.result_text.setText("截图成功!点击“开始识别”进行处理。")
self.run_ocr()
else:
self.result_text.setText("截图操作已取消。")
def get_image_as_numpy(self, image_input):
if image_input is None: return None
image_to_process = None
if isinstance(image_input, str):
image_to_process = Image.open(image_input)
elif isinstance(image_input, Image.Image):
image_to_process = image_input
elif isinstance(image_input, QImage):
image_to_process = ImageQt.fromqimage(image_input)
elif isinstance(image_input, QPixmap):
image_to_process = ImageQt.fromqpixmap(image_input)
if image_to_process:
return np.array(image_to_process.convert('RGB'))
return None
def reconstruct_layout(self, data_dict):
"""
【【【 全新智能排版算法 】】】
根据文本框坐标智能判断横排/竖排并重构原始排版
"""
if not (data_dict and 'res' in data_dict and data_dict['res'].get('rec_polys')):
return ""
boxes = data_dict['res']['rec_polys']
texts = data_dict['res']['rec_texts']
if not texts:
return ""
items = []
for box, text in zip(boxes, texts):
x_coords = [p[0] for p in box]
y_coords = [p[1] for p in box]
width = max(x_coords) - min(x_coords)
height = max(y_coords) - min(y_coords)
items.append({'text': text, 'x': min(x_coords), 'y': min(y_coords), 'w': width, 'h': height})
if not items:
return ""
# 通过平均宽高比判断是横排还是竖排
avg_wh_ratio = sum(item['w'] / item['h'] for item in items) / len(items)
is_vertical = avg_wh_ratio < 1.0
if not is_vertical:
# --- 横排逻辑 ---
items.sort(key=lambda item: (item['y'], item['x']))
lines = []
if items:
current_line = [items[0]]
avg_height = sum(item['h'] for item in items) / len(items)
for item in items[1:]:
if abs(item['y'] - current_line[-1]['y']) > avg_height * 0.7:
lines.append(current_line)
current_line = [item]
else:
current_line.append(item)
lines.append(current_line)
final_text = ""
for line in lines:
line.sort(key=lambda item: item['x'])
final_text += " ".join([item['text'] for item in line]) + "\n"
return final_text
else:
# --- 竖排逻辑 ---
items.sort(key=lambda item: (item['x'], item['y'])) # 按x坐标排序,得到列
columns = []
if items:
current_column = [items[0]]
avg_width = sum(item['w'] for item in items) / len(items)
for item in items[1:]:
# 如果下一个框的x坐标离当前列太远,则开启新列
if abs(item['x'] - current_column[-1]['x']) > avg_width * 2.0:
columns.append(current_column)
current_column = [item]
else:
current_column.append(item)
columns.append(current_column)
# 从右到左排列各列
columns.sort(key=lambda col: -col[0]['x'])
final_text = ""
max_rows = max(len(col) for col in columns) if columns else 0
for i in range(max_rows):
row_texts = []
for col in columns:
row_texts.append(col[i]['text'] if i < len(col) else " ")
final_text += " ".join(row_texts) + "\n"
return final_text
def run_ocr(self):
img_np = self.get_image_as_numpy(self.current_image_input)
if img_np is None:
self.result_text.setText("错误:图像数据为空,请重新加载或截图!")
return
self.result_text.setText("正在使用高精度模型识别中...")
QApplication.processEvents()
try:
predict_params = {
'use_doc_orientation_classify': self.check_doc_orientation.isChecked(),
'use_doc_unwarping': self.check_doc_unwarping.isChecked(),
'use_textline_orientation': self.check_textline_orientation.isChecked(),
'text_det_limit_type': 'min' if self.radio_limit_type_min.isChecked() else 'max',
'text_det_limit_side_len': self.spin_limit_side_len.value(),
'text_det_thresh': self.spin_det_thresh.value(),
'text_det_box_thresh': self.spin_box_thresh.value(),
'text_det_unclip_ratio': self.spin_unclip_ratio.value(),
'text_rec_score_thresh': self.spin_rec_score_thresh.value()
}
results = self.ocr_engine.predict(img_np, **predict_params)
if not results or not results[0]:
self.result_text.setText("未识别到任何内容。")
return
result_obj = results[0]
data_dict = result_obj.json
formatted_text = self.reconstruct_layout(data_dict)
self.result_text.setText(formatted_text)
vis_image_dict = result_obj.img
vis_pil_image = vis_image_dict.get('ocr_res_img')
if vis_pil_image:
self.display_image(vis_pil_image.toqpixmap())
except Exception as e:
self.result_text.setText(f"识别出错:\n{e}")
logging.exception("OCR Error")
def load_image(self, file_path=None):
if not file_path:
file_path, _ = QFileDialog.getOpenFileName(self, "选择图片文件", "", "图片文件 (*.png *.jpg *.jpeg *.bmp)")
if file_path:
self.current_image_input = file_path
self.display_image(QPixmap(file_path))
self.result_text.clear()
def capture_fullscreen(self):
self.hide()
QTimer.singleShot(200, self._internal_capture_fullscreen)
def _internal_capture_fullscreen(self):
screen = QGuiApplication.primaryScreen()
screenshot = screen.grabWindow(0)
self.show()
self.current_image_input = screenshot.toImage()
self.display_image(screenshot)
self.result_text.clear()
self.run_ocr()
def display_image(self, pixmap_obj):
self.image_label.setPixmap(pixmap_obj)
--- 全局热键监听 ---
class HotkeyListener(threading.Thread):
def init(self, app_window):
super().init()
self.daemon = True
self.app_window = app_window
self.hotkeys = {
'++x': self.on_activate_screenshot
}
def on_activate_screenshot(self):
print("快捷键 (Ctrl+Alt+X) 已触发!")
self.app_window.trigger_screenshot_signal.emit()
def run(self):
with keyboard.GlobalHotKeys(self.hotkeys) as listener:
listener.join()
if name == "main":
app = QApplication(sys.argv)
window = OCRApp()
window.show()
hotkey_thread = HotkeyListener(window)
hotkey_thread.start()
sys.exit(app.exec())
