Skip to content

PP-OCRv5_server_rec 本地识别竖排就不行 网页STUDIO就行呢? #75283

@zaichao

Description

@zaichao

请提出你的问题 Please ask your question

import sys
import os
import logging
import threading
import time

在导入 paddleocr 之前设置环境变量,强制使用离线模式

os.environ['PADDLE_PDX_MODEL_DOWNLOAD_SOURCE'] = 'LOCAL'

from PySide6.QtWidgets import (QApplication, QMainWindow, QVBoxLayout, QHBoxLayout,
QPushButton, QWidget, QLabel, QTextEdit, QFileDialog,
QDialog, QGroupBox, QCheckBox, QRadioButton, QSpinBox,
QDoubleSpinBox, QFormLayout, QComboBox)
from PySide6.QtGui import QPixmap, QPainter, QPen, QGuiApplication, QScreen, QColor, QFont, QImage
from PySide6.QtCore import Qt, QRect, QTimer, Signal, Slot

导入 pynput 用于全局热键

try:
from pynput import keyboard
except ImportError:
print("错误: pynput 库未安装。请运行 'pip install pynput' 来安装。")
sys.exit(1)

from paddleocr import PaddleOCR
from PIL import Image, ImageQt
import numpy as np

屏蔽不必要的日志

logging.disable(logging.WARNING)
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

--- 全新的专业截图窗口 ---

class ScreenshotWindow(QWidget):
screenshot_taken = Signal(QPixmap)

def __init__(self):
    super().__init__()
    self.setWindowFlags(Qt.WindowType.FramelessWindowHint | Qt.WindowType.WindowStaysOnTopHint | Qt.Tool)
    self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
    self.setCursor(Qt.CursorShape.CrossCursor)

    self.screen = QApplication.primaryScreen()
    self.setGeometry(self.screen.geometry())

    self.start_point = None
    self.end_point = None
    self.selection_rect = QRect()

def paintEvent(self, event):
    painter = QPainter(self)
    painter.fillRect(self.rect(), QColor(0, 0, 0, 100))  # 半透明黑色遮罩

    if not self.selection_rect.isNull():
        painter.setCompositionMode(QPainter.CompositionMode.CompositionMode_Clear)
        painter.fillRect(self.selection_rect, Qt.GlobalColor.transparent)
        painter.setCompositionMode(QPainter.CompositionMode.CompositionMode_SourceOver)
        pen = QPen(QColor(0, 150, 255), 2)
        painter.setPen(pen)
        painter.drawRect(self.selection_rect)

def mousePressEvent(self, event):
    self.start_point = event.position().toPoint()
    self.selection_rect = QRect(self.start_point, self.start_point)
    self.update()

def mouseMoveEvent(self, event):
    if self.start_point:
        self.end_point = event.position().toPoint()
        self.selection_rect = QRect(self.start_point, self.end_point).normalized()
        self.update()

def mouseReleaseEvent(self, event):
    if self.start_point and self.end_point and self.selection_rect.width() > 5 and self.selection_rect.height() > 5:
        self.hide()
        screenshot = self.screen.grabWindow(
            0,
            self.selection_rect.x(),
            self.selection_rect.y(),
            self.selection_rect.width(),
            self.selection_rect.height()
        )
        self.screenshot_taken.emit(screenshot)
    else:
        self.screenshot_taken.emit(QPixmap())
    self.close()

def keyPressEvent(self, event):
    if event.key() == Qt.Key_Escape:
        self.screenshot_taken.emit(QPixmap())
        self.close()

--- 主应用窗口 ---

class OCRApp(QMainWindow):
trigger_screenshot_signal = Signal()

def __init__(self):
    super().__init__()
    self.setWindowTitle("高精度OCR工具 (最终版)")
    self.setGeometry(100, 100, 1200, 800)
    self.ocr_engine = None
    self.current_image_input = None
    self.screenshot_window = None

    self.init_ui()
    self.setWindowTitle("高精度OCR工具 (正在后台加载模型...)")
    QTimer.singleShot(100, self.initialize_ocr_engine)
    self.trigger_screenshot_signal.connect(self.take_screenshot)

def initialize_ocr_engine(self):
    self.btn_run_ocr.setText("正在加载模型...")
    self.btn_run_ocr.setEnabled(False)
    QApplication.processEvents()

    try:
        lang_code = self.lang_combo.currentData()
        print(f"正在加载语言模型: {lang_code}...")

        self.ocr_engine = PaddleOCR(
            use_textline_orientation=True,
            lang=lang_code,
            device='gpu:0',
            ocr_version='PP-OCRv5',
            text_detection_model_name='PP-OCRv5_server_det',
            text_recognition_model_name='PP-OCRv5_server_rec'
        )
        print(f"语言模型 {lang_code} 初始化成功!")
        self.setWindowTitle("高精度OCR工具 (模型加载完成)")
        self.btn_run_ocr.setText("开始识别")
        self.btn_run_ocr.setEnabled(True)
    except Exception as e:
        self.result_text.setText(f"引擎初始化失败:\n{e}\n\n请确保已正确安装paddlepaddle-gpu版本。")
        self.setWindowTitle("高精度OCR工具 (引擎初始化失败!)")
        logging.exception("Engine Init Error")

def init_ui(self):
    central_widget = QWidget()
    self.setCentralWidget(central_widget)
    main_layout = QHBoxLayout(central_widget)

    left_panel = QWidget()
    left_layout = QVBoxLayout(left_panel)
    left_panel.setFixedWidth(320)

    self.btn_load = QPushButton("1. 加载本地图片")
    self.btn_fullscreen = QPushButton("2. 截取全屏")
    self.btn_select_area = QPushButton("3. 框选识别区域 (Ctrl+Alt+X)")
    self.btn_run_ocr = QPushButton("开始识别 (正在加载模型...)")
    self.btn_run_ocr.setEnabled(False)

    config_groupbox = QGroupBox("高级配置")
    form_layout = QFormLayout()

    self.lang_combo = QComboBox()
    self.lang_combo.addItem("自动选择 (中/英/繁/日)", "ch")
    self.lang_combo.addItem("繁体中文 (chinese_cht)", "chinese_cht")
    self.lang_combo.addItem("英文 (en)", "en")
    self.lang_combo.addItem("日文 (japan)", "japan")
    self.lang_combo.addItem("韩文 (korean)", "korean")
    self.lang_combo.currentIndexChanged.connect(self.initialize_ocr_engine)
    form_layout.addRow("选择语言:", self.lang_combo)

    self.check_doc_orientation = QCheckBox()
    self.check_doc_orientation.setChecked(False)
    form_layout.addRow("启用文档方向分类:", self.check_doc_orientation)

    self.check_doc_unwarping = QCheckBox()
    self.check_doc_unwarping.setChecked(False)
    form_layout.addRow("启用文档扭曲矫正:", self.check_doc_unwarping)

    self.check_textline_orientation = QCheckBox()
    self.check_textline_orientation.setChecked(True)
    form_layout.addRow("启用文本行方向分类:", self.check_textline_orientation)

    self.radio_limit_type_min = QRadioButton("短边")
    self.radio_limit_type_max = QRadioButton("长边")
    self.radio_limit_type_min.setChecked(True)
    limit_type_layout = QHBoxLayout()
    limit_type_layout.addWidget(self.radio_limit_type_min)
    limit_type_layout.addWidget(self.radio_limit_type_max)
    form_layout.addRow("图像边长限制类型:", limit_type_layout)

    self.spin_limit_side_len = QSpinBox()
    self.spin_limit_side_len.setRange(32, 4096)
    self.spin_limit_side_len.setValue(960)
    form_layout.addRow("图像边长限制:", self.spin_limit_side_len)

    self.spin_det_thresh = QDoubleSpinBox()
    self.spin_det_thresh.setRange(0.0, 1.0)
    self.spin_det_thresh.setSingleStep(0.05)
    self.spin_det_thresh.setValue(0.3)
    form_layout.addRow("文本检测像素阈值:", self.spin_det_thresh)

    self.spin_box_thresh = QDoubleSpinBox()
    self.spin_box_thresh.setRange(0.0, 1.0)
    self.spin_box_thresh.setSingleStep(0.05)
    self.spin_box_thresh.setValue(0.6)
    form_layout.addRow("文本检测框阈值:", self.spin_box_thresh)

    self.spin_unclip_ratio = QDoubleSpinBox()
    self.spin_unclip_ratio.setRange(1.0, 10.0)
    self.spin_unclip_ratio.setSingleStep(0.1)
    self.spin_unclip_ratio.setValue(2.0)
    form_layout.addRow("文本检测扩张系数:", self.spin_unclip_ratio)

    self.spin_rec_score_thresh = QDoubleSpinBox()
    self.spin_rec_score_thresh.setRange(0.0, 1.0)
    self.spin_rec_score_thresh.setSingleStep(0.05)
    self.spin_rec_score_thresh.setValue(0.0)
    form_layout.addRow("文本识别阈值:", self.spin_rec_score_thresh)

    config_groupbox.setLayout(form_layout)

    self.result_text = QTextEdit()
    self.result_text.setPlaceholderText("识别结果将显示在这里...")
    font = QFont("SimSun", 10)
    self.result_text.setFont(font)

    left_layout.addWidget(self.btn_load)
    left_layout.addWidget(self.btn_fullscreen)
    left_layout.addWidget(self.btn_select_area)
    left_layout.addWidget(config_groupbox)
    left_layout.addWidget(self.btn_run_ocr)
    left_layout.addWidget(QLabel("识别结果 (保留排版):"))
    left_layout.addWidget(self.result_text)

    self.image_label = QLabel("请加载图片或截图")
    self.image_label.setAlignment(Qt.AlignCenter)
    self.image_label.setStyleSheet("border: 2px dashed #aaa;")
    self.image_label.setScaledContents(True)

    main_layout.addWidget(left_panel)
    main_layout.addWidget(self.image_label, 1)

    self.btn_load.clicked.connect(self.load_image)
    self.btn_fullscreen.clicked.connect(self.capture_fullscreen)
    self.btn_select_area.clicked.connect(self.take_screenshot)
    self.btn_run_ocr.clicked.connect(self.run_ocr)

@Slot()
def take_screenshot(self):
    self.hide()
    time.sleep(0.2)

    self.screenshot_window = ScreenshotWindow()
    self.screenshot_window.screenshot_taken.connect(self.process_screenshot)
    self.screenshot_window.show()

@Slot(QPixmap)
def process_screenshot(self, pixmap: QPixmap):
    self.show()
    if not pixmap.isNull():
        self.current_image_input = pixmap.toImage()
        self.display_image(pixmap)
        self.result_text.setText("截图成功!点击“开始识别”进行处理。")
        self.run_ocr()
    else:
        self.result_text.setText("截图操作已取消。")

def get_image_as_numpy(self, image_input):
    if image_input is None: return None

    image_to_process = None
    if isinstance(image_input, str):
        image_to_process = Image.open(image_input)
    elif isinstance(image_input, Image.Image):
        image_to_process = image_input
    elif isinstance(image_input, QImage):
        image_to_process = ImageQt.fromqimage(image_input)
    elif isinstance(image_input, QPixmap):
        image_to_process = ImageQt.fromqpixmap(image_input)

    if image_to_process:
        return np.array(image_to_process.convert('RGB'))
    return None

def reconstruct_layout(self, data_dict):
    """
    【【【 全新智能排版算法 】】】
    根据文本框坐标智能判断横排/竖排并重构原始排版
    """
    if not (data_dict and 'res' in data_dict and data_dict['res'].get('rec_polys')):
        return ""

    boxes = data_dict['res']['rec_polys']
    texts = data_dict['res']['rec_texts']

    if not texts:
        return ""

    items = []
    for box, text in zip(boxes, texts):
        x_coords = [p[0] for p in box]
        y_coords = [p[1] for p in box]
        width = max(x_coords) - min(x_coords)
        height = max(y_coords) - min(y_coords)
        items.append({'text': text, 'x': min(x_coords), 'y': min(y_coords), 'w': width, 'h': height})

    if not items:
        return ""

    # 通过平均宽高比判断是横排还是竖排
    avg_wh_ratio = sum(item['w'] / item['h'] for item in items) / len(items)
    is_vertical = avg_wh_ratio < 1.0

    if not is_vertical:
        # --- 横排逻辑 ---
        items.sort(key=lambda item: (item['y'], item['x']))
        lines = []
        if items:
            current_line = [items[0]]
            avg_height = sum(item['h'] for item in items) / len(items)
            for item in items[1:]:
                if abs(item['y'] - current_line[-1]['y']) > avg_height * 0.7:
                    lines.append(current_line)
                    current_line = [item]
                else:
                    current_line.append(item)
            lines.append(current_line)

        final_text = ""
        for line in lines:
            line.sort(key=lambda item: item['x'])
            final_text += " ".join([item['text'] for item in line]) + "\n"
        return final_text
    else:
        # --- 竖排逻辑 ---
        items.sort(key=lambda item: (item['x'], item['y']))  # 按x坐标排序,得到列
        columns = []
        if items:
            current_column = [items[0]]
            avg_width = sum(item['w'] for item in items) / len(items)
            for item in items[1:]:
                # 如果下一个框的x坐标离当前列太远,则开启新列
                if abs(item['x'] - current_column[-1]['x']) > avg_width * 2.0:
                    columns.append(current_column)
                    current_column = [item]
                else:
                    current_column.append(item)
            columns.append(current_column)

        # 从右到左排列各列
        columns.sort(key=lambda col: -col[0]['x'])

        final_text = ""
        max_rows = max(len(col) for col in columns) if columns else 0
        for i in range(max_rows):
            row_texts = []
            for col in columns:
                row_texts.append(col[i]['text'] if i < len(col) else "  ")
            final_text += " ".join(row_texts) + "\n"
        return final_text

def run_ocr(self):
    img_np = self.get_image_as_numpy(self.current_image_input)
    if img_np is None:
        self.result_text.setText("错误:图像数据为空,请重新加载或截图!")
        return

    self.result_text.setText("正在使用高精度模型识别中...")
    QApplication.processEvents()

    try:
        predict_params = {
            'use_doc_orientation_classify': self.check_doc_orientation.isChecked(),
            'use_doc_unwarping': self.check_doc_unwarping.isChecked(),
            'use_textline_orientation': self.check_textline_orientation.isChecked(),
            'text_det_limit_type': 'min' if self.radio_limit_type_min.isChecked() else 'max',
            'text_det_limit_side_len': self.spin_limit_side_len.value(),
            'text_det_thresh': self.spin_det_thresh.value(),
            'text_det_box_thresh': self.spin_box_thresh.value(),
            'text_det_unclip_ratio': self.spin_unclip_ratio.value(),
            'text_rec_score_thresh': self.spin_rec_score_thresh.value()
        }
        results = self.ocr_engine.predict(img_np, **predict_params)

        if not results or not results[0]:
            self.result_text.setText("未识别到任何内容。")
            return

        result_obj = results[0]
        data_dict = result_obj.json

        formatted_text = self.reconstruct_layout(data_dict)
        self.result_text.setText(formatted_text)

        vis_image_dict = result_obj.img
        vis_pil_image = vis_image_dict.get('ocr_res_img')

        if vis_pil_image:
            self.display_image(vis_pil_image.toqpixmap())

    except Exception as e:
        self.result_text.setText(f"识别出错:\n{e}")
        logging.exception("OCR Error")

def load_image(self, file_path=None):
    if not file_path:
        file_path, _ = QFileDialog.getOpenFileName(self, "选择图片文件", "", "图片文件 (*.png *.jpg *.jpeg *.bmp)")
    if file_path:
        self.current_image_input = file_path
        self.display_image(QPixmap(file_path))
        self.result_text.clear()

def capture_fullscreen(self):
    self.hide()
    QTimer.singleShot(200, self._internal_capture_fullscreen)

def _internal_capture_fullscreen(self):
    screen = QGuiApplication.primaryScreen()
    screenshot = screen.grabWindow(0)
    self.show()
    self.current_image_input = screenshot.toImage()
    self.display_image(screenshot)
    self.result_text.clear()
    self.run_ocr()

def display_image(self, pixmap_obj):
    self.image_label.setPixmap(pixmap_obj)

--- 全局热键监听 ---

class HotkeyListener(threading.Thread):
def init(self, app_window):
super().init()
self.daemon = True
self.app_window = app_window
self.hotkeys = {
'++x': self.on_activate_screenshot
}

def on_activate_screenshot(self):
    print("快捷键 (Ctrl+Alt+X) 已触发!")
    self.app_window.trigger_screenshot_signal.emit()

def run(self):
    with keyboard.GlobalHotKeys(self.hotkeys) as listener:
        listener.join()

if name == "main":
app = QApplication(sys.argv)
window = OCRApp()
window.show()

hotkey_thread = HotkeyListener(window)
hotkey_thread.start()

sys.exit(app.exec())
Image

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions