#!/usr/bin/env python3
"""Test database operations."""
import sys
from pathlib import Path
# Add src to path so we can import our modules
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from database.db import init_db, insert_cve, get_cve, search_cves, get_stats
def test_database():
"""Test all database operations."""
print("Testing database operations...\n")
# Use a test database
test_db_path = Path(__file__).parent / "test_cve.db"
# Clean up any existing test database
if test_db_path.exists():
test_db_path.unlink()
# 1. Initialize database
print("1. Initializing database...")
conn = init_db(test_db_path)
print(" ✓ Database initialized\n")
# 2. Insert test CVEs
print("2. Inserting test CVEs...")
test_cves = [
{
"cve_id": "CVE-2024-0001",
"description": "A critical vulnerability in Apache HTTP Server allows remote code execution.",
"severity": "CRITICAL",
"cvss_score": 9.8,
"published_date": "2024-01-15",
"modified_date": "2024-01-20",
"references_json": '["https://example.com/advisory"]'
},
{
"cve_id": "CVE-2024-0002",
"description": "SQL injection vulnerability in WordPress plugin.",
"severity": "HIGH",
"cvss_score": 8.5,
"published_date": "2024-02-10",
"modified_date": "2024-02-15",
"references_json": '["https://example.com/wp-vuln"]'
},
{
"cve_id": "CVE-2024-0003",
"description": "Cross-site scripting (XSS) in Apache Tomcat.",
"severity": "MEDIUM",
"cvss_score": 6.1,
"published_date": "2024-03-05",
"modified_date": "2024-03-10",
"references_json": '["https://example.com/tomcat"]'
}
]
for cve in test_cves:
insert_cve(conn, cve)
print(f" ✓ Inserted {cve['cve_id']}")
conn.commit()
print()
# 3. Retrieve a CVE
print("3. Retrieving CVE-2024-0001...")
result = get_cve(conn, "CVE-2024-0001")
if result and result['cve_id'] == "CVE-2024-0001":
print(f" ✓ Found: {result['cve_id']} - {result['severity']}\n")
else:
print(" ✗ Failed to retrieve CVE\n")
return False
# 4. Search CVEs
print("4. Searching for 'Apache'...")
results = search_cves(conn, "Apache")
print(f" ✓ Found {len(results)} results")
for r in results:
print(f" - {r['cve_id']}: {r['description'][:50]}...")
print()
# 5. Get statistics
print("5. Getting statistics...")
stats = get_stats(conn)
print(f" ✓ Total CVEs: {stats['total_cves']}")
print(f" ✓ Date range: {stats['oldest_cve']} to {stats['newest_cve']}\n")
# Cleanup
conn.close()
test_db_path.unlink()
print("✅ All database tests passed!")
return True
if __name__ == "__main__":
success = test_database()
sys.exit(0 if success else 1)