from sqlalchemy import select, and_, or_, func, desc, asc
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Any
import structlog
from src.data.models import Contract, Trade, Settlement, ContractType
logger = structlog.get_logger()
class ContractRepository:
"""Repository for Contract operations."""
@staticmethod
async def create(session: AsyncSession, **kwargs) -> Contract:
"""Create a new contract."""
contract = Contract(**kwargs)
session.add(contract)
await session.flush()
return contract
@staticmethod
async def get_by_symbol(session: AsyncSession, symbol: str) -> Optional[Contract]:
"""Get contract by symbol."""
result = await session.execute(
select(Contract).where(Contract.symbol == symbol)
)
return result.scalar_one_or_none()
@staticmethod
async def get_by_id(session: AsyncSession, contract_id: int) -> Optional[Contract]:
"""Get contract by ID."""
result = await session.execute(
select(Contract).where(Contract.id == contract_id)
)
return result.scalar_one_or_none()
@staticmethod
async def get_active_contracts(session: AsyncSession) -> List[Contract]:
"""Get all active contracts."""
result = await session.execute(
select(Contract).where(
and_(
Contract.is_active == True,
Contract.is_settled == False
)
)
)
return list(result.scalars().all())
@staticmethod
async def search_contracts(
session: AsyncSession,
search_term: str,
contract_type: Optional[ContractType] = None
) -> List[Contract]:
"""Search contracts by description or symbol."""
conditions = [
or_(
Contract.symbol.ilike(f"%{search_term}%"),
Contract.description.ilike(f"%{search_term}%")
)
]
if contract_type:
conditions.append(Contract.contract_type == contract_type)
result = await session.execute(
select(Contract).where(and_(*conditions))
)
return list(result.scalars().all())
@staticmethod
async def count_all(session: AsyncSession) -> int:
"""Get total count of all contracts."""
result = await session.execute(
select(func.count(Contract.id))
)
return result.scalar()
@staticmethod
async def get_recent(session: AsyncSession, limit: int = 10) -> List[Contract]:
"""Get most recently created contracts."""
result = await session.execute(
select(Contract)
.order_by(desc(Contract.created_at))
.limit(limit)
)
return list(result.scalars().all())
class TradeRepository:
"""Repository for Trade operations."""
@staticmethod
async def create(session: AsyncSession, trade: Trade) -> Trade:
"""Create a single trade."""
session.add(trade)
await session.flush()
return trade
@staticmethod
async def create_bulk(session: AsyncSession, trades: List[Trade]) -> int:
"""Create multiple trades."""
session.add_all(trades)
await session.flush()
return len(trades)
@staticmethod
async def get_by_trade_id(session: AsyncSession, trade_id: str) -> Optional[Trade]:
"""Get trade by trade ID."""
result = await session.execute(
select(Trade).where(Trade.trade_id == trade_id)
)
return result.scalar_one_or_none()
@staticmethod
async def count_all(session: AsyncSession) -> int:
"""Get total count of all trades."""
result = await session.execute(
select(func.count(Trade.id))
)
return result.scalar()
@staticmethod
async def get_recent(session: AsyncSession, limit: int = 10) -> List[Trade]:
"""Get most recent trades."""
result = await session.execute(
select(Trade)
.order_by(desc(Trade.timestamp))
.limit(limit)
)
return list(result.scalars().all())
@staticmethod
async def get_by_contract_and_timerange(
session: AsyncSession,
contract_id: int,
start_time: datetime,
end_time: datetime
) -> List[Trade]:
"""Get trades for a contract within a time range."""
result = await session.execute(
select(Trade)
.where(
and_(
Trade.contract_id == contract_id,
Trade.timestamp >= start_time,
Trade.timestamp <= end_time
)
)
.order_by(Trade.timestamp)
)
return list(result.scalars().all())
@staticmethod
async def get_price_at_timestamp(
session: AsyncSession,
contract_id: int,
timestamp: datetime,
tolerance_minutes: int = 5
) -> Optional[Trade]:
"""Get price at or near a specific timestamp."""
# Try exact timestamp first
result = await session.execute(
select(Trade)
.where(
and_(
Trade.contract_id == contract_id,
Trade.timestamp == timestamp
)
)
.limit(1)
)
trade = result.scalar_one_or_none()
if trade:
return trade
# Try within tolerance window
start = timestamp - timedelta(minutes=tolerance_minutes)
end = timestamp + timedelta(minutes=tolerance_minutes)
result = await session.execute(
select(Trade)
.where(
and_(
Trade.contract_id == contract_id,
Trade.timestamp >= start,
Trade.timestamp <= end
)
)
.order_by(
func.abs(
func.extract('epoch', Trade.timestamp) -
func.extract('epoch', timestamp)
)
)
.limit(1)
)
return result.scalar_one_or_none()
@staticmethod
async def get_latest_trade(
session: AsyncSession,
contract_id: int
) -> Optional[Trade]:
"""Get the most recent trade for a contract."""
result = await session.execute(
select(Trade)
.where(Trade.contract_id == contract_id)
.order_by(desc(Trade.timestamp))
.limit(1)
)
return result.scalar_one_or_none()
@staticmethod
async def get_ohlc(
session: AsyncSession,
contract_id: int,
start_time: datetime,
end_time: datetime
) -> Dict[str, float]:
"""Calculate OHLC (Open, High, Low, Close) for a time period."""
trades = await TradeRepository.get_by_contract_and_timerange(
session, contract_id, start_time, end_time
)
if not trades:
return {}
prices = [t.price for t in trades]
return {
'open': trades[0].price,
'high': max(prices),
'low': min(prices),
'close': trades[-1].price,
'volume': sum(t.volume for t in trades),
'trades_count': len(trades)
}
class SettlementRepository:
"""Repository for Settlement operations."""
@staticmethod
async def create(session: AsyncSession, **kwargs) -> Settlement:
"""Create a new settlement."""
settlement = Settlement(**kwargs)
session.add(settlement)
await session.flush()
return settlement
@staticmethod
async def get_by_contract(
session: AsyncSession,
contract_id: int
) -> Optional[Settlement]:
"""Get settlement for a contract."""
result = await session.execute(
select(Settlement).where(Settlement.contract_id == contract_id)
)
return result.scalar_one_or_none()
@staticmethod
async def update_verification(
session: AsyncSession,
settlement_id: int,
verified: bool,
source: str,
notes: Optional[str] = None
) -> Settlement:
"""Update settlement verification status."""
result = await session.execute(
select(Settlement).where(Settlement.id == settlement_id)
)
settlement = result.scalar_one()
settlement.verified = verified
settlement.verification_source = source
settlement.notes = notes
await session.flush()
return settlement