"""
文件系统接口模块
提供异步文件操作、目录遍历和文件信息获取功能。
"""
import os
import asyncio
import aiofiles
from pathlib import Path
from typing import List, Dict, Any, Optional, AsyncGenerator, Tuple
from datetime import datetime
import mimetypes
import hashlib
from .security import SecurityValidator
from middleware.logging import log_performance
class FileInfo:
"""文件信息类"""
def __init__(self, path: Path):
self.path = path
self.name = path.name
self.is_dir = path.is_dir()
self.is_file = path.is_file()
self.size = 0
self.modified_time = None
self.extension = ""
self.mime_type = ""
self.content_hash = ""
if self.is_file and path.exists():
stat = path.stat()
self.size = stat.st_size
self.modified_time = datetime.fromtimestamp(stat.st_mtime)
self.extension = path.suffix.lower()
self.mime_type = mimetypes.guess_type(str(path))[0] or "application/octet-stream"
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"name": self.name,
"path": str(self.path),
"is_dir": self.is_dir,
"is_file": self.is_file,
"size": self.size,
"modified_time": self.modified_time.isoformat() if self.modified_time else None,
"extension": self.extension,
"mime_type": self.mime_type,
"content_hash": self.content_hash
}
class FolderStructure:
"""文件夹结构类"""
def __init__(self, root_path: str):
self.root_path = Path(root_path).resolve()
self.folders = []
self.files = []
self.total_size = 0
self.file_count = 0
self.folder_count = 0
self.generated_at = datetime.now()
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"root_path": str(self.root_path),
"folders": [folder.to_dict() for folder in self.folders],
"files": [file.to_dict() for file in self.files],
"total_size": self.total_size,
"file_count": self.file_count,
"folder_count": self.folder_count,
"generated_at": self.generated_at.isoformat()
}
class FileSystemInterface:
"""文件系统接口类"""
def __init__(self, security_validator: SecurityValidator = None):
"""
初始化文件系统接口
Args:
security_validator: 安全验证器实例
"""
self.security_validator = security_validator or SecurityValidator()
@log_performance("validate_path")
async def validate_path(self, path: str) -> Tuple[bool, str]:
"""
验证路径安全性
Args:
path: 要验证的路径
Returns:
(是否有效, 错误信息)
"""
# 安全验证
if not self.security_validator.is_safe_path(path):
return False, "路径包含不安全字符"
try:
path_obj = Path(path)
# 检查路径是否存在
if not path_obj.exists():
return False, f"路径不存在: {path}"
# 检查是否为目录
if not path_obj.is_dir():
return False, f"路径不是目录: {path}"
# 检查读取权限
if not os.access(path, os.R_OK):
return False, f"没有读取权限: {path}"
return True, ""
except (FileNotFoundError, PermissionError, OSError) as e:
return False, f"路径验证失败: {str(e)}"
@log_performance("get_folder_structure")
async def get_folder_structure(
self,
root_path: str,
exclude_dirs: Optional[List[str]] = None,
max_depth: int = 10,
include_hidden: bool = False
) -> FolderStructure:
"""
获取文件夹结构
Args:
root_path: 根目录路径
exclude_dirs: 要排除的目录列表
max_depth: 最大遍历深度
include_hidden: 是否包含隐藏文件/目录
Returns:
文件夹结构对象
"""
if exclude_dirs is None:
exclude_dirs = ['.git', '__pycache__', 'node_modules', '.vscode', '.idea']
# 验证路径
is_valid, error_msg = await self.validate_path(root_path)
if not is_valid:
raise ValueError(error_msg)
root_path_obj = Path(root_path).resolve()
structure = FolderStructure(str(root_path_obj))
async for file_info in self._walk_directory_async(
root_path_obj,
exclude_dirs,
max_depth,
include_hidden
):
if file_info.is_dir:
structure.folders.append(file_info)
structure.folder_count += 1
else:
structure.files.append(file_info)
structure.file_count += 1
structure.total_size += file_info.size
return structure
async def _walk_directory_async(
self,
root_path: Path,
exclude_dirs: List[str],
max_depth: int,
include_hidden: bool,
current_depth: int = 0
) -> AsyncGenerator[FileInfo, None]:
"""
异步遍历目录
Args:
root_path: 根目录路径
exclude_dirs: 排除目录列表
max_depth: 最大深度
include_hidden: 是否包含隐藏文件
current_depth: 当前深度
"""
if current_depth >= max_depth:
return
try:
# 获取目录中的所有项目
items = []
for item in root_path.iterdir():
if not include_hidden and item.name.startswith('.'):
continue
if item.is_dir() and item.name in exclude_dirs:
continue
items.append(item)
# 并行处理所有项目
tasks = []
for item in items:
task = self._process_item_async(item, exclude_dirs, max_depth, include_hidden, current_depth)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
try:
if isinstance(result, FileInfo):
yield result
except (FileNotFoundError, IOError, OSError) as e:
# 记录错误但继续处理其他文件
print(f"处理文件时出错: {result}")
except PermissionError:
# 忽略权限错误,继续处理其他目录
pass
except (FileNotFoundError, PermissionError, OSError) as e:
print(f"遍历目录 {root_path} 时出错: {e}")
async def _process_item_async(
self,
item: Path,
exclude_dirs: List[str],
max_depth: int,
include_hidden: bool,
current_depth: int
) -> FileInfo:
"""
异步处理单个文件/目录项
Args:
item: 文件/目录路径
exclude_dirs: 排除目录列表
max_depth: 最大深度
include_hidden: 是否包含隐藏文件
current_depth: 当前深度
Returns:
文件信息对象
"""
file_info = FileInfo(item)
# 如果是目录且需要递归,生成子项目
if file_info.is_dir and current_depth < max_depth:
try:
async for sub_file_info in self._walk_directory_async(
item, exclude_dirs, max_depth, include_hidden, current_depth + 1
):
if sub_file_info.is_file:
# 可以在这里添加对子文件的处理逻辑
pass
except (FileNotFoundError, PermissionError, OSError) as e:
print(f"处理子目录 {item} 时出错: {e}")
return file_info
@log_performance("read_file_content")
async def read_file_content(
self,
file_path: str,
max_size: int = 1024 * 1024, # 1MB
encoding: str = 'utf-8'
) -> Tuple[str, str]:
"""
异步读取文件内容
Args:
file_path: 文件路径
max_size: 最大读取大小(字节)
encoding: 文件编码
Returns:
(文件内容, 错误信息)
"""
try:
path_obj = Path(file_path)
# 验证文件
if not path_obj.exists():
return "", f"文件不存在: {file_path}"
if not path_obj.is_file():
return "", f"路径不是文件: {file_path}"
# 检查文件大小
size = path_obj.stat().st_size
if size > max_size:
return "", f"文件过大: {size} > {max_size}"
# 异步读取文件
async with aiofiles.open(path_obj, 'r', encoding=encoding) as f:
content = await f.read()
return content, ""
except UnicodeDecodeError:
return "", f"文件编码不是 {encoding}"
except PermissionError:
return "", f"没有读取权限: {file_path}"
except (FileNotFoundError, PermissionError, OSError) as e:
return "", f"读取文件失败: {str(e)}"
@log_performance("write_file_content")
async def write_file_content(
self,
file_path: str,
content: str,
encoding: str = 'utf-8',
backup: bool = True
) -> Tuple[bool, str]:
"""
异步写入文件内容
Args:
file_path: 文件路径
content: 文件内容
encoding: 文件编码
backup: 是否创建备份
Returns:
(是否成功, 错误信息)
"""
try:
path_obj = Path(file_path)
# 安全验证
if not self.security_validator.is_safe_path(file_path):
return False, "路径不安全"
# 确保目录存在
path_obj.parent.mkdir(parents=True, exist_ok=True)
# 创建备份
if backup and path_obj.exists():
backup_path = path_obj.with_suffix(f"{path_obj.suffix}.backup")
await self._copy_file_async(path_obj, backup_path)
# 异步写入文件
async with aiofiles.open(path_obj, 'w', encoding=encoding) as f:
await f.write(content)
return True, ""
except PermissionError:
return False, f"没有写入权限: {file_path}"
except (FileNotFoundError, PermissionError, OSError) as e:
return False, f"写入文件失败: {str(e)}"
async def _copy_file_async(self, source: Path, destination: Path) -> bool:
"""异步复制文件"""
try:
async with aiofiles.open(source, 'rb') as src:
content = await src.read()
async with aiofiles.open(destination, 'wb') as dst:
await dst.write(content)
return True
except (FileNotFoundError, IOError, IndexError) as e:
return False
@log_performance("get_file_hash")
async def get_file_hash(self, file_path: str, algorithm: str = 'md5') -> Optional[str]:
"""
计算文件哈希值
Args:
file_path: 文件路径
algorithm: 哈希算法
Returns:
文件哈希值
"""
try:
path_obj = Path(file_path)
if not path_obj.exists() or not path_obj.is_file():
return None
hash_obj = hashlib.new(algorithm)
async with aiofiles.open(path_obj, 'rb') as f:
while chunk := await f.read(8192): # 8KB chunks
hash_obj.update(chunk)
return hash_obj.hexdigest()
except (FileNotFoundError, IOError, OSError) as e:
return None
@log_performance("should_update_file")
async def should_update_file(
self,
target_file: str,
source_files: List[str]
) -> bool:
"""
判断是否需要更新目标文件
Args:
target_file: 目标文件路径
source_files: 源文件路径列表
Returns:
是否需要更新
"""
try:
target_path = Path(target_file)
# 如果目标文件不存在,需要更新
if not target_path.exists():
return True
# 获取目标文件的修改时间
target_mtime = target_path.stat().st_mtime
# 检查源文件是否有更新
for source_file in source_files:
source_path = Path(source_file)
if source_path.exists():
source_mtime = source_path.stat().st_mtime
if source_mtime > target_mtime:
return True
return False
except (FileNotFoundError, IOError, IndexError) as e:
return True # 出错时默认更新