Skip to main content
Glama
queue.py27.8 kB
# -*- coding: utf-8 -*- """ Redis队列管理器 - 基于Redis的持久化发布队列 """ import asyncio import time import uuid import ujson from enum import Enum from dataclasses import dataclass from typing import Dict, List, Optional, Any, Callable, Awaitable import redis.asyncio as aioredis from pydantic import BaseModel, Field from app.config.settings import global_settings from app.providers.logger import get_logger logger = get_logger() class TaskStatus(Enum): """任务状态""" PENDING = "pending" QUEUED = "queued" PROCESSING = "processing" SUCCESS = "success" FAILED = "failed" CANCELLED = "cancelled" class TaskType(Enum): """任务类型""" IMAGE = "image" VIDEO = "video" @dataclass class PublishStrategy: """发布策略配置""" min_interval: int = 30 # 最小发布间隔(秒) max_concurrent: int = 1 # 最大并发数 retry_count: int = 3 # 重试次数 retry_delay: int = 60 # 重试延迟(秒) daily_limit: int = 50 # 每日发布限制 hourly_limit: int = 10 # 每小时发布限制 class PublishTask(BaseModel): """发布任务""" task_id: str = Field(..., description="任务ID") platform: str = Field(..., description="发布平台") task_type: TaskType = Field(..., description="任务类型") payload: Dict[str, Any] = Field(..., description="发布内容") status: TaskStatus = Field(default=TaskStatus.PENDING, description="任务状态") priority: int = Field(default=0, description="优先级,数字越大优先级越高") created_at: float = Field(default_factory=time.time, description="创建时间") queued_at: Optional[float] = Field(None, description="入队时间") started_at: Optional[float] = Field(None, description="开始时间") completed_at: Optional[float] = Field(None, description="完成时间") retry_count: int = Field(default=0, description="重试次数") progress: int = Field(default=0, description="进度百分比") message: str = Field(default="", description="状态消息") error_detail: Optional[str] = Field(None, description="错误详情") result: Optional[Dict[str, Any]] = Field(None, description="执行结果") def model_dump(self, **kwargs) -> Dict[str, Any]: """重写model_dump确保枚举值被正确序列化""" data = super().model_dump(**kwargs) # 确保枚举类型被转换为值 if isinstance(data.get('status'), TaskStatus): data['status'] = data['status'].value if isinstance(data.get('task_type'), TaskType): data['task_type'] = data['task_type'].value return data class Config: use_enum_values = True def config_to_strategy(config) -> PublishStrategy: """将配置转换为策略对象""" return PublishStrategy( min_interval=config.min_interval, max_concurrent=config.max_concurrent, retry_count=config.retry_count, retry_delay=config.retry_delay, daily_limit=config.daily_limit, hourly_limit=config.hourly_limit ) class RedisQueuerManager: """Redis队列实例管理器""" _instances: Dict[str, aioredis.Redis] = {} @classmethod def get_queue_redis(cls, platform: str) -> aioredis.Redis: """获取指定平台的Redis实例""" if platform not in cls._instances: pool = aioredis.ConnectionPool( username=global_settings.redis.user, host=global_settings.redis.host, port=global_settings.redis.port, password=global_settings.redis.password, db=global_settings.redis.db + 1, # 使用不同的db避免冲突 decode_responses=True, max_connections=100, ) redis_client = aioredis.Redis(connection_pool=pool) cls._instances[platform] = redis_client return cls._instances[platform] @classmethod async def close_all(cls): """关闭所有Redis实例""" for redis_instance in cls._instances.values(): await redis_instance.close() cls._instances.clear() class PlatformQueuer: """单个平台的队列处理器""" def __init__(self, platform: str, strategy: PublishStrategy): self.platform = platform self.strategy = strategy self.redis = RedisQueuerManager.get_queue_redis(platform) # Redis键前缀 self.key_prefix = f"publish_queue:{platform}" self.queue_key = f"{self.key_prefix}:queue" self.pending_key = f"{self.key_prefix}:pending" self.processing_key = f"{self.key_prefix}:processing" self.tasks_key = f"{self.key_prefix}:tasks" self.stats_key = f"{self.key_prefix}:stats" self.history_key = f"{self.key_prefix}:history" # 状态控制 self.is_running = False self.worker_tasks: List[asyncio.Task] = [] self.executor: Optional[Callable] = None def set_executor(self, executor: Callable[[PublishTask], Awaitable[Dict[str, Any]]]): """设置任务执行器""" self.executor = executor async def start(self) -> None: """启动队列处理器""" if self.is_running: return # 恢复处理中的任务 await self._recover_processing_tasks() self.is_running = True # 启动工作进程 for i in range(self.strategy.max_concurrent): task = asyncio.create_task(self._worker_loop(f"worker_{i}")) self.worker_tasks.append(task) logger.info(f"[Queuer] {self.platform} 队列已启动,工作进程数: {self.strategy.max_concurrent}") async def stop(self) -> None: """停止队列处理器""" if not self.is_running: return self.is_running = False # 取消所有工作任务 for task in self.worker_tasks: task.cancel() if self.worker_tasks: await asyncio.gather(*self.worker_tasks, return_exceptions=True) self.worker_tasks.clear() logger.info(f"[Queuer] {self.platform} 队列已停止") async def add_task(self, task: PublishTask) -> None: """添加任务到队列""" # 检查频率限制 if not await self._check_rate_limits(): raise RuntimeError("发布频率超出限制") # 保存任务数据 task.status = TaskStatus.QUEUED task.queued_at = time.time() task_data = task.model_dump() task_json = ujson.dumps(task_data) await self.redis.hset(self.tasks_key, task.task_id, task_json) # 加入优先级队列 score = task.priority * 1000000 + (2147483647 - int(task.created_at)) await self.redis.zadd(self.queue_key, {task.task_id: score}) logger.info(f"[Queuer] {self.platform} 任务已入队: {task.task_id}") async def add_pending(self, task: PublishTask) -> None: """添加任务到待审核列表(不进入发布队列)""" task.status = TaskStatus.PENDING task.queued_at = None task_data = task.model_dump() task_json = ujson.dumps(task_data) await self.redis.hset(self.tasks_key, task.task_id, task_json) # 使用创建时间作为排序,最新在前 await self.redis.zadd(self.pending_key, {task.task_id: task.created_at}) logger.info(f"[Queuer] {self.platform} 任务进入待审核: {task.task_id}") async def get_task_status(self, task_id: str) -> Optional[PublishTask]: """获取任务状态""" task_data = await self.redis.hget(self.tasks_key, task_id) if not task_data: return None task_dict = ujson.loads(task_data) return PublishTask.model_validate(task_dict) async def get_stats(self) -> Dict[str, Any]: """获取队列统计""" queue_size = await self.redis.zcard(self.queue_key) processing_count = await self.redis.scard(self.processing_key) pending_count = await self.redis.zcard(self.pending_key) # 发布历史统计 now = time.time() hour_ago = now - 3600 day_ago = now - 86400 hourly_count = await self.redis.zcount(self.history_key, hour_ago, now) daily_count = await self.redis.zcount(self.history_key, day_ago, now) # 最后发布时间 last_publish = await self.redis.get(f"{self.stats_key}:last_publish") last_publish_time = float(last_publish) if last_publish else 0 return { "platform": self.platform, "queue_size": queue_size, "processing_count": processing_count, "pending_count": pending_count, "daily_published": daily_count, "hourly_published": hourly_count, "last_publish_time": last_publish_time, "is_running": self.is_running, "worker_count": len(self.worker_tasks) } async def _worker_loop(self, worker_id: str) -> None: """工作循环""" logger.info(f"[Queuer] {self.platform} 工作进程 {worker_id} 启动") while self.is_running: try: # 检查发布间隔 can_publish = await self._can_publish_now() if not can_publish: logger.info(f"[Queuer] {self.platform} {worker_id} 等待发布间隔") await asyncio.sleep(1) continue # 从队列获取任务 task_id = await self._pop_task() if not task_id: # logger.info(f"[Queuer] {self.platform} {worker_id} 队列为空,等待...") await asyncio.sleep(1) continue logger.info(f"[Queuer] {self.platform} {worker_id} 获取到任务: {task_id}") # 处理任务 await self._process_task(task_id, worker_id) except asyncio.CancelledError: break except Exception as exc: logger.error(f"[Queuer] {self.platform} 工作进程 {worker_id} 错误: {exc}") await asyncio.sleep(5) logger.info(f"[Queuer] {self.platform} 工作进程 {worker_id} 停止") async def _pop_task(self) -> Optional[str]: """原子性弹出任务""" lua_script = """ local task_id = redis.call('ZPOPMIN', KEYS[1]) if next(task_id) then redis.call('SADD', KEYS[2], task_id[1]) return task_id[1] end return nil """ result = await self.redis.eval(lua_script, 2, self.queue_key, self.processing_key) return result async def _process_task(self, task_id: str, worker_id: str) -> None: """处理单个任务""" try: # 获取任务数据 task_data = await self.redis.hget(self.tasks_key, task_id) if not task_data: await self.redis.srem(self.processing_key, task_id) return task_dict = ujson.loads(task_data) task = PublishTask.model_validate(task_dict) # 更新处理状态 task.status = TaskStatus.PROCESSING task.started_at = time.time() task.message = f"工作进程 {worker_id} 开始处理" await self._update_task(task) logger.info(f"[Queuer] {self.platform} {worker_id} 开始处理任务: {task_id}") # 执行发布任务 if self.executor: result = await self.executor(task) else: raise RuntimeError("未设置任务执行器") # 更新成功状态 task.status = TaskStatus.SUCCESS task.completed_at = time.time() task.progress = 100 task.result = result task.message = "发布成功" await self._update_task(task) # 记录发布历史 await self._record_publish(time.time()) logger.info(f"[Queuer] {self.platform} {worker_id} 任务处理成功: {task_id}") except Exception as exc: # 处理失败,检查重试 task.retry_count += 1 if task.retry_count <= self.strategy.retry_count: # 重新入队重试 task.status = TaskStatus.PENDING task.error_detail = str(exc) task.message = f"第{task.retry_count}次重试" await self._update_task(task) # 延迟重试 retry_time = time.time() + self.strategy.retry_delay await self.redis.zadd(self.queue_key, {task_id: retry_time}) logger.warning(f"[Queuer] {self.platform} 任务重试: {task_id}, 重试次数: {task.retry_count}") else: # 标记失败 task.status = TaskStatus.FAILED task.completed_at = time.time() task.error_detail = str(exc) task.message = f"发布失败: {str(exc)}" await self._update_task(task) logger.error(f"[Queuer] {self.platform} 任务失败: {task_id}, 错误: {exc}") finally: # 从处理中集合移除 await self.redis.srem(self.processing_key, task_id) async def _update_task(self, task: PublishTask) -> None: """更新任务数据""" task_data = task.model_dump() task_json = ujson.dumps(task_data) await self.redis.hset(self.tasks_key, task.task_id, task_json) async def list_pending(self, limit: int = 50, offset: int = 0) -> List[PublishTask]: """列出待审核任务(按创建时间倒序)""" # zrange 默认从小到大;我们使用 zrevrange 取最新 ids = await self.redis.zrevrange(self.pending_key, offset, offset + limit - 1) if not ids: return [] if len(ids) == 1: data = [await self.redis.hget(self.tasks_key, ids[0])] else: data = await self.redis.hmget(self.tasks_key, *ids) tasks: List[PublishTask] = [] for raw in data: if not raw: continue try: tasks.append(PublishTask.model_validate(ujson.loads(raw))) except Exception: continue return tasks async def approve(self, task_id: str) -> None: """审核通过:从待审核移动到发布队列""" # 先从 pending 集合移除 await self.redis.zrem(self.pending_key, task_id) # 读取任务并入队 task_data = await self.redis.hget(self.tasks_key, task_id) if not task_data: raise ValueError("任务不存在") task = PublishTask.model_validate(ujson.loads(task_data)) await self.add_task(task) async def reject(self, task_id: str, reason: str | None = None) -> None: """审核拒绝:标记任务为取消并从待审核移除""" task_data = await self.redis.hget(self.tasks_key, task_id) if not task_data: return task = PublishTask.model_validate(ujson.loads(task_data)) task.status = TaskStatus.CANCELLED task.completed_at = time.time() task.message = reason or "审核拒绝" await self._update_task(task) await self.redis.zrem(self.pending_key, task_id) async def list_tasks(self, limit: int = 50) -> List[PublishTask]: """列出平台所有任务(按创建时间倒序,限制数量)""" all_map = await self.redis.hgetall(self.tasks_key) items: List[PublishTask] = [] for _, raw in all_map.items(): try: items.append(PublishTask.model_validate(ujson.loads(raw))) except Exception: continue items.sort(key=lambda x: x.created_at or 0, reverse=True) return items[:limit] async def update_pending(self, task_id: str, changes: Dict[str, Any]) -> PublishTask: """更新待审核任务的负载或附加字段""" # 必须是待审核集合中的任务 in_pending = await self.redis.zscore(self.pending_key, task_id) if in_pending is None: raise ValueError("任务不在待审核列表") task_data = await self.redis.hget(self.tasks_key, task_id) if not task_data: raise ValueError("任务不存在") task = PublishTask.model_validate(ujson.loads(task_data)) if task.status != TaskStatus.PENDING: raise ValueError("任务状态非待审核,无法编辑") # 合并变更到 payload payload = task.payload or {} if not isinstance(payload, dict): payload = {} for k, v in (changes or {}).items(): payload[k] = v task.payload = payload task.message = "待审核(已编辑)" await self._update_task(task) return task async def update_queued(self, task_id: str, changes: Dict[str, Any]) -> PublishTask: """更新排队中任务的负载""" # 检查任务是否在队列中 in_queue = await self.redis.zscore(self.queue_key, task_id) if in_queue is None: raise ValueError("任务不在发布队列") task_data = await self.redis.hget(self.tasks_key, task_id) if not task_data: raise ValueError("任务不存在") task = PublishTask.model_validate(ujson.loads(task_data)) if task.status != TaskStatus.QUEUED.value: raise ValueError(f"任务状态非排队中,无法编辑。当前状态: {task.status}") # 合并变更到 payload payload = task.payload or {} if not isinstance(payload, dict): payload = {} for k, v in (changes or {}).items(): payload[k] = v task.payload = payload task.message = "排队中(已编辑)" await self._update_task(task) logger.info(f"[Queuer] {self.platform} 更新排队任务: {task_id}") return task async def _record_publish(self, timestamp: float) -> None: """记录发布时间""" # 更新最后发布时间 await self.redis.set(f"{self.stats_key}:last_publish", timestamp) # 添加到发布历史 await self.redis.zadd(self.history_key, {str(uuid.uuid4()): timestamp}) # 清理1天前的记录 day_ago = timestamp - 86400 await self.redis.zremrangebyscore(self.history_key, 0, day_ago) async def _can_publish_now(self) -> bool: """检查发布间隔""" last_publish = await self.redis.get(f"{self.stats_key}:last_publish") if not last_publish: return True elapsed = time.time() - float(last_publish) return elapsed >= self.strategy.min_interval async def _check_rate_limits(self) -> bool: """检查频率限制""" now = time.time() # 检查每小时限制 hour_ago = now - 3600 hourly_count = await self.redis.zcount(self.history_key, hour_ago, now) if hourly_count >= self.strategy.hourly_limit: return False # 检查每日限制 day_ago = now - 86400 daily_count = await self.redis.zcount(self.history_key, day_ago, now) if daily_count >= self.strategy.daily_limit: return False return True async def _recover_processing_tasks(self) -> None: """恢复处理中的任务""" processing_tasks = await self.redis.smembers(self.processing_key) if processing_tasks: logger.info(f"[Queuer] {self.platform} 恢复 {len(processing_tasks)} 个处理中的任务") # 重新入队 for task_id in processing_tasks: score = int(time.time()) await self.redis.zadd(self.queue_key, {task_id: score}) # 清空处理中集合 await self.redis.delete(self.processing_key) class PublishQueue: """发布队列管理器 - 主入口""" def __init__(self): self.platform_queuers: Dict[str, PlatformQueuer] = {} # 从配置文件加载策略 publish_config = global_settings.publish self.default_strategies: Dict[str, PublishStrategy] = { "xhs": config_to_strategy(publish_config.xhs), } def register_platform( self, platform: str, executor: Callable[[PublishTask], Awaitable[Dict[str, Any]]], strategy: Optional[PublishStrategy] = None ) -> None: """注册平台发布器""" if strategy is None: strategy = self.default_strategies.get(platform, PublishStrategy()) queuer = PlatformQueuer(platform, strategy) queuer.set_executor(executor) self.platform_queuers[platform] = queuer logger.info(f"[Queuer] 注册平台: {platform}") async def start_all(self) -> None: """启动所有队列""" for queuer in self.platform_queuers.values(): await queuer.start() logger.info("[Queuer] 所有发布队列已启动") async def stop_all(self) -> None: """停止所有队列""" for queuer in self.platform_queuers.values(): await queuer.stop() await RedisQueuerManager.close_all() logger.info("[Queuer] 所有发布队列已停止") async def submit_task(self, task: PublishTask) -> None: """提交发布任务""" if task.platform not in self.platform_queuers: raise ValueError(f"未支持的平台: {task.platform}") queuer = self.platform_queuers[task.platform] await queuer.add_task(task) async def submit_task_pending(self, task: PublishTask) -> None: """提交待审核任务(不立即入队)""" if task.platform not in self.platform_queuers: raise ValueError(f"未支持的平台: {task.platform}") queuer = self.platform_queuers[task.platform] await queuer.add_pending(task) async def get_task_status(self, task_id: str, platform: str) -> Optional[PublishTask]: """获取任务状态""" if platform not in self.platform_queuers: return None return await self.platform_queuers[platform].get_task_status(task_id) async def update_platform_strategy(self, platform: str, strategy: PublishStrategy) -> None: """更新平台发布策略""" if platform not in self.platform_queuers: raise ValueError(f"平台 {platform} 未注册") # 停止当前队列 queuer = self.platform_queuers[platform] executor = queuer.executor await queuer.stop() # 创建新的队列处理器 new_queuer = PlatformQueuer(platform, strategy) new_queuer.set_executor(executor) self.platform_queuers[platform] = new_queuer # 启动新队列 await new_queuer.start() logger.info(f"[Queuer] 更新平台 {platform} 策略: {strategy}") def get_platform_strategy(self, platform: str) -> Optional[PublishStrategy]: """获取平台发布策略""" if platform not in self.platform_queuers: return None return self.platform_queuers[platform].strategy async def get_all_stats(self) -> Dict[str, Any]: """获取所有队列统计""" stats = { "total_platforms": len(self.platform_queuers), "platforms": {} } for platform, queuer in self.platform_queuers.items(): platform_stats = await queuer.get_stats() # 添加策略信息 strategy = queuer.strategy platform_stats["strategy"] = { "min_interval": strategy.min_interval, "max_concurrent": strategy.max_concurrent, "retry_count": strategy.retry_count, "retry_delay": strategy.retry_delay, "daily_limit": strategy.daily_limit, "hourly_limit": strategy.hourly_limit } stats["platforms"][platform] = platform_stats return stats # ----- 审核相关的便捷方法 ----- async def list_pending_tasks(self, platform: str, limit: int = 50, offset: int = 0) -> List[PublishTask]: if platform not in self.platform_queuers: return [] return await self.platform_queuers[platform].list_pending(limit=limit, offset=offset) async def approve_task(self, platform: str, task_id: str) -> None: if platform not in self.platform_queuers: raise ValueError(f"未支持的平台: {platform}") await self.platform_queuers[platform].approve(task_id) async def reject_task(self, platform: str, task_id: str, reason: Optional[str] = None) -> None: if platform not in self.platform_queuers: raise ValueError(f"未支持的平台: {platform}") await self.platform_queuers[platform].reject(task_id, reason) async def list_tasks(self, platform: str, limit: int = 50) -> List[PublishTask]: if platform not in self.platform_queuers: return [] return await self.platform_queuers[platform].list_tasks(limit=limit) async def update_pending_task(self, platform: str, task_id: str, changes: Dict[str, Any]) -> PublishTask: if platform not in self.platform_queuers: raise ValueError(f"未支持的平台: {platform}") return await self.platform_queuers[platform].update_pending(task_id, changes) async def update_queued_task(self, platform: str, task_id: str, changes: Dict[str, Any]) -> PublishTask: """更新排队中的任务""" if platform not in self.platform_queuers: raise ValueError(f"未支持的平台: {platform}") return await self.platform_queuers[platform].update_queued(task_id, changes) async def delete_task(self, platform: str, task_id: str) -> None: """删除任务(只能删除已完成或失败的任务)""" if platform not in self.platform_queuers: raise ValueError(f"未支持的平台: {platform}") queuer = self.platform_queuers[platform] # 获取任务状态 task = await queuer.get_task_status(task_id) if not task: raise ValueError("任务不存在") # 只允许删除成功或失败的任务 if task.status not in [TaskStatus.SUCCESS.value, TaskStatus.FAILED.value]: raise ValueError(f"只能删除已完成或失败的任务,当前状态: {task.status}") # 从Redis中删除任务数据 await queuer.redis.hdel(queuer.tasks_key, task_id) # 确保从所有集合中移除 await queuer.redis.zrem(queuer.queue_key, task_id) await queuer.redis.zrem(queuer.pending_key, task_id) await queuer.redis.srem(queuer.processing_key, task_id) logger.info(f"[Queuer] {platform} 任务已删除: {task_id}")

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/mcp-service/media-crawler-mcp-service'

If you have feedback or need assistance with the MCP directory API, please join our Discord server