Skip to main content
Glama
batch_ocr.py13.7 kB
#!/usr/bin/env python3 """批量OCR处理脚本,支持重试、分批处理和汇总报告。 功能: - 批量处理图片目录中的所有图片 - 自动重试失败的图片 - 分批处理,避免服务负载过高 - 生成详细的处理报告 - 支持断点续传(跳过已处理的图片) """ import sys import json import time import argparse from pathlib import Path from typing import List, Dict, Optional, Tuple from datetime import datetime import traceback # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from scripts.common import setup_script class BatchOCRProcessor: """批量OCR处理器,支持重试和分批处理。""" def __init__( self, image_dir: Path, output_dir: Optional[Path] = None, engine: str = "paddleocr", batch_size: int = 2, max_retries: int = 3, retry_delay: float = 2.0, skip_existing: bool = True, lang: str = "ch" ): """初始化批量处理器。 Args: image_dir: 图片目录 output_dir: 输出目录(默认:image_dir/ocr_results) engine: OCR引擎类型 batch_size: 每批处理的图片数量 max_retries: 最大重试次数 retry_delay: 重试延迟(秒) skip_existing: 是否跳过已处理的图片 lang: 语言代码 """ self.image_dir = Path(image_dir).resolve() self.output_dir = output_dir or (self.image_dir / "ocr_results") self.output_dir.mkdir(parents=True, exist_ok=True) self.engine = engine self.batch_size = batch_size self.max_retries = max_retries self.retry_delay = retry_delay self.skip_existing = skip_existing self.lang = lang # 统计信息 self.stats = { "total": 0, "success": 0, "failed": 0, "skipped": 0, "retries": 0, "start_time": None, "end_time": None, "errors": [] } # 支持的图片格式 self.image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"} def find_images(self) -> List[Path]: """查找目录中的所有图片文件。""" images = [] for ext in self.image_extensions: images.extend(self.image_dir.glob(f"*{ext}")) images.extend(self.image_dir.glob(f"*{ext.upper()}")) # 排序以确保处理顺序一致 return sorted(images) def is_already_processed(self, image_path: Path) -> bool: """检查图片是否已经处理过。""" if not self.skip_existing: return False json_file = self.output_dir / f"{image_path.stem}_ocr.json" return json_file.exists() def process_image(self, image_path: Path) -> Tuple[bool, Optional[Dict], Optional[str]]: """处理单张图片,带重试机制。 Returns: (success, result_dict, error_message) """ from ocr_mcp_service.ocr_engine import OCREngineFactory from ocr_mcp_service.utils import validate_image for attempt in range(self.max_retries + 1): try: # 验证图片 validate_image(str(image_path)) # 获取引擎并识别 engine = OCREngineFactory.get_engine(self.engine) # 根据引擎类型传递参数 if self.engine == "paddleocr": result = engine.recognize_image(str(image_path), lang=self.lang) else: result = engine.recognize_image(str(image_path)) # 转换为字典 result_dict = result.to_dict() return True, result_dict, None except Exception as e: error_msg = str(e) error_type = type(e).__name__ # 检查是否是连接错误 is_connection_error = ( "Not connected" in error_msg or "Connection" in error_msg or "timeout" in error_msg.lower() or error_type == "TimeoutError" ) if attempt < self.max_retries: # 如果是连接错误,等待更长时间 wait_time = self.retry_delay * (attempt + 1) if is_connection_error else self.retry_delay print(f" ⚠️ 尝试 {attempt + 1}/{self.max_retries + 1} 失败: {error_msg}") print(f" ⏳ 等待 {wait_time:.1f} 秒后重试...") time.sleep(wait_time) self.stats["retries"] += 1 else: # 最后一次尝试失败 return False, None, f"{error_type}: {error_msg}" return False, None, "Max retries exceeded" def save_result(self, image_path: Path, result_dict: Dict): """保存OCR结果。""" base_name = image_path.stem # 保存JSON格式 json_file = self.output_dir / f"{base_name}_ocr.json" with open(json_file, "w", encoding="utf-8") as f: json.dump(result_dict, f, ensure_ascii=False, indent=2) # 保存文本格式 txt_file = self.output_dir / f"{base_name}_ocr.txt" with open(txt_file, "w", encoding="utf-8") as f: f.write(result_dict.get("text", "")) def process_batch(self, images: List[Path]) -> Dict: """处理一批图片。 Returns: 处理结果统计 """ batch_stats = { "success": 0, "failed": 0, "skipped": 0, "errors": [] } for image_path in images: print(f"\n📷 处理: {image_path.name}") # 检查是否已处理 if self.is_already_processed(image_path): print(f" ⏭️ 跳过(已处理)") batch_stats["skipped"] += 1 self.stats["skipped"] += 1 continue # 处理图片 success, result_dict, error_msg = self.process_image(image_path) if success: # 保存结果 self.save_result(image_path, result_dict) text_length = len(result_dict.get("text", "")) boxes_count = len(result_dict.get("boxes", [])) processing_time = result_dict.get("processing_time", 0.0) print(f" ✅ 成功: {text_length}字符, {boxes_count}个文本块, {processing_time:.2f}秒") batch_stats["success"] += 1 self.stats["success"] += 1 else: print(f" ❌ 失败: {error_msg}") batch_stats["failed"] += 1 self.stats["failed"] += 1 self.stats["errors"].append({ "image": image_path.name, "error": error_msg }) return batch_stats def process_all(self): """处理所有图片。""" print("=" * 80) print("批量OCR处理") print("=" * 80) print(f"图片目录: {self.image_dir}") print(f"输出目录: {self.output_dir}") print(f"引擎: {self.engine}") print(f"批次大小: {self.batch_size}") print(f"最大重试: {self.max_retries}") print(f"跳过已处理: {self.skip_existing}") print("=" * 80) # 查找所有图片 images = self.find_images() self.stats["total"] = len(images) self.stats["start_time"] = datetime.now().isoformat() if not images: print("❌ 未找到图片文件") return print(f"\n找到 {len(images)} 张图片") # 分批处理 total_batches = (len(images) + self.batch_size - 1) // self.batch_size for batch_num in range(total_batches): start_idx = batch_num * self.batch_size end_idx = min(start_idx + self.batch_size, len(images)) batch_images = images[start_idx:end_idx] print(f"\n{'=' * 80}") print(f"批次 {batch_num + 1}/{total_batches} ({len(batch_images)} 张图片)") print(f"{'=' * 80}") # 处理批次 batch_stats = self.process_batch(batch_images) print(f"\n批次统计: ✅ {batch_stats['success']} 成功, " f"❌ {batch_stats['failed']} 失败, " f"⏭️ {batch_stats['skipped']} 跳过") # 批次间等待,避免服务负载过高 if batch_num < total_batches - 1: wait_time = 1.0 print(f"\n⏳ 批次间等待 {wait_time} 秒...") time.sleep(wait_time) # 完成统计 self.stats["end_time"] = datetime.now().isoformat() # 生成报告 self.generate_report() def generate_report(self): """生成处理报告。""" print("\n" + "=" * 80) print("处理完成") print("=" * 80) # 计算总时间 if self.stats["start_time"] and self.stats["end_time"]: start = datetime.fromisoformat(self.stats["start_time"]) end = datetime.fromisoformat(self.stats["end_time"]) duration = (end - start).total_seconds() print(f"总耗时: {duration:.1f} 秒") print(f"\n统计:") print(f" 总计: {self.stats['total']}") print(f" ✅ 成功: {self.stats['success']}") print(f" ❌ 失败: {self.stats['failed']}") print(f" ⏭️ 跳过: {self.stats['skipped']}") print(f" 🔄 重试: {self.stats['retries']}") if self.stats["errors"]: print(f"\n失败列表:") for error in self.stats["errors"]: print(f" - {error['image']}: {error['error']}") # 保存报告到JSON文件 report_file = self.output_dir / "batch_report.json" with open(report_file, "w", encoding="utf-8") as f: json.dump(self.stats, f, ensure_ascii=False, indent=2) print(f"\n📄 详细报告已保存: {report_file}") def main(): """主函数。""" parser = argparse.ArgumentParser( description="批量OCR处理脚本,支持重试和分批处理", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 示例: # 处理当前目录的所有图片,每批2张 python scripts/batch_ocr.py . # 处理指定目录,每批3张,最多重试5次 python scripts/batch_ocr.py /path/to/images --batch-size 3 --max-retries 5 # 使用easyocr引擎,不跳过已处理的图片 python scripts/batch_ocr.py . --engine easyocr --no-skip-existing """ ) parser.add_argument( "image_dir", type=str, help="图片目录路径" ) parser.add_argument( "--output-dir", type=str, help="输出目录(默认:image_dir/ocr_results)" ) parser.add_argument( "--engine", choices=["paddleocr", "paddleocr_mcp", "easyocr", "deepseek"], default="paddleocr", help="OCR引擎类型(默认:paddleocr)" ) parser.add_argument( "--batch-size", type=int, default=2, help="每批处理的图片数量(默认:2)" ) parser.add_argument( "--max-retries", type=int, default=3, help="最大重试次数(默认:3)" ) parser.add_argument( "--retry-delay", type=float, default=2.0, help="重试延迟(秒,默认:2.0)" ) parser.add_argument( "--no-skip-existing", action="store_true", help="不跳过已处理的图片" ) parser.add_argument( "--lang", type=str, default="ch", help="语言代码(默认:ch,仅paddleocr)" ) args = parser.parse_args() # 验证图片目录 image_dir = Path(args.image_dir).resolve() if not image_dir.exists(): print(f"❌ 错误: 目录不存在: {image_dir}") sys.exit(1) if not image_dir.is_dir(): print(f"❌ 错误: 不是目录: {image_dir}") sys.exit(1) # 创建处理器 processor = BatchOCRProcessor( image_dir=image_dir, output_dir=Path(args.output_dir).resolve() if args.output_dir else None, engine=args.engine, batch_size=args.batch_size, max_retries=args.max_retries, retry_delay=args.retry_delay, skip_existing=not args.no_skip_existing, lang=args.lang ) # 处理所有图片 try: processor.process_all() except KeyboardInterrupt: print("\n\n⚠️ 用户中断") processor.generate_report() sys.exit(1) except Exception as e: print(f"\n\n❌ 发生错误: {e}") traceback.print_exc() processor.generate_report() sys.exit(1) if __name__ == "__main__": main()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/qiao-925/ocr-mcp-service'

If you have feedback or need assistance with the MCP directory API, please join our Discord server