"""
存储层 - JSON 文件存储实现
提供基于 JSON 文件的数据持久化功能,包括:
- JSONStorage: 基础 JSON 文件存储
- TransactionStorage: 交易记录存储
- AccountStorage: 账户信息存储
- CategoryStorage: 分类信息存储
- 原子写入、备份恢复、并发控制等特性
"""
import json
import os
import shutil
import threading
from pathlib import Path
from typing import Dict, Any, List, Optional, Union
from datetime import date, datetime
from decimal import Decimal
from .models import Transaction, Account, Category, create_default_categories
class StorageError(Exception):
"""存储相关异常"""
pass
class JSONStorage:
"""基础 JSON 存储类"""
def __init__(self, file_path: Union[str, Path], default_data: Optional[Dict[str, Any]] = None):
self.file_path = Path(file_path)
self.default_data = default_data or {}
self._lock = threading.RLock() # 可重入锁,防止并发问题
# 确保目录存在
self.file_path.parent.mkdir(parents=True, exist_ok=True)
# 如果文件不存在,创建默认数据
if not self.file_path.exists():
self.save_data(self.default_data)
def load_data(self) -> Dict[str, Any]:
"""加载数据,支持备份恢复"""
with self._lock:
try:
return self._load_from_file(self.file_path)
except (json.JSONDecodeError, OSError) as e:
# 文件损坏或不可读,尝试从备份恢复
backup_path = self._get_backup_path()
if backup_path.exists():
try:
data = self._load_from_file(backup_path)
# 恢复主文件
self._write_file(self.file_path, data)
return data
except Exception:
pass
# 如果备份也失败,返回默认数据
return self.default_data.copy()
def save_data(self, data: Dict[str, Any]) -> None:
"""保存数据,使用原子写入"""
with self._lock:
try:
# 检查文件权限(如果文件存在)
if self.file_path.exists() and not os.access(self.file_path, os.W_OK):
raise StorageError("权限错误: 文件只读,无法写入")
# 先创建备份(如果原文件存在)
if self.file_path.exists():
self._create_backup()
# 原子写入
self._atomic_write(self.file_path, data)
except Exception as e:
raise StorageError(f"保存数据失败: {e}")
def _load_from_file(self, file_path: Path) -> Dict[str, Any]:
"""从文件加载 JSON 数据"""
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
def _write_file(self, file_path: Path, data: Dict[str, Any]) -> None:
"""写入 JSON 数据到文件"""
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
def _atomic_write(self, file_path: Path, data: Dict[str, Any]) -> None:
"""原子写入文件"""
temp_path = file_path.with_suffix(file_path.suffix + '.tmp')
try:
# 写入临时文件
self._write_file(temp_path, data)
# 原子重命名
temp_path.replace(file_path)
except PermissionError as e:
# 清理临时文件
if temp_path.exists():
temp_path.unlink()
raise StorageError(f"权限错误: {e}")
except Exception as e:
# 清理临时文件
if temp_path.exists():
temp_path.unlink()
raise StorageError(f"原子写入失败: {e}")
def _create_backup(self) -> None:
"""创建备份文件"""
if self.file_path.exists():
backup_path = self._get_backup_path()
shutil.copy2(self.file_path, backup_path)
def _get_backup_path(self) -> Path:
"""获取备份文件路径"""
# 正确的方法:使用 with_name 替换整个文件名
stem = self.file_path.stem # 文件名(不含后缀)
suffix = self.file_path.suffix # 后缀(如 .json)
backup_name = f"{stem}_backup{suffix}"
return self.file_path.with_name(backup_name)
class TransactionStorage:
"""交易记录存储"""
def __init__(self, file_path: Union[str, Path]):
self.storage = JSONStorage(file_path, {"transactions": []})
self._lock = threading.RLock() # 添加锁用于并发控制
def load_transactions(self) -> List[Transaction]:
"""加载所有交易记录"""
data = self.storage.load_data()
transactions = []
for trans_data in data.get("transactions", []):
try:
transactions.append(Transaction.from_dict(trans_data))
except Exception:
# 忽略损坏的记录
continue
# 按时间戳倒序排序(最新的在前)
transactions.sort(key=lambda t: t.timestamp, reverse=True)
return transactions
def save_transactions(self, transactions: List[Transaction]) -> None:
"""保存交易记录列表"""
trans_data = [t.to_dict() for t in transactions]
self.storage.save_data({"transactions": trans_data})
def add_transaction(self, transaction: Transaction) -> None:
"""添加单笔交易(线程安全)"""
with self._lock:
transactions = self.load_transactions()
transactions.append(transaction)
self.save_transactions(transactions)
def get_transactions_by_category(self, category: str) -> List[Transaction]:
"""按分类筛选交易"""
transactions = self.load_transactions()
return [t for t in transactions if t.category == category]
def get_transactions_by_date_range(
self,
start_date: date,
end_date: date
) -> List[Transaction]:
"""按日期范围筛选交易"""
transactions = self.load_transactions()
return [
t for t in transactions
if start_date <= t.date <= end_date
]
def get_transactions(
self,
limit: int = 20,
offset: int = 0,
category: Optional[str] = None,
start_date: Optional[date] = None,
end_date: Optional[date] = None
) -> List[Transaction]:
"""获取交易记录(支持分页和筛选)"""
transactions = self.load_transactions()
# 应用筛选条件
if category:
transactions = [t for t in transactions if t.category == category]
if start_date:
transactions = [t for t in transactions if t.date >= start_date]
if end_date:
transactions = [t for t in transactions if t.date <= end_date]
# 应用分页
return transactions[offset:offset + limit]
class AccountStorage:
"""账户信息存储"""
def __init__(self, file_path: Union[str, Path]):
default_account = Account().to_dict()
self.storage = JSONStorage(file_path, default_account)
def load_account(self) -> Account:
"""加载账户信息"""
data = self.storage.load_data()
return Account.from_dict(data)
def save_account(self, account: Account) -> None:
"""保存账户信息"""
self.storage.save_data(account.to_dict())
def update_balance(self, new_balance: Decimal, transaction_count: int) -> None:
"""更新账户余额和交易数量"""
account = self.load_account()
account.balance = new_balance
account.total_transactions = transaction_count
account.last_updated = datetime.now()
self.save_account(account)
class CategoryStorage:
"""分类信息存储"""
def __init__(self, file_path: Union[str, Path]):
# 创建默认分类数据
default_categories = [cat.to_dict() for cat in create_default_categories()]
self.storage = JSONStorage(file_path, {"categories": default_categories})
def load_categories(self) -> List[Category]:
"""加载分类列表"""
data = self.storage.load_data()
categories = []
# 处理不同格式的数据:数组格式或对象格式
if isinstance(data, list):
# 直接是分类数组格式 (兼容现有的 categories.json)
cat_list = data
else:
# 对象格式,从 categories 键获取
cat_list = data.get("categories", [])
for cat_data in cat_list:
try:
categories.append(Category.from_dict(cat_data))
except Exception:
# 忽略损坏的记录
continue
# 如果没有分类,返回默认分类
if not categories:
categories = create_default_categories()
self.save_categories(categories)
return categories
def save_categories(self, categories: List[Category]) -> None:
"""保存分类列表"""
cat_data = [cat.to_dict() for cat in categories]
self.storage.save_data({"categories": cat_data})
def is_valid_category(self, category_id: str) -> bool:
"""验证分类是否有效"""
if not category_id or not category_id.strip():
return False
categories = self.load_categories()
valid_ids = {cat.id for cat in categories}
return category_id in valid_ids
def get_category(self, category_id: str) -> Optional[Category]:
"""根据ID获取分类"""
categories = self.load_categories()
for category in categories:
if category.id == category_id:
return category
return None
def add_category(self, category: Category) -> None:
"""添加新分类"""
categories = self.load_categories()
# 检查ID是否已存在
existing_ids = {cat.id for cat in categories}
if category.id in existing_ids:
raise StorageError(f"分类ID '{category.id}' 已存在")
categories.append(category)
self.save_categories(categories)
# 存储管理器 - 统一管理所有存储
class StorageManager:
"""存储管理器,统一管理所有存储组件"""
def __init__(self, data_dir: Union[str, Path] = "data"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
# 初始化各个存储组件
self.transactions = TransactionStorage(self.data_dir / "transactions.json")
self.account = AccountStorage(self.data_dir / "account.json")
self.categories = CategoryStorage(self.data_dir / "categories.json")
def add_transaction(self, transaction: Transaction) -> None:
"""添加交易并更新账户"""
# 验证分类
if not self.categories.is_valid_category(transaction.category):
raise StorageError(f"无效的分类: {transaction.category}")
# 添加交易
self.transactions.add_transaction(transaction)
# 更新账户
account = self.account.load_account()
account.add_transaction(transaction)
self.account.save_account(account)
def get_balance(self) -> Decimal:
"""获取当前余额"""
account = self.account.load_account()
return account.balance
def get_transaction_count(self) -> int:
"""获取交易总数"""
account = self.account.load_account()
return account.total_transactions
def get_monthly_summary(self, year: int, month: int) -> Dict[str, Any]:
"""获取月度汇总"""
from calendar import monthrange
start_date = date(year, month, 1)
end_date = date(year, month, monthrange(year, month)[1])
transactions = self.transactions.get_transactions_by_date_range(start_date, end_date)
# 计算统计
total_income = sum(t.amount for t in transactions if t.amount > 0)
total_expense = sum(abs(t.amount) for t in transactions if t.amount < 0)
# 按分类统计
category_stats = {}
categories = self.categories.load_categories()
category_names = {cat.id: cat.name for cat in categories}
for transaction in transactions:
cat_name = category_names.get(transaction.category, transaction.category)
if cat_name not in category_stats:
category_stats[cat_name] = Decimal('0.00')
category_stats[cat_name] += abs(transaction.amount)
return {
"year": year,
"month": month,
"total_income": float(total_income),
"total_expense": float(total_expense),
"net_flow": float(total_income - total_expense),
"transaction_count": len(transactions),
"category_breakdown": {k: float(v) for k, v in category_stats.items()}
}