"""
测试 JSON 存储层 - 数据持久化和文件操作
这些测试验证:
1. JSON 文件的读写操作
2. 数据完整性和事务性
3. 错误处理和恢复机制
4. 并发访问控制
"""
import pytest
import json
import tempfile
import os
from pathlib import Path
from decimal import Decimal
from datetime import date, datetime
from typing import List, Dict, Any
from unittest.mock import patch, mock_open
# 这些导入会失败,因为我们还没有实现存储层
from accounting_mcp.storage import JSONStorage, StorageError
from accounting_mcp.models import Transaction, Account, Category, TransactionType
class TestJSONStorage:
"""测试 JSON 存储基类"""
def test_storage_initialization_with_existing_file(self, tmp_path: Path) -> None:
"""测试使用已存在文件初始化存储"""
data_file = tmp_path / "test_data.json"
initial_data = {"test": "data"}
# 创建初始文件
with open(data_file, 'w') as f:
json.dump(initial_data, f)
storage = JSONStorage(data_file)
assert storage.file_path == data_file
assert storage.load_data() == initial_data
def test_storage_initialization_with_new_file(self, tmp_path: Path) -> None:
"""测试使用不存在文件初始化存储(应创建空文件)"""
data_file = tmp_path / "new_data.json"
storage = JSONStorage(data_file, default_data={"empty": True})
assert storage.file_path == data_file
assert data_file.exists()
assert storage.load_data() == {"empty": True}
def test_storage_save_and_load(self, tmp_path: Path) -> None:
"""测试数据保存和加载"""
data_file = tmp_path / "test_data.json"
storage = JSONStorage(data_file)
test_data = {"key": "value", "number": 42}
storage.save_data(test_data)
loaded_data = storage.load_data()
assert loaded_data == test_data
def test_storage_atomic_write(self, tmp_path: Path) -> None:
"""测试原子写入(避免写入中断导致文件损坏)"""
data_file = tmp_path / "test_data.json"
storage = JSONStorage(data_file)
# 初始数据
storage.save_data({"version": 1})
# 模拟写入过程中出错
with patch('builtins.open', side_effect=IOError("磁盘已满")):
with pytest.raises(StorageError):
storage.save_data({"version": 2})
# 原始数据应该保持不变
loaded_data = storage.load_data()
assert loaded_data == {"version": 1}
def test_storage_backup_on_corruption(self, tmp_path: Path) -> None:
"""测试文件损坏时的备份恢复"""
data_file = tmp_path / "test_data.json"
backup_file = tmp_path / "test_data_backup.json"
# 创建有效的初始数据
storage = JSONStorage(data_file)
storage.save_data({"valid": "data"})
# 手动创建备份
storage._create_backup()
assert backup_file.exists()
# 模拟文件损坏
with open(data_file, 'w') as f:
f.write("invalid json {")
# 加载时应该从备份恢复
loaded_data = storage.load_data()
assert loaded_data == {"valid": "data"}
def test_storage_permission_error_handling(self, tmp_path: Path) -> None:
"""测试文件权限错误处理"""
data_file = tmp_path / "readonly_data.json"
storage = JSONStorage(data_file)
# 创建只读文件
storage.save_data({"test": "data"})
os.chmod(data_file, 0o444) # 只读权限
# 尝试写入应该抛出权限错误
with pytest.raises(StorageError, match="权限"):
storage.save_data({"new": "data"})
# 恢复权限以便清理
os.chmod(data_file, 0o644)
class TestTransactionStorage:
"""测试交易记录存储"""
def test_save_and_load_transactions(self, tmp_path: Path) -> None:
"""测试保存和加载交易记录"""
from accounting_mcp.storage import TransactionStorage
storage = TransactionStorage(tmp_path / "transactions.json")
transactions = [
Transaction(
amount=Decimal("-50.00"),
category="food",
description="午餐"
),
Transaction(
amount=Decimal("1000.00"),
category="income",
description="工资"
)
]
# 保存交易
storage.save_transactions(transactions)
# 加载交易
loaded_transactions = storage.load_transactions()
assert len(loaded_transactions) == 2
# 按时间倒序,收入交易(后创建)应该在前面
assert loaded_transactions[0].amount == Decimal("1000.00")
assert loaded_transactions[0].category == "income"
assert loaded_transactions[1].amount == Decimal("-50.00")
assert loaded_transactions[1].category == "food"
def test_add_single_transaction(self, tmp_path: Path) -> None:
"""测试添加单笔交易"""
from accounting_mcp.storage import TransactionStorage
storage = TransactionStorage(tmp_path / "transactions.json")
transaction = Transaction(
amount=Decimal("-30.00"),
category="transport",
description="地铁"
)
# 添加交易
storage.add_transaction(transaction)
# 验证交易被保存
transactions = storage.load_transactions()
assert len(transactions) == 1
assert transactions[0].amount == Decimal("-30.00")
assert transactions[0].category == "transport"
def test_transaction_filtering(self, tmp_path: Path) -> None:
"""测试交易筛选功能"""
from accounting_mcp.storage import TransactionStorage
storage = TransactionStorage(tmp_path / "transactions.json")
# 添加多笔交易
transactions = [
Transaction(amount=Decimal("-50.00"), category="food", date=date(2025, 1, 15)),
Transaction(amount=Decimal("-20.00"), category="transport", date=date(2025, 1, 16)),
Transaction(amount=Decimal("-30.00"), category="food", date=date(2025, 1, 17)),
]
for t in transactions:
storage.add_transaction(t)
# 按分类筛选
food_transactions = storage.get_transactions_by_category("food")
assert len(food_transactions) == 2
assert all(t.category == "food" for t in food_transactions)
# 按日期范围筛选
date_filtered = storage.get_transactions_by_date_range(
start_date=date(2025, 1, 16),
end_date=date(2025, 1, 17)
)
assert len(date_filtered) == 2
def test_transaction_pagination(self, tmp_path: Path) -> None:
"""测试交易分页功能"""
from accounting_mcp.storage import TransactionStorage
storage = TransactionStorage(tmp_path / "transactions.json")
# 添加多笔交易
for i in range(25):
storage.add_transaction(Transaction(
amount=Decimal(f"-{i+1}.00"),
category="food",
description=f"交易{i+1}"
))
# 测试分页
page1 = storage.get_transactions(limit=10, offset=0)
page2 = storage.get_transactions(limit=10, offset=10)
page3 = storage.get_transactions(limit=10, offset=20)
assert len(page1) == 10
assert len(page2) == 10
assert len(page3) == 5
# 验证顺序(最新的在前)
assert page1[0].description == "交易25"
assert page1[-1].description == "交易16"
class TestAccountStorage:
"""测试账户信息存储"""
def test_save_and_load_account(self, tmp_path: Path) -> None:
"""测试保存和加载账户信息"""
from accounting_mcp.storage import AccountStorage
storage = AccountStorage(tmp_path / "account.json")
account = Account()
account.add_transaction(Transaction(
amount=Decimal("-50.00"),
category="food"
))
# 保存账户
storage.save_account(account)
# 加载账户
loaded_account = storage.load_account()
assert loaded_account.balance == Decimal("-50.00")
assert loaded_account.total_transactions == 1
def test_update_account_balance(self, tmp_path: Path) -> None:
"""测试更新账户余额"""
from accounting_mcp.storage import AccountStorage
storage = AccountStorage(tmp_path / "account.json")
# 初始账户
account = Account()
storage.save_account(account)
# 更新余额
storage.update_balance(Decimal("100.00"), 1)
# 验证更新
updated_account = storage.load_account()
assert updated_account.balance == Decimal("100.00")
assert updated_account.total_transactions == 1
class TestCategoryStorage:
"""测试分类信息存储"""
def test_load_default_categories(self, tmp_path: Path) -> None:
"""测试加载默认分类"""
from accounting_mcp.storage import CategoryStorage
storage = CategoryStorage(tmp_path / "categories.json")
categories = storage.load_categories()
# 应该有默认分类
assert len(categories) > 0
# 检查必要的分类
category_ids = [c.id for c in categories]
assert "food" in category_ids
assert "transport" in category_ids
assert "income" in category_ids
def test_validate_category(self, tmp_path: Path) -> None:
"""测试分类验证"""
from accounting_mcp.storage import CategoryStorage
storage = CategoryStorage(tmp_path / "categories.json")
# 有效分类
assert storage.is_valid_category("food") == True
# 无效分类
assert storage.is_valid_category("invalid_category") == False
assert storage.is_valid_category("") == False
# 集成测试
class TestStorageIntegration:
"""存储层集成测试"""
def test_complete_workflow(self, tmp_path: Path) -> None:
"""测试完整的数据流程"""
from accounting_mcp.storage import TransactionStorage, AccountStorage
trans_storage = TransactionStorage(tmp_path / "transactions.json")
acc_storage = AccountStorage(tmp_path / "account.json")
# 1. 创建账户
account = Account()
acc_storage.save_account(account)
# 2. 添加交易
transaction = Transaction(
amount=Decimal("-50.00"),
category="food",
description="午餐"
)
trans_storage.add_transaction(transaction)
# 3. 更新账户
account.add_transaction(transaction)
acc_storage.save_account(account)
# 4. 验证数据一致性
loaded_transactions = trans_storage.load_transactions()
loaded_account = acc_storage.load_account()
assert len(loaded_transactions) == 1
assert loaded_account.balance == Decimal("-50.00")
assert loaded_account.total_transactions == 1
def test_concurrent_access_simulation(self, tmp_path: Path) -> None:
"""测试并发访问模拟"""
from accounting_mcp.storage import TransactionStorage
import threading
import time
storage = TransactionStorage(tmp_path / "transactions.json")
results = []
def add_transaction(i):
max_retries = 3
for attempt in range(max_retries):
try:
transaction = Transaction(
amount=Decimal(f"-{i}.00"),
category="test",
description=f"并发测试{i}"
)
storage.add_transaction(transaction)
results.append(f"success-{i}")
return
except Exception as e:
if attempt == max_retries - 1:
results.append(f"error-{i}: {e}")
else:
# 短暂等待后重试
time.sleep(0.01 * (attempt + 1))
# 创建多个线程同时写入
threads = []
for i in range(5):
thread = threading.Thread(target=add_transaction, args=(i,))
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# 验证大部分操作都成功(允许偶尔的并发冲突)
success_count = len([r for r in results if r.startswith("success")])
assert success_count >= 4 # 至少4个成功,允许1个失败
# 验证成功的交易都被保存
transactions = storage.load_transactions()
assert len(transactions) == success_count
# 测试工具函数
@pytest.fixture
def sample_transaction() -> Transaction:
"""提供示例交易用于测试"""
return Transaction(
amount=Decimal("-50.00"),
category="food",
description="测试交易"
)
@pytest.fixture
def sample_account() -> Account:
"""提供示例账户用于测试"""
account = Account()
account.add_transaction(Transaction(
amount=Decimal("-50.00"),
category="food"
))
return account