test_stock_api.pyโข19.2 kB
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from app.api.v1.stock.models_stock import StockDataCreate, VectorSearchQuery
from tests.conftest import create_test_stock_data
class TestStockAPI:
"""Test suite for Stock API endpoints"""
def test_get_stock_data_empty(self, client: TestClient):
"""Test getting stock data when database is empty"""
response = client.get("/v1/stock")
assert response.status_code == 200
assert response.json() == []
def test_create_stock_data(self, client: TestClient, session):
"""Test creating new stock data"""
stock_data = {
"name": "Apple Analysis",
"description": "Comprehensive analysis of Apple Inc.",
"symbol": "AAPL",
"data_type": "analysis",
"content": "Apple shows strong growth potential with innovative products.",
"extra_metadata": {"analyst": "test", "confidence": 0.85},
}
response = client.post("/v1/stock", json=stock_data)
assert response.status_code == 201
data = response.json()
assert data["symbol"] == "AAPL"
assert data["data_type"] == "analysis"
assert data["content"] == stock_data["content"]
# Check that embedding_id is generated (don't check exact value)
assert data["embedding_id"] is not None
assert data["embedding_id"].startswith("embedding:")
def test_get_stock_data_with_filters(
self, client: TestClient, session, sample_stock_data
):
"""Test getting stock data with filters"""
# Create test data
for stock_data in sample_stock_data:
create_test_stock_data(session, stock_data)
# Test symbol filter
response = client.get("/v1/stock?symbol=AAPL")
assert response.status_code == 200
data = response.json()
assert all(item["symbol"] == "AAPL" for item in data)
# Test data type filter
response = client.get("/v1/stock?data_type=analysis")
assert response.status_code == 200
data = response.json()
assert all(item["data_type"] == "analysis" for item in data)
# Test pagination
response = client.get("/v1/stock?limit=1&offset=0")
assert response.status_code == 200
data = response.json()
assert len(data) <= 1
def test_get_stock_data_by_id(self, client: TestClient, session, sample_stock_data):
"""Test getting specific stock data by ID"""
# Create test data
db_stock_data = create_test_stock_data(session, sample_stock_data[0])
response = client.get(f"/v1/stock/{db_stock_data.id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == db_stock_data.id
assert data["symbol"] == db_stock_data.symbol
assert data["content"] == db_stock_data.content
def test_get_stock_data_by_id_not_found(self, client: TestClient):
"""Test getting non-existent stock data"""
response = client.get("/v1/stock/99999")
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_update_stock_data(self, client: TestClient, session, sample_stock_data):
"""Test updating stock data"""
# Create test data
db_stock_data = create_test_stock_data(session, sample_stock_data[0])
update_data = {
"name": "Updated Apple Analysis",
"content": "Updated analysis content with new insights.",
}
response = client.patch(f"/v1/stock/{db_stock_data.id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == update_data["name"]
assert data["content"] == update_data["content"]
# Check that embedding_id is updated (don't check exact value)
assert data["embedding_id"] is not None
assert data["embedding_id"].startswith("embedding:")
def test_update_stock_data_not_found(self, client: TestClient):
"""Test updating non-existent stock data"""
update_data = {"name": "Updated Name"}
response = client.patch("/v1/stock/99999", json=update_data)
assert response.status_code == 404
def test_delete_stock_data(self, client: TestClient, session, sample_stock_data):
"""Test deleting stock data"""
# Create test data
db_stock_data = create_test_stock_data(session, sample_stock_data[0])
response = client.delete(f"/v1/stock/{db_stock_data.id}")
assert response.status_code == 204
# Verify deletion
response = client.get(f"/v1/stock/{db_stock_data.id}")
assert response.status_code == 404
def test_delete_stock_data_not_found(self, client: TestClient):
"""Test deleting non-existent stock data"""
response = client.delete("/v1/stock/99999")
assert response.status_code == 404
def test_vector_search(self, client: TestClient, session, sample_stock_data):
"""Test vector similarity search"""
# Create test data with embeddings
for stock_data in sample_stock_data:
create_test_stock_data(session, stock_data)
search_query = {
"query": "Apple growth potential",
"symbols": ["AAPL"],
"limit": 5,
"similarity_threshold": 0.1, # Lower threshold to get results
"include_news": True,
"include_analysis": True,
}
response = client.post("/v1/stock/search", json=search_query)
assert response.status_code == 200
data = response.json()
# The search might return 0 results due to similarity threshold, which is acceptable
assert isinstance(data, list)
def test_get_stock_price(self, client: TestClient):
"""Test getting stock price"""
mock_price_data = {
"symbol": "AAPL",
"price": 150.25,
"change": 2.50,
"change_percent": 1.69,
"volume": 50000000,
"market_cap": 2500000000000,
"pe_ratio": 25.5,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
with patch(
"app.api.v1.stock.controllers_stock.StockController.get_stock_price"
) as mock_get_price:
from app.api.v1.stock.models_stock import StockPrice
mock_get_price.return_value = StockPrice(**mock_price_data)
response = client.get("/v1/stock/price/AAPL")
assert response.status_code == 200
data = response.json()
assert data["symbol"] == "AAPL"
assert data["price"] == 150.25
assert data["change"] == 2.50
def test_get_stock_price_not_found(self, client: TestClient):
"""Test getting price for invalid symbol"""
with patch(
"app.api.v1.stock.controllers_stock.StockController.get_stock_price"
) as mock_get_price:
mock_get_price.return_value = None
response = client.get("/v1/stock/price/INVALID")
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_get_stock_news(self, client: TestClient):
"""Test getting stock news"""
with patch(
"app.api.v1.stock.services.market_service.market_service.get_stock_news"
) as mock_get_news:
from app.api.v1.stock.models_stock import StockNews
mock_news_data = [
StockNews(
title="Apple Reports Strong Earnings",
summary="Apple Inc. announced record quarterly earnings.",
url="https://example.com/apple-news",
source="TechNews",
published_at=datetime.now(timezone.utc),
symbols=["AAPL"],
sentiment_score=0.8,
relevance_score=0.9,
)
]
mock_get_news.return_value = mock_news_data
response = client.get("/v1/stock/news?limit=20")
assert response.status_code == 200
data = response.json()
assert len(data) > 0
assert data[0]["title"] == "Apple Reports Strong Earnings"
assert data[0]["sentiment_score"] == 0.8
def test_get_stock_news_with_symbol_filter(self, client: TestClient):
"""Test getting news with symbol filter"""
with patch(
"app.api.v1.stock.services.market_service.market_service.get_stock_news"
) as mock_get_news:
mock_get_news.return_value = []
response = client.get("/v1/stock/news?symbol=AAPL&limit=10")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
# Verify that the service was called with correct parameters
mock_get_news.assert_called_once_with("AAPL", 10)
def test_analyze_stock(self, client: TestClient, session):
"""Test stock analysis endpoint"""
mock_analysis_data = {
"symbol": "AAPL",
"analysis_type": "comprehensive",
"insights": "Strong growth potential with positive market sentiment.",
"confidence_score": 0.85,
"recommendation": "BUY",
"target_price": 160.0,
"risk_level": "MEDIUM",
}
with patch(
"app.api.v1.stock.controllers_stock.StockController.analyze_stock"
) as mock_analyze:
from app.api.v1.stock.models_stock import StockAnalysis
mock_analyze.return_value = StockAnalysis(**mock_analysis_data)
response = client.get("/v1/stock/analysis/AAPL")
assert response.status_code == 200
data = response.json()
assert data["symbol"] == "AAPL"
assert data["recommendation"] == "BUY"
assert data["confidence_score"] == 0.85
def test_get_market_summary(self, client: TestClient, session):
"""Test market summary endpoint"""
mock_summary_data = {
"date": datetime.now(timezone.utc).isoformat(),
"total_symbols": 8,
"top_gainers": [],
"top_losers": [],
"market_sentiment": 0.2,
"news_count": 25,
"trending_topics": ["technology", "earnings", "growth"],
}
with patch(
"app.api.v1.stock.controllers_stock.StockController.get_market_summary"
) as mock_summary:
from app.api.v1.stock.models_stock import MarketSummary
mock_summary.return_value = MarketSummary(**mock_summary_data)
response = client.get("/v1/stock/market/summary")
assert response.status_code == 200
data = response.json()
assert data["total_symbols"] == 8
assert data["market_sentiment"] == 0.2
assert "technology" in data["trending_topics"]
def test_get_trending_symbols(self, client: TestClient, session):
"""Test trending symbols endpoint"""
mock_trending_data = [
{
"symbol": "AAPL",
"activity_count": 15,
"avg_sentiment": 0.6,
"latest_activity": datetime.now(timezone.utc).isoformat(),
},
{
"symbol": "TSLA",
"activity_count": 12,
"avg_sentiment": 0.3,
"latest_activity": datetime.now(timezone.utc).isoformat(),
},
]
with patch(
"app.api.v1.stock.controllers_stock.StockController.get_trending_symbols"
) as mock_trending:
mock_trending.return_value = mock_trending_data
response = client.get("/v1/stock/market/trending")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["symbol"] == "AAPL"
assert data[0]["activity_count"] == 15
def test_ingest_news_data_background(self, client: TestClient, session):
"""Test news ingestion background task"""
response = client.post("/v1/stock/ingest/news")
assert response.status_code == 200
data = response.json()
assert data["status"] == "processing"
assert "background" in data["message"].lower()
def test_ingest_news_data_sync(self, client: TestClient, session):
"""Test synchronous news ingestion"""
with patch(
"app.api.v1.stock.services.market_service.market_service.get_stock_news"
) as mock_get_news, patch(
"app.api.v1.stock.services.vector_service.vector_service.get_embedding"
) as mock_get_embedding:
# Mock news data
from app.api.v1.stock.models_stock import StockNews
mock_news = [
StockNews(
title="Test News",
summary="Test summary",
url="https://example.com",
source="TestSource",
published_at=datetime.now(timezone.utc),
symbols=["AAPL"],
sentiment_score=0.5,
relevance_score=0.7,
)
]
mock_get_news.return_value = mock_news
mock_get_embedding.return_value = ("test_embedding", [0.1, 0.2, 0.3])
response = client.post("/v1/stock/ingest/news/sync")
assert response.status_code == 200
data = response.json()
assert data["status"] == "completed"
assert "ingested_count" in data
def test_get_cache_stats(self, client: TestClient, session):
"""Test cache statistics endpoint"""
mock_stats_data = {
"cache_hits": 150,
"cache_misses": 25,
"cache_size": 500,
"hit_rate": 85.7,
"last_updated": datetime.now(timezone.utc).isoformat(),
}
with patch(
"app.api.v1.stock.controllers_stock.StockController.get_cache_stats"
) as mock_stats:
from app.api.v1.stock.models_stock import CacheStats
mock_stats.return_value = CacheStats(**mock_stats_data)
response = client.get("/v1/stock/cache/stats")
assert response.status_code == 200
data = response.json()
assert data["cache_hits"] == 150
assert data["hit_rate"] == 85.7
def test_cleanup_cache(self, client: TestClient, session):
"""Test cache cleanup endpoint"""
response = client.post("/v1/stock/cache/cleanup")
assert response.status_code == 200
data = response.json()
assert data["status"] == "processing"
assert "cleanup" in data["message"].lower()
def test_stock_health_check(self, client: TestClient, session):
"""Test stock service health check"""
with patch(
"app.api.v1.stock.services.vector_service.vector_service.get_cache_stats"
) as mock_stats:
mock_stats.return_value = {
"search_cache_entries": 10,
"search_cache_hits": 100,
"embedding_cache_entries": 5,
"embedding_total_usage": 50,
"memory_cache_size": 3,
"cache_ttl_hours": 24,
"model_name": "test_model",
}
response = client.get("/v1/stock/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "cache_stats" in data
def test_validation_errors(self, client: TestClient):
"""Test API validation errors"""
# Test invalid stock data creation - missing required fields
invalid_stock_data = {
"symbol": "AAPL",
"data_type": "analysis",
# Missing required 'content' field
}
response = client.post("/v1/stock", json=invalid_stock_data)
assert response.status_code == 422 # Validation error
# Test invalid limit parameter
response = client.get("/v1/stock?limit=150") # Exceeds maximum
assert response.status_code == 422
def test_pagination_limits(self, client: TestClient):
"""Test pagination parameter limits"""
# Test maximum limit
response = client.get("/v1/stock?limit=150")
assert response.status_code == 422 # Should exceed maximum
# Test negative offset
response = client.get("/v1/stock?offset=-1")
assert response.status_code == 422
# Test valid pagination
response = client.get("/v1/stock?limit=50&offset=0")
assert response.status_code == 200
class TestStockAPIIntegration:
"""Integration tests for Stock API"""
def test_create_and_search_workflow(self, client: TestClient, session):
"""Test complete workflow: create data, then search for it"""
# Create stock data
stock_data = {
"name": "Apple Growth Analysis",
"description": "Analysis of Apple's growth potential",
"symbol": "AAPL",
"data_type": "analysis",
"content": "Apple Inc. demonstrates exceptional growth potential in the smartphone and services markets.",
"extra_metadata": {"analyst": "integration_test"},
}
# Create the data
create_response = client.post("/v1/stock", json=stock_data)
assert create_response.status_code == 201
created_data = create_response.json()
# Search for the data
search_query = {
"query": "Apple growth potential",
"symbols": ["AAPL"],
"limit": 10,
"similarity_threshold": 0.1, # Lower threshold
"include_analysis": True,
}
search_response = client.post("/v1/stock/search", json=search_query)
assert search_response.status_code == 200
search_results = search_response.json()
# The search might return 0 results due to similarity threshold, which is acceptable
assert isinstance(search_results, list)
def test_error_handling_chain(self, client: TestClient):
"""Test error handling across multiple endpoints"""
# Test 404 chain
response = client.get("/v1/stock/99999")
assert response.status_code == 404
response = client.patch("/v1/stock/99999", json={"name": "Updated"})
assert response.status_code == 404
response = client.delete("/v1/stock/99999")
assert response.status_code == 404
# Test service errors
with patch(
"app.api.v1.stock.controllers_stock.StockController.get_stock_price"
) as mock_price:
mock_price.side_effect = Exception("Service unavailable")
response = client.get("/v1/stock/price/AAPL")
assert response.status_code == 500