"""
下载管理模块
实现异步图片下载、超时控制、重试机制
"""
import asyncio
import aiohttp
import logging
from typing import Optional, Dict, Any
from pathlib import Path
import time
logger = logging.getLogger(__name__)
class DownloadError(Exception):
"""下载错误异常"""
pass
class DownloadManager:
"""异步下载管理器"""
def __init__(
self,
timeout: int = 30,
max_retries: int = 3,
retry_delay: float = 1.0,
max_file_size: int = 50 * 1024 * 1024 # 50MB
):
"""
初始化下载管理器
Args:
timeout: 下载超时时间(秒)
max_retries: 最大重试次数
retry_delay: 重试延迟时间(秒)
max_file_size: 最大文件大小(字节)
"""
self.timeout = timeout
self.max_retries = max_retries
self.retry_delay = retry_delay
self.max_file_size = max_file_size
async def download_image(
self,
url: str,
save_path: Path,
headers: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""
异步下载图片
Args:
url: 图片URL
save_path: 保存路径
headers: 请求头
Returns:
下载结果信息
Raises:
DownloadError: 下载失败时抛出
"""
if headers is None:
headers = {
'User-Agent': 'Seedream-MCP/1.0',
'Accept': 'image/*'
}
start_time = time.time()
last_error = None
for attempt in range(self.max_retries + 1):
try:
logger.info(f"开始下载图片 (尝试 {attempt + 1}/{self.max_retries + 1}): {url}")
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
async with session.get(url, headers=headers) as response:
# 检查响应状态
if response.status != 200:
raise DownloadError(f"HTTP错误: {response.status}")
# 检查内容类型
content_type = response.headers.get('content-type', '')
if not content_type.startswith('image/'):
logger.warning(f"内容类型可能不是图片: {content_type}")
# 检查文件大小
content_length = response.headers.get('content-length')
if content_length and int(content_length) > self.max_file_size:
raise DownloadError(f"文件过大: {content_length} 字节")
# 确保目录存在
save_path.parent.mkdir(parents=True, exist_ok=True)
# 下载并保存文件
total_size = 0
with open(save_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
total_size += len(chunk)
if total_size > self.max_file_size:
raise DownloadError(f"文件过大: {total_size} 字节")
f.write(chunk)
download_time = time.time() - start_time
result = {
'success': True,
'file_path': str(save_path),
'file_size': total_size,
'download_time': download_time,
'content_type': content_type,
'attempts': attempt + 1
}
logger.info(f"图片下载成功: {save_path} ({total_size} 字节, {download_time:.2f}秒)")
return result
except asyncio.TimeoutError as e:
last_error = DownloadError(f"下载超时: {e}")
logger.warning(f"下载超时 (尝试 {attempt + 1}): {url}")
except aiohttp.ClientError as e:
last_error = DownloadError(f"网络错误: {e}")
logger.warning(f"网络错误 (尝试 {attempt + 1}): {e}")
except OSError as e:
last_error = DownloadError(f"文件系统错误: {e}")
logger.warning(f"文件系统错误 (尝试 {attempt + 1}): {e}")
except Exception as e:
last_error = DownloadError(f"未知错误: {e}")
logger.warning(f"下载失败 (尝试 {attempt + 1}): {e}")
# 如果不是最后一次尝试,等待后重试
if attempt < self.max_retries:
await asyncio.sleep(self.retry_delay * (attempt + 1))
# 所有重试都失败了
logger.error(f"图片下载失败,已重试 {self.max_retries} 次: {url}")
raise last_error or DownloadError("下载失败")
async def download_multiple_images(
self,
urls_and_paths: list[tuple[str, Path]],
headers: Optional[Dict[str, str]] = None,
max_concurrent: int = 5
) -> list[Dict[str, Any]]:
"""
并发下载多个图片
Args:
urls_and_paths: URL和保存路径的元组列表
headers: 请求头
max_concurrent: 最大并发数
Returns:
下载结果列表
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def download_with_semaphore(url: str, path: Path) -> Dict[str, Any]:
async with semaphore:
try:
return await self.download_image(url, path, headers)
except Exception as e:
logger.error(f"下载失败: {url} -> {e}")
return {
'success': False,
'url': url,
'file_path': str(path),
'error': str(e)
}
tasks = [
download_with_semaphore(url, path)
for url, path in urls_and_paths
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常结果
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
url, path = urls_and_paths[i]
processed_results.append({
'success': False,
'url': url,
'file_path': str(path),
'error': str(result)
})
else:
processed_results.append(result)
return processed_results
def validate_url(self, url: str) -> bool:
"""
验证URL格式
Args:
url: 要验证的URL
Returns:
是否为有效URL
"""
try:
from urllib.parse import urlparse
result = urlparse(url)
return all([result.scheme, result.netloc])
except Exception:
return False
def get_file_extension_from_url(self, url: str) -> str:
"""
从URL获取文件扩展名
Args:
url: 图片URL
Returns:
文件扩展名(包含点号)
"""
try:
from urllib.parse import urlparse
path = urlparse(url).path
if '.' in path:
return Path(path).suffix.lower()
return '.jpg' # 默认扩展名
except Exception:
return '.jpg'