"""
进度显示工具模块
提供进度条、状态更新等用户体验优化功能
"""
import time
import threading
from typing import Optional, Dict, Any, Callable
from dataclasses import dataclass
from enum import Enum
import json
class TaskStatus(Enum):
"""任务状态枚举"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class TaskInfo:
"""任务信息"""
id: str
name: str
status: TaskStatus
progress: float = 0.0
message: str = ""
start_time: Optional[float] = None
end_time: Optional[float] = None
metadata: Dict[str, Any] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@property
def elapsed_time(self) -> float:
"""获取已用时间(秒)"""
if not self.start_time:
return 0.0
end = self.end_time or time.time()
return end - self.start_time
@property
def is_finished(self) -> bool:
"""任务是否已完成"""
return self.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]
class ProgressTracker:
"""进度跟踪器"""
def __init__(self):
self.tasks: Dict[str, TaskInfo] = {}
self.callbacks: Dict[str, Callable] = {}
self.lock = threading.Lock()
def create_task(self, task_id: str, name: str, metadata: Dict[str, Any] = None) -> TaskInfo:
"""创建新任务"""
with self.lock:
if task_id in self.tasks:
raise ValueError(f"Task {task_id} already exists")
task = TaskInfo(
id=task_id,
name=name,
status=TaskStatus.PENDING,
metadata=metadata or {}
)
self.tasks[task_id] = task
return task
def start_task(self, task_id: str) -> None:
"""开始任务"""
with self.lock:
if task_id not in self.tasks:
raise ValueError(f"Task {task_id} not found")
task = self.tasks[task_id]
task.status = TaskStatus.RUNNING
task.start_time = time.time()
self._notify_callbacks(task_id)
def update_progress(self, task_id: str, progress: float, message: str = "") -> None:
"""更新任务进度"""
with self.lock:
if task_id not in self.tasks:
raise ValueError(f"Task {task_id} not found")
task = self.tasks[task_id]
task.progress = max(0.0, min(100.0, progress))
if message:
task.message = message
self._notify_callbacks(task_id)
def complete_task(self, task_id: str, message: str = "完成") -> None:
"""完成任务"""
with self.lock:
if task_id not in self.tasks:
raise ValueError(f"Task {task_id} not found")
task = self.tasks[task_id]
task.status = TaskStatus.COMPLETED
task.progress = 100.0
task.message = message
task.end_time = time.time()
self._notify_callbacks(task_id)
def fail_task(self, task_id: str, error_message: str) -> None:
"""任务失败"""
with self.lock:
if task_id not in self.tasks:
raise ValueError(f"Task {task_id} not found")
task = self.tasks[task_id]
task.status = TaskStatus.FAILED
task.message = error_message
task.end_time = time.time()
self._notify_callbacks(task_id)
def cancel_task(self, task_id: str, message: str = "已取消") -> None:
"""取消任务"""
with self.lock:
if task_id not in self.tasks:
raise ValueError(f"Task {task_id} not found")
task = self.tasks[task_id]
task.status = TaskStatus.CANCELLED
task.message = message
task.end_time = time.time()
self._notify_callbacks(task_id)
def register_callback(self, task_id: str, callback: Callable[[TaskInfo], None]) -> None:
"""注册进度回调函数"""
with self.lock:
self.callbacks[task_id] = callback
def unregister_callback(self, task_id: str) -> None:
"""取消注册进度回调函数"""
with self.lock:
self.callbacks.pop(task_id, None)
def get_task(self, task_id: str) -> Optional[TaskInfo]:
"""获取任务信息"""
with self.lock:
return self.tasks.get(task_id)
def get_all_tasks(self) -> Dict[str, TaskInfo]:
"""获取所有任务信息"""
with self.lock:
return self.tasks.copy()
def clear_completed_tasks(self, older_than: float = 3600.0) -> int:
"""清理已完成的任务"""
current_time = time.time()
to_remove = []
with self.lock:
for task_id, task in self.tasks.items():
if task.is_finished and task.end_time and (current_time - task.end_time) > older_than:
to_remove.append(task_id)
for task_id in to_remove:
del self.tasks[task_id]
self.callbacks.pop(task_id, None)
return len(to_remove)
def _notify_callbacks(self, task_id: str) -> None:
"""通知回调函数"""
task = self.tasks.get(task_id)
callback = self.callbacks.get(task_id)
if task and callback:
try:
callback(task)
except (RuntimeError, ValueError) as e:
print(f"Progress callback error for task {task_id}: {e}")
class ProgressBar:
"""进度条显示类"""
def __init__(self, width: int = 50, show_percentage: bool = True, show_time: bool = True):
self.width = width
self.show_percentage = show_percentage
self.show_time = show_time
def render(self, task: TaskInfo) -> str:
"""渲染进度条"""
# 计算进度条长度
filled_length = int(self.width * task.progress / 100.0)
bar = "█" * filled_length + "░" * (self.width - filled_length)
parts = [f"{task.name}"]
parts.append(f"[{bar}]")
if self.show_percentage:
parts.append(f"{task.progress:.1f}%")
if task.message:
parts.append(f"- {task.message}")
if self.show_time and task.start_time:
elapsed = task.elapsed_time
if not task.is_finished and task.progress > 0:
# 估算剩余时间
total_estimated = elapsed / (task.progress / 100.0)
remaining = total_estimated - elapsed
parts.append(f"(剩余: {remaining:.1f}s)")
else:
parts.append(f"(用时: {elapsed:.1f}s)")
return " ".join(parts)
def print_progress(self, task: TaskInfo) -> None:
"""打印进度条"""
print("\r" + self.render(task), end="", flush=True)
if task.is_finished:
print() # 换行
# 全局进度跟踪器实例
global_tracker = ProgressTracker()
def create_task(task_id: str, name: str, metadata: Dict[str, Any] = None) -> TaskInfo:
"""创建任务的便捷函数"""
return global_tracker.create_task(task_id, name, metadata)
def start_task(task_id: str) -> None:
"""开始任务的便捷函数"""
global_tracker.start_task(task_id)
def update_progress(task_id: str, progress: float, message: str = "") -> None:
"""更新进度的便捷函数"""
global_tracker.update_progress(task_id, progress, message)
def complete_task(task_id: str, message: str = "完成") -> None:
"""完成任务便捷函数"""
global_tracker.complete_task(task_id, message)
def fail_task(task_id: str, error_message: str) -> None:
"""任务失败便捷函数"""
global_tracker.fail_task(task_id, error_message)
def progress_callback(task_id: str, callback: Callable[[TaskInfo], None]) -> None:
"""注册进度回调便捷函数"""
global_tracker.register_callback(task_id, callback)
class ProgressContext:
"""进度上下文管理器"""
def __init__(self, task_id: str, name: str, show_progress: bool = True):
self.task_id = task_id
self.name = name
self.show_progress = show_progress
self.progress_bar = ProgressBar() if show_progress else None
def __enter__(self):
create_task(self.task_id, self.name)
start_task(self.task_id)
if self.show_progress:
progress_callback(self.task_id, self._on_progress_update)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
complete_task(self.task_id)
else:
fail_task(self.task_id, str(exc_val))
def update(self, progress: float, message: str = "") -> None:
"""更新进度"""
update_progress(self.task_id, progress, message)
def _on_progress_update(self, task: TaskInfo) -> None:
"""进度更新回调"""
if self.progress_bar:
self.progress_bar.print_progress(task)
def task_progress(task_id: str, name: str, show_progress: bool = True) -> ProgressContext:
"""装饰器:为函数添加进度跟踪"""
def decorator(func):
def wrapper(*args, **kwargs):
with ProgressContext(task_id, name, show_progress):
return func(*args, **kwargs)
return wrapper
return decorator
def get_progress_summary() -> Dict[str, Any]:
"""获取进度摘要"""
tasks = global_tracker.get_all_tasks()
summary = {
"total_tasks": len(tasks),
"running": len([t for t in tasks.values() if t.status == TaskStatus.RUNNING]),
"completed": len([t for t in tasks.values() if t.status == TaskStatus.COMPLETED]),
"failed": len([t for t in tasks.values() if t.status == TaskStatus.FAILED]),
"pending": len([t for t in tasks.values() if t.status == TaskStatus.PENDING]),
"tasks": {}
}
for task_id, task in tasks.items():
summary["tasks"][task_id] = {
"name": task.name,
"status": task.status.value,
"progress": task.progress,
"message": task.message,
"elapsed_time": task.elapsed_time,
"is_finished": task.is_finished
}
return summary
def export_progress_json(filepath: str) -> None:
"""导出进度信息到JSON文件"""
summary = get_progress_summary()
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)