"""测试服务类"""
import pytest
import tempfile
import os
from unittest.mock import patch
from src.jxls_mcp.models import DataField, DataStruct, GenerateTemplateRequest
from src.jxls_mcp.services import TemplateGenerator, ParameterValidator
class TestParameterValidator:
"""测试参数验证服务"""
def setup_method(self):
self.validator = ParameterValidator()
def test_validate_json_data_struct(self):
"""测试JSON格式数据结构验证"""
data_struct = DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", field="test")]
)
# 不应该抛出异常
self.validator.validate_data_struct(data_struct, "json")
def test_validate_array_data_struct(self):
"""测试数组格式数据结构验证"""
data_struct = DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", index=0)]
)
# 不应该抛出异常
self.validator.validate_data_struct(data_struct, "array")
def test_json_format_with_index_field(self):
"""测试JSON格式使用index字段"""
data_struct = DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", index=0)] # JSON格式不应该有index
)
with pytest.raises(ValueError, match="JSON格式时字段.*不应提供index属性"):
self.validator.validate_data_struct(data_struct, "json")
def test_array_format_with_field_field(self):
"""测试数组格式使用field字段"""
data_struct = DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", field="test")] # 数组格式不应该有field
)
with pytest.raises(ValueError, match="数组格式时字段.*不应提供field属性"):
self.validator.validate_data_struct(data_struct, "array")
def test_duplicate_field_names(self):
"""测试重复字段名"""
data_struct = DataStruct(
collectName="data",
itemVariable="item",
dataFields=[
DataField(name="重复", field="test1"),
DataField(name="重复", field="test2") # 重复名称
]
)
with pytest.raises(ValueError, match="字段名称必须唯一"):
self.validator.validate_data_struct(data_struct, "json")
def test_duplicate_indices(self):
"""测试重复索引"""
data_struct = DataStruct(
collectName="data",
itemVariable="item",
dataFields=[
DataField(name="测试1", index=0),
DataField(name="测试2", index=0) # 重复索引
]
)
with pytest.raises(ValueError, match="数组格式时字段索引必须唯一"):
self.validator.validate_data_struct(data_struct, "array")
def test_validate_safe_template_name(self):
"""测试安全的模板名称"""
# 不应该抛出异常
self.validator.validate_template_name("valid_template_name")
def test_validate_unsafe_template_name(self):
"""测试不安全的模板名称"""
unsafe_names = ["test/name", "test\\name", "test:name", "test*name"]
for unsafe_name in unsafe_names:
with pytest.raises(ValueError, match="模板名称不能包含字符"):
self.validator.validate_template_name(unsafe_name)
def test_validate_reserved_template_name(self):
"""测试保留名称"""
with pytest.raises(ValueError, match="模板名称不能使用保留名称"):
self.validator.validate_template_name("CON")
def test_validate_json_sample_data(self):
"""测试JSON示例数据验证"""
data_struct = DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", field="test")]
)
sample_data = [{"test": "value1"}, {"test": "value2"}]
# 不应该抛出异常
self.validator.validate_sample_data(sample_data, data_struct, "json")
def test_validate_invalid_json_sample_data(self):
"""测试无效的JSON示例数据"""
data_struct = DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", field="test")]
)
sample_data = ["not_a_dict"] # 不是字典
with pytest.raises(ValueError, match="示例数据第.*项必须是对象类型"):
self.validator.validate_sample_data(sample_data, data_struct, "json")
def test_validate_array_sample_data(self):
"""测试数组示例数据验证"""
data_struct = DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", index=0)]
)
sample_data = [["value1"], ["value2"]]
# 不应该抛出异常
self.validator.validate_sample_data(sample_data, data_struct, "array")
def test_validate_short_array_sample_data(self):
"""测试长度不足的数组示例数据"""
data_struct = DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", index=2)] # 需要至少3个元素
)
sample_data = [["v1", "v2"]] # 只有2个元素
with pytest.raises(ValueError, match="示例数据第.*项数组长度不足"):
self.validator.validate_sample_data(sample_data, data_struct, "array")
def test_validate_absolute_output_path(self):
"""测试绝对路径验证(应该通过)"""
import tempfile
temp_dir = tempfile.mkdtemp()
absolute_path = os.path.join(temp_dir, "test.xlsx")
# 不应该抛出异常
self.validator.validate_output_path(absolute_path)
# 清理
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_validate_path_traversal_attack(self):
"""测试路径遍历攻击防护"""
unsafe_path = "../../../etc/passwd"
with pytest.raises(ValueError, match="输出路径不能包含"):
self.validator.validate_output_path(unsafe_path)
class TestTemplateGenerator:
"""测试模板生成服务"""
def setup_method(self):
# 使用临时目录
self.temp_dir = tempfile.mkdtemp()
self.generator = TemplateGenerator(self.temp_dir)
def teardown_method(self):
# 清理临时文件
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_generate_json_template(self):
"""测试生成JSON格式模板"""
request = GenerateTemplateRequest(
templateName="test_json",
dataStruct=DataStruct(
collectName="employees",
itemVariable="employee",
dataFields=[
DataField(name="姓名", field="name"),
DataField(name="年龄", field="age")
]
),
dataFormat="json"
)
response = self.generator.generate_template(request)
assert response.success == True
assert response.templatePath is not None
assert os.path.exists(response.templatePath)
assert response.jxlsAnnotations is not None
assert response.dataStruct is not None
assert response.dataStruct.columnCount == 2
assert "employees" in response.jxlsAnnotations.each
assert "employee" in response.jxlsAnnotations.each
def test_generate_array_template(self):
"""测试生成数组格式模板"""
request = GenerateTemplateRequest(
templateName="test_array",
dataStruct=DataStruct(
collectName="dataList",
itemVariable="row",
dataFields=[
DataField(name="列1", index=0),
DataField(name="列2", index=1),
DataField(name="列3", index=2)
]
),
dataFormat="array"
)
response = self.generator.generate_template(request)
assert response.success == True
assert response.templatePath is not None
assert os.path.exists(response.templatePath)
assert response.jxlsAnnotations is not None
assert response.dataStruct is not None
assert response.dataStruct.columnCount == 3
assert "dataList" in response.jxlsAnnotations.each
assert "row" in response.jxlsAnnotations.each
def test_generate_template_with_sample_data(self):
"""测试带示例数据的模板生成"""
request = GenerateTemplateRequest(
templateName="test_with_sample",
dataStruct=DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", field="test")]
),
dataFormat="json",
sampleData=[{"test": "sample_value"}]
)
response = self.generator.generate_template(request)
assert response.success == True
assert response.templatePath is not None
assert os.path.exists(response.templatePath)
def test_generate_template_with_invalid_request(self):
"""测试无效请求的错误处理 - 通过直接测试Pydantic验证"""
# 由于Pydantic在对象创建时就验证,我们需要直接测试异常
with pytest.raises(ValueError, match="模板名称不能包含字符"):
GenerateTemplateRequest(
templateName="test/invalid", # 包含不安全字符
dataStruct=DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", field="test")]
),
dataFormat="json"
)
def test_generate_template_with_missing_output_directory(self):
"""测试输出目录不存在的情况"""
# 使用不存在的目录路径
non_existent_generator = TemplateGenerator("/non/existent/path")
request = GenerateTemplateRequest(
templateName="test_missing_dir",
dataStruct=DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", field="test")]
),
dataFormat="json"
)
# 应该能够处理并创建目录
response = non_existent_generator.generate_template(request)
# 可能成功(创建了目录)或失败(权限问题)
# 主要是测试不会崩溃
assert response.success in [True, False]
def test_generate_template_with_custom_output_path(self):
"""测试自定义输出路径"""
# 使用相对路径
custom_path = "subdir/custom_template.xlsx"
request = GenerateTemplateRequest(
templateName="test_custom_path",
dataStruct=DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", field="test")]
),
dataFormat="json",
outputPath=custom_path
)
response = self.generator.generate_template(request)
assert response.success == True
assert response.templatePath is not None
assert os.path.exists(response.templatePath)
def test_generate_template_with_absolute_path(self):
"""测试绝对路径支持"""
import tempfile
# 创建临时目录
temp_dir = tempfile.mkdtemp()
absolute_path = os.path.join(temp_dir, "absolute_template.xlsx")
request = GenerateTemplateRequest(
templateName="test_absolute",
dataStruct=DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", field="test")]
),
dataFormat="json",
outputPath=absolute_path
)
response = self.generator.generate_template(request)
assert response.success == True
assert response.templatePath is not None
assert os.path.exists(response.templatePath)
assert "absolute_template.xlsx" in response.templatePath
# 清理
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_generate_template_with_custom_filename(self):
"""测试自定义文件名"""
custom_filename = "my_custom_template"
custom_path = os.path.join(self.temp_dir, custom_filename)
request = GenerateTemplateRequest(
templateName="test_custom_filename",
dataStruct=DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", field="test")]
),
dataFormat="json",
outputPath=custom_path
)
response = self.generator.generate_template(request)
assert response.success == True
assert response.templatePath is not None
assert os.path.exists(response.templatePath)
assert custom_filename in response.templatePath
assert response.templatePath.endswith('.xlsx')
def test_generate_template_with_directory_path(self):
"""测试目录路径(不包含扩展名的路径被视为目录)"""
directory_path = os.path.join(self.temp_dir, "test_dir") # 不包含扩展名,应该被视为目录
request = GenerateTemplateRequest(
templateName="test_directory",
dataStruct=DataStruct(
collectName="data",
itemVariable="item",
dataFields=[DataField(name="测试", field="test")]
),
dataFormat="json",
outputPath=directory_path
)
response = self.generator.generate_template(request)
assert response.success == True
assert response.templatePath is not None
assert os.path.exists(response.templatePath)
assert "test_directory_" in response.templatePath # 应该包含templateName和时间戳
assert os.path.basename(os.path.dirname(response.templatePath)) == "test_dir" # 确保在正确的目录下