"""
文件下载工具模块
提供MCP工具接口用于下载网络资源和RustFS文件。
"""
import os
import urllib.parse
from typing import Dict, Any
from pathlib import Path
import httpx
import aiofiles
from .config import config
from .rustfs_client import RustFSClient
def create_download_tool() -> callable:
"""创建文件下载工具函数"""
timeout = config.timeout
rustfs_client = RustFSClient()
def download_file(url: str, download_path: str) -> Dict[str, Any]:
"""
从指定URL下载文件到本地路径(支持普通URL和RustFS URL)
Args:
url: 要下载的文件URL(支持HTTP/HTTPS协议和RustFS URL)
download_path: 本地保存路径(可以是目录或具体文件路径)
Returns:
包含下载结果的字典,包括:
- success: 下载是否成功
- url: 原始下载URL
- file_path: 保存的完整文件路径
- filename: 文件名
- size: 文件大小(字节)
- content_type: 文件MIME类型
- download_method: 下载方法(http/https 或 boto3_s3)
- bucket: 如果是RustFS下载,显示存储桶名称
- file_key: 如果是RustFS下载,显示文件键名
- message: 操作消息
Raises:
ValueError: URL无效或下载路径无效
FileNotFoundError: 文件不存在(RustFS)
RuntimeError: 下载失败
"""
try:
# 验证输入参数
if not url or not isinstance(url, str):
raise ValueError("URL不能为空且必须是字符串")
if not download_path or not isinstance(download_path, str):
raise ValueError("下载路径不能为空且必须是字符串")
# 标准化URL和路径
url = url.strip()
download_path = os.path.abspath(download_path.strip())
# 检查是否是RustFS URL
rustfs_result = rustfs_client.parse_rustfs_url(url)
if rustfs_result:
# 使用RustFS SDK下载
return _download_from_rustfs(rustfs_result, url, download_path)
else:
# 使用HTTP/HTTPS下载
return _download_from_http(url, download_path)
except (ValueError, FileNotFoundError) as e:
raise
except Exception as e:
raise RuntimeError(f"文件下载失败: {str(e)}") from e
def _download_from_rustfs(rustfs_result: tuple, original_url: str, download_path: str) -> Dict[str, Any]:
"""使用RustFS SDK下载文件"""
bucket, file_key = rustfs_result
# 确定本地文件路径
if os.path.isdir(download_path):
# 如果是目录,使用原始文件名
filename = os.path.basename(file_key)
local_path = os.path.join(download_path, filename)
else:
local_path = download_path
filename = os.path.basename(download_path)
try:
# 使用RustFS客户端下载
result = rustfs_client.download_file(bucket, file_key, local_path)
return {
**result,
"url": original_url,
"file_path": result["local_path"], # 确保有file_path字段
"message": f"文件 '{file_key}' 从RustFS存储桶 '{bucket}' 下载成功"
}
except Exception as e:
if isinstance(e, FileNotFoundError):
raise
raise RuntimeError(f"RustFS下载失败: {str(e)}") from e
def _download_from_http(url: str, download_path: str) -> Dict[str, Any]:
"""使用HTTP/HTTPS下载文件"""
# 验证URL格式
parsed_url = urllib.parse.urlparse(url)
if not parsed_url.scheme or not parsed_url.scheme in ['http', 'https']:
raise ValueError("URL必须是有效的HTTP或HTTPS地址")
if not parsed_url.netloc:
raise ValueError("URL格式无效,缺少域名")
# 确定本地文件路径
if os.path.isdir(download_path):
filename = _extract_filename_from_url(url)
file_path = os.path.join(download_path, filename)
else:
# 确保父目录存在
parent_dir = os.path.dirname(download_path)
if parent_dir and not os.path.exists(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
file_path = download_path
filename = os.path.basename(download_path)
try:
# 使用同步HTTP客户端下载
import httpx
with httpx.Client(timeout=timeout, follow_redirects=True) as client:
# 发送GET请求并下载文件
response = client.get(url)
response.raise_for_status()
# 写入文件
with open(file_path, 'wb') as f:
f.write(response.content)
# 获取文件信息
total_size = len(response.content)
content_type = response.headers.get('content-type', 'application/octet-stream')
return {
"success": True,
"url": url,
"file_path": file_path,
"filename": filename,
"size": total_size,
"content_type": content_type,
"download_method": "http",
"message": f"文件 '{filename}' 下载成功"
}
except httpx.HTTPStatusError as e:
error_msg = f"下载失败: HTTP {e.response.status_code}"
try:
error_detail = e.response.json().get("message", "")
if error_detail:
error_msg += f" - {error_detail}"
except:
pass
raise RuntimeError(error_msg) from e
except Exception as e:
raise RuntimeError(f"HTTP下载失败: {str(e)}") from e
def _extract_filename_from_url(url: str) -> str:
"""
从URL中提取文件名
Args:
url: 文件URL
Returns:
提取的文件名
"""
parsed_url = urllib.parse.urlparse(url)
path = parsed_url.path
filename = os.path.basename(path)
# 如果无法从路径中提取文件名,使用默认名称
if not filename:
filename = "downloaded_file"
return filename
return download_file