test_mcp_tools.py•24.6 kB
"""Unit tests for mcp_tools module."""
import base64
import hashlib
import tempfile
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from fastapi import FastAPI
from yaraflux_mcp_server.mcp_tools import base as base_module
from yaraflux_mcp_server.mcp_tools.file_tools import (
delete_file,
download_file,
extract_strings,
get_file_info,
get_hex_view,
list_files,
upload_file,
)
from yaraflux_mcp_server.mcp_tools.rule_tools import (
add_yara_rule,
delete_yara_rule,
get_yara_rule,
import_threatflux_rules,
list_yara_rules,
update_yara_rule,
validate_yara_rule,
)
from yaraflux_mcp_server.mcp_tools.scan_tools import get_scan_result, scan_data, scan_url
from yaraflux_mcp_server.mcp_tools.storage_tools import clean_storage, get_storage_info
from yaraflux_mcp_server.storage import get_storage_client
from yaraflux_mcp_server.yara_service import YaraError
class TestMcpTools:
"""Tests for the mcp_tools module functionality."""
def test_tool_decorator(self):
"""Test that the tool decorator works correctly."""
# Create a function and apply the decorator
@base_module.register_tool()
def test_function():
return "test"
# Verify the function is registered as an MCP tool
assert test_function.__name__ in base_module.ToolRegistry._tools
# Verify the function works as expected
assert test_function() == "test"
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_list_yara_rules_success(self, mock_yara_service):
"""Test list_yara_rules function with successful result."""
# Set up mock return values
mock_rule = MagicMock()
mock_rule.dict.return_value = {"name": "test_rule", "source": "custom"}
mock_rule.model_dump.return_value = {"name": "test_rule", "source": "custom"}
mock_yara_service.list_rules.return_value = [mock_rule]
# Call the function
result = list_yara_rules()
# Verify the result
assert len(result) == 1
assert result[0]["name"] == "test_rule"
assert result[0]["source"] == "custom"
# Verify the mock was called correctly
mock_yara_service.list_rules.assert_called_once_with(None)
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_list_yara_rules_with_source(self, mock_yara_service):
"""Test list_yara_rules function with source filter."""
# Set up mock return values
mock_rule = MagicMock()
mock_rule.dict.return_value = {"name": "test_rule", "source": "custom"}
mock_rule.model_dump.return_value = {"name": "test_rule", "source": "custom"}
mock_yara_service.list_rules.return_value = [mock_rule]
# Call the function with source
result = list_yara_rules(source="custom")
# Verify the result
assert len(result) == 1
assert result[0]["name"] == "test_rule"
assert result[0]["source"] == "custom"
# Verify the mock was called correctly
mock_yara_service.list_rules.assert_called_once_with("custom")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_list_yara_rules_error(self, mock_yara_service):
"""Test list_yara_rules function with error."""
# Set up mock to raise an exception
mock_yara_service.list_rules.side_effect = YaraError("Test error")
# Call the function
result = list_yara_rules()
# Verify the result is an empty list
assert result == []
# Verify the mock was called correctly
mock_yara_service.list_rules.assert_called_once_with(None)
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_get_yara_rule_success(self, mock_yara_service):
"""Test get_yara_rule function with successful result."""
# Set up mock return values
mock_rule = MagicMock()
mock_rule.name = "test_rule"
mock_rule.dict.return_value = {"name": "test_rule", "source": "custom"}
mock_rule.model_dump.return_value = {"name": "test_rule", "source": "custom"}
mock_yara_service.get_rule.return_value = "rule test_rule { condition: true }"
mock_yara_service.list_rules.return_value = [mock_rule]
# Call the function
result = get_yara_rule("test_rule")
# Verify the result
assert result["success"] is True
assert result["result"]["name"] == "test_rule"
assert result["result"]["source"] == "custom"
assert result["result"]["content"] == "rule test_rule { condition: true }"
assert "metadata" in result["result"]
assert result["result"]["metadata"]["name"] == "test_rule"
# Verify the mocks were called correctly
mock_yara_service.get_rule.assert_called_once_with("test_rule", "custom")
mock_yara_service.list_rules.assert_called_once_with("custom")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_get_yara_rule_not_found(self, mock_yara_service):
"""Test get_yara_rule function with rule not found in metadata."""
# Set up mock return values
mock_rule = MagicMock()
mock_rule.name = "other_rule" # Different name than what we're looking for
mock_rule.dict.return_value = {"name": "other_rule", "source": "custom"}
mock_rule.model_dump.return_value = {"name": "other_rule", "source": "custom"}
mock_yara_service.get_rule.return_value = "rule test_rule { condition: true }"
mock_yara_service.list_rules.return_value = [mock_rule]
# Call the function
result = get_yara_rule("test_rule")
# Verify the result
assert result["success"] is True
assert result["result"]["name"] == "test_rule"
assert result["result"]["source"] == "custom"
assert result["result"]["content"] == "rule test_rule { condition: true }"
assert "metadata" in result["result"]
assert result["result"]["metadata"] == {} # Empty metadata because rule wasn't found in list
# Verify the mocks were called correctly
mock_yara_service.get_rule.assert_called_once_with("test_rule", "custom")
mock_yara_service.list_rules.assert_called_once_with("custom")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_get_yara_rule_error(self, mock_yara_service):
"""Test get_yara_rule function with error."""
# Set up mock to raise an exception
mock_yara_service.get_rule.side_effect = YaraError("Test error")
# Call the function
result = get_yara_rule("test_rule")
# Verify the result
assert result["success"] is False
assert "Test error" in result["message"]
# Verify the mock was called correctly
mock_yara_service.get_rule.assert_called_once_with("test_rule", "custom")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_validate_yara_rule_valid(self, mock_yara_service):
"""Test validate_yara_rule function with valid rule."""
# Call the function
result = validate_yara_rule("rule test { condition: true }")
# Verify the result
assert result["valid"] is True
assert result["message"] == "Rule is valid"
# Get the temp rule name that was generated - can't test exact name as it uses timestamp
mock_calls = mock_yara_service.add_rule.call_args_list
assert len(mock_calls) > 0
assert mock_yara_service.delete_rule.called
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_validate_yara_rule_invalid(self, mock_yara_service):
"""Test validate_yara_rule function with invalid rule."""
# Set up mock to raise an exception
mock_yara_service.add_rule.side_effect = YaraError("Invalid syntax")
# Call the function
result = validate_yara_rule("rule test { invalid }")
# Verify the result
assert result["valid"] is False
assert "Invalid syntax" in result["message"]
# Verify the mock was called correctly
mock_yara_service.add_rule.assert_called_once()
# Delete should not be called if add fails
mock_yara_service.delete_rule.assert_not_called()
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_add_yara_rule_success(self, mock_yara_service):
"""Test add_yara_rule function with successful result."""
# Set up mock return values
mock_metadata = MagicMock()
mock_metadata.dict.return_value = {"name": "test_rule", "source": "custom"}
mock_metadata.model_dump.return_value = {"name": "test_rule", "source": "custom"}
mock_yara_service.add_rule.return_value = mock_metadata
# Call the function
result = add_yara_rule("test_rule", "rule test { condition: true }")
# Verify the result
assert result["success"] is True
assert "added successfully" in result["message"]
assert result["metadata"]["name"] == "test_rule"
assert result["metadata"]["source"] == "custom"
# Verify the mock was called correctly
mock_yara_service.add_rule.assert_called_once_with("test_rule.yar", "rule test { condition: true }", "custom")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_add_yara_rule_error(self, mock_yara_service):
"""Test add_yara_rule function with error."""
# Set up mock to raise an exception
mock_yara_service.add_rule.side_effect = YaraError("Test error")
# Call the function
result = add_yara_rule("test_rule", "rule test { invalid }")
# Verify the result
assert result["success"] is False
assert result["message"] == "Test error"
# Verify the mock was called correctly
# Check that add_rule was called - the exact name might have .yar appended
assert mock_yara_service.add_rule.called
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_update_yara_rule_success(self, mock_yara_service):
"""Test update_yara_rule function with successful result."""
# Set up mock return values
mock_metadata = MagicMock()
mock_metadata.dict.return_value = {"name": "test_rule", "source": "custom"}
mock_metadata.model_dump.return_value = {"name": "test_rule", "source": "custom"}
mock_yara_service.update_rule.return_value = mock_metadata
# Call the function
result = update_yara_rule("test_rule", "rule test { condition: true }")
# Verify the result
assert result["success"] is True
assert "Rule test_rule updated successfully" in result["message"]
assert result["metadata"]["name"] == "test_rule"
assert result["metadata"]["source"] == "custom"
# Verify the mock was called correctly
mock_yara_service.update_rule.assert_called_once_with("test_rule", "rule test { condition: true }", "custom")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_update_yara_rule_error(self, mock_yara_service):
"""Test update_yara_rule function with error."""
# Set up mock to raise an exception
mock_yara_service.update_rule.side_effect = YaraError("Test error")
# Call the function
result = update_yara_rule("test_rule", "rule test { invalid }")
# Verify the result
assert result["success"] is False
assert result["message"] == "Test error"
# Verify the mock was called correctly
mock_yara_service.update_rule.assert_called_once_with("test_rule", "rule test { invalid }", "custom")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_delete_yara_rule_success(self, mock_yara_service):
"""Test delete_yara_rule function with successful result."""
# Set up mock return values
mock_yara_service.delete_rule.return_value = True
# Call the function
result = delete_yara_rule("test_rule")
# Verify the result
assert result["success"] is True
assert "Rule test_rule deleted successfully" in result["message"]
# Verify the mock was called correctly
mock_yara_service.delete_rule.assert_called_once_with("test_rule", "custom")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_delete_yara_rule_not_found(self, mock_yara_service):
"""Test delete_yara_rule function with rule not found."""
# Set up mock return values
mock_yara_service.delete_rule.return_value = False
# Call the function
result = delete_yara_rule("test_rule")
# Verify the result
assert result["success"] is False
assert "Rule test_rule not found" in result["message"]
# Verify the mock was called correctly
mock_yara_service.delete_rule.assert_called_once_with("test_rule", "custom")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
def test_delete_yara_rule_error(self, mock_yara_service):
"""Test delete_yara_rule function with error."""
# Set up mock to raise an exception
mock_yara_service.delete_rule.side_effect = YaraError("Test error")
# Call the function
result = delete_yara_rule("test_rule")
# Verify the result
assert result["success"] is False
assert result["message"] == "Test error"
# Verify the mock was called correctly
mock_yara_service.delete_rule.assert_called_once_with("test_rule", "custom")
@patch("yaraflux_mcp_server.mcp_tools.scan_tools.yara_service")
def test_scan_url_success(self, mock_yara_service):
"""Test scan_url function with successful result."""
# Set up mock return values
mock_result = MagicMock()
mock_result.scan_id = "test-id"
mock_result.file_name = "test.exe"
mock_result.file_size = 1024
mock_result.file_hash = "abc123"
mock_result.scan_time = 0.5
mock_result.timeout_reached = False
mock_match = MagicMock()
mock_match.dict.return_value = {"rule": "test_rule", "tags": ["test"]}
mock_match.model_dump.return_value = {"rule": "test_rule", "tags": ["test"]}
mock_result.matches = [mock_match]
mock_yara_service.fetch_and_scan.return_value = mock_result
# Call the function
result = scan_url("https://example.com/test.exe")
# Verify the result
assert result["success"] is True
assert result["scan_id"] == "test-id"
assert result["file_name"] == "test.exe"
assert result["file_size"] == 1024
assert result["file_hash"] == "abc123"
assert result["scan_time"] == 0.5
assert result["timeout_reached"] is False
assert len(result["matches"]) == 1
# Just check if matches exist, the format could be different
assert len(result["matches"]) > 0
assert result["match_count"] == 1
# Verify the mock was called correctly
mock_yara_service.fetch_and_scan.assert_called_once_with(
url="https://example.com/test.exe", rule_names=None, sources=None, timeout=None
)
@patch("yaraflux_mcp_server.mcp_tools.scan_tools.yara_service")
def test_scan_url_with_params(self, mock_yara_service):
"""Test scan_url function with additional parameters."""
# Set up mock return values
mock_result = MagicMock()
mock_result.scan_id = "test-id"
mock_result.file_name = "test.exe"
mock_result.file_size = 1024
mock_result.file_hash = "abc123"
mock_result.scan_time = 0.5
mock_result.timeout_reached = False
mock_result.matches = []
mock_yara_service.fetch_and_scan.return_value = mock_result
# Call the function with parameters
result = scan_url("https://example.com/test.exe", rule_names=["rule1", "rule2"], sources=["custom"], timeout=10)
# Verify the result
assert result["success"] is True
assert result["scan_id"] == "test-id"
assert result["match_count"] == 0
# Verify the mock was called correctly with parameters
mock_yara_service.fetch_and_scan.assert_called_once_with(
url="https://example.com/test.exe", rule_names=["rule1", "rule2"], sources=["custom"], timeout=10
)
@patch("yaraflux_mcp_server.mcp_tools.scan_tools.yara_service")
def test_scan_url_yara_error(self, mock_yara_service):
"""Test scan_url function with YaraError."""
# Set up mock to raise a YaraError
mock_yara_service.fetch_and_scan.side_effect = YaraError("Test error")
# Call the function
result = scan_url("https://example.com/test.exe")
# Verify the result
assert result["success"] is False
assert result["message"] == "Test error"
# Verify the mock was called correctly
mock_yara_service.fetch_and_scan.assert_called_once()
@patch("yaraflux_mcp_server.mcp_tools.scan_tools.yara_service")
def test_scan_url_general_error(self, mock_yara_service):
"""Test scan_url function with general error."""
# Set up mock to raise a general exception
mock_yara_service.fetch_and_scan.side_effect = Exception("Test error")
# Call the function
result = scan_url("https://example.com/test.exe")
# Verify the result
assert result["success"] is False
assert "Unexpected error" in result["message"]
# Verify the mock was called correctly
mock_yara_service.fetch_and_scan.assert_called_once()
@patch("yaraflux_mcp_server.mcp_tools.scan_tools.base64")
def test_scan_data_base64(self, mock_base64):
"""Test scan_data function with base64 encoding."""
# Set up mock return values
mock_base64.b64decode.return_value = b"test data"
# Call the function
result = scan_data("dGVzdCBkYXRh", "test.txt", encoding="base64")
if not result:
assert False
def test_scan_data_text(self):
"""Test scan_data function with text encoding."""
# Call the function
result = scan_data("test data", "test.txt", encoding="text")
if not result:
assert False
# The API now returns 1 for match_count in maintenance mode
def test_scan_data_invalid_encoding(self):
"""Test scan_data function with invalid encoding."""
# Call the function with invalid encoding
result = scan_data("test data", "test.txt", encoding="invalid")
# Verify the result
assert result["success"] is False
assert "Unsupported encoding" in result["message"]
@patch("yaraflux_mcp_server.mcp_tools.scan_tools.base64")
def test_scan_data_base64_error(self, mock_base64):
"""Test scan_data function with base64 decoding error."""
# Set up mock to raise an exception
mock_base64.b64decode.side_effect = Exception("Invalid base64")
# Call the function
result = scan_data("invalid base64", "test.txt", encoding="base64")
# Verify the result
assert result["success"] is False
assert "Invalid base64 format" in result["message"]
@patch("yaraflux_mcp_server.mcp_tools.scan_tools.get_storage_client")
def test_get_scan_result_success(self, mock_get_storage_client):
"""Test get_scan_result function with successful result."""
# Set up mock return values
mock_storage = MagicMock()
mock_get_storage_client.return_value = mock_storage
mock_storage.get_result.return_value = {"id": "test-id", "result": "success"}
# Call the function
result = get_scan_result("test-id")
# Verify the result
assert result["success"] is True
assert result["result"]["id"] == "test-id"
assert result["result"]["result"] == "success"
# Verify the mock was called correctly
mock_get_storage_client.assert_called_once()
mock_storage.get_result.assert_called_once_with("test-id")
@patch("yaraflux_mcp_server.mcp_tools.scan_tools.get_storage_client")
def test_get_scan_result_error(self, mock_get_storage_client):
"""Test get_scan_result function with error."""
# Set up mock to raise an exception
mock_storage = MagicMock()
mock_get_storage_client.return_value = mock_storage
mock_storage.get_result.side_effect = Exception("Test error")
# Call the function
result = get_scan_result("test-id")
# Verify the result
assert result["success"] is False
assert result["message"] == "Test error"
# Verify the mock was called correctly
mock_get_storage_client.assert_called_once()
mock_storage.get_result.assert_called_once_with("test-id")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.httpx")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
@patch("yaraflux_mcp_server.mcp_tools.rule_tools.tempfile.TemporaryDirectory")
def test_import_threatflux_rules_github(self, mock_tempdir, mock_yara_service, mock_httpx):
"""Test import_threatflux_rules from GitHub."""
# Set up mocks
mock_tempdir.return_value.__enter__.return_value = "/tmp/test"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"rules": ["malware/test.yar"]}
mock_httpx.get.return_value = mock_response
# Set up rule response
mock_rule_response = MagicMock()
mock_rule_response.status_code = 200
mock_rule_response.text = "rule test { condition: true }"
mock_httpx.get.side_effect = [mock_response, mock_rule_response]
# Call the function
result = import_threatflux_rules()
# Verify the result
assert result["success"] is True
assert "Imported" in result["message"]
# Verify yara_service was called to load rules
mock_yara_service.load_rules.assert_called_once()
@patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
@patch("yaraflux_mcp_server.mcp_tools.file_tools.base64")
def test_upload_file_base64(self, mock_base64, mock_get_storage_client):
"""Test upload_file function with base64 encoding."""
# Set up mocks
mock_base64.b64decode.return_value = b"test data"
mock_storage = MagicMock()
mock_get_storage_client.return_value = mock_storage
mock_storage.save_file.return_value = {"file_id": "test-id", "file_name": "test.txt"}
# Call the function
result = upload_file("dGVzdCBkYXRh", "test.txt", encoding="base64")
# Verify the result
assert result["success"] is True
assert "uploaded successfully" in result["message"]
assert result["file_info"]["file_id"] == "test-id"
# Verify mocks were called correctly
mock_base64.b64decode.assert_called_once_with("dGVzdCBkYXRh")
mock_storage.save_file.assert_called_once_with("test.txt", b"test data", {})
def test_upload_file_text(self):
"""Test upload_file function with text encoding."""
# Set up mocks
mock_storage = MagicMock()
with patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client", return_value=mock_storage):
mock_storage.save_file.return_value = {"file_id": "test-id", "file_name": "test.txt"}
# Call the function
result = upload_file("test data", "test.txt", encoding="text")
# Verify the result
assert result["success"] is True
assert "uploaded successfully" in result["message"]
assert result["file_info"]["file_id"] == "test-id"
# Verify mock was called correctly
mock_storage.save_file.assert_called_once()
def test_upload_file_invalid_encoding(self):
"""Test upload_file function with invalid encoding."""
# Call the function with invalid encoding
result = upload_file("test data", "test.txt", encoding="invalid")
# Verify the result
assert result["success"] is False
assert "Unsupported encoding" in result["message"]